Skip to content

Commit

Permalink
Use a Map from Snap to Snap to represent a magic wand snapshot. (#836)
Browse files Browse the repository at this point in the history
* Use a Map from Snap to Snap to represent a magic wand snapshot.

* Fix failing test cases.

* Fix broken links in comments.

* First optimization: Create sort MagicWandSnapFunction (MWSF) with its own function definitions to replace wand maps.

* Second optimization: When there is no applying expression use the original approach using MagicWandSnapSingleton.

* Revert "Second optimization: When there is no applying expression use the original approach using MagicWandSnapSingleton."

This reverts commit fd87482.

* Apply suggested changes from code review.

* Fix test cases with quasihavoc statements.

* Update submodule to use branch with both testcase changes

* Remove abstractLhs and rhsSnapshot from MagicWandSnapshot.

* Reduce diff

* Simplify havoc

* Simplify production of a MWSF.

* Rename variable in Producer.

* Update silver branch magic-wand-fixes.

---------

Co-authored-by: Jonáš Fiala <jonas.fiala@inf.ethz.ch>
  • Loading branch information
manud99 and JonasAlaif authored Jun 12, 2024
1 parent 4a0c07e commit f8cc484
Show file tree
Hide file tree
Showing 17 changed files with 442 additions and 242 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
(declare-fun MWSF_apply ($MWSF $Snap) $Snap)
1 change: 1 addition & 0 deletions src/main/scala/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ package object utils {
case class ViperEmbedding(embeddedSort: Sort) extends silver.ast.ExtensionType {
def substitute(typVarsMap: Predef.Map[silver.ast.TypeVar, silver.ast.Type]): silver.ast.Type = this
def isConcrete: Boolean = true
override def toString: String = s"ViperEmbedding(sorts.$embeddedSort)"
}
}

Expand Down
7 changes: 6 additions & 1 deletion src/main/scala/decider/TermToSMTLib2Converter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class TermToSMTLib2Converter

case sorts.FieldPermFunction() => text("$FPM")
case sorts.PredicatePermFunction() => text("$PPM")

case sorts.MagicWandSnapFunction => text("$MWSF")
}

def convert(d: Decl): String = {
Expand Down Expand Up @@ -263,7 +265,7 @@ class TermToSMTLib2Converter

case Lookup(field, fvf, at) => //fvf.sort match {
// case _: sorts.PartialFieldValueFunction =>
parens(text("$FVF.lookup_") <> field <+> render(fvf) <+> render(at))
parens(text("$FVF.lookup_") <> field <+> render(fvf) <+> render(at))
// case _: sorts.TotalFieldValueFunction =>
// render(Apply(fvf, Seq(at)))
// parens("$FVF.lookup_" <> field <+> render(fvf) <+> render(at))
Expand Down Expand Up @@ -313,6 +315,9 @@ class TermToSMTLib2Converter
val docBindings = ssep((bindings.toSeq map (p => parens(render(p._1) <+> render(p._2)))).to(collection.immutable.Seq), space)
parens(text("let") <+> parens(docBindings) <+> render(body))

case MagicWandSnapshot(mwsf) => render(mwsf)
case MWSFLookup(mwsf, snap) => renderApp("MWSF_apply", Seq(mwsf, snap), sorts.Snap)

case _: MagicWandChunkTerm
| _: Quantification =>
sys.error(s"Unexpected term $term cannot be translated to SMTLib code")
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/decider/TermToZ3APIConverter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ class TermToZ3APIConverter
case Let(bindings, body) =>
convert(body.replace(bindings))

case MWSFLookup(mwsf, snap) => createApp("MWSF_apply", Seq(mwsf, snap), sorts.Snap)

case _: MagicWandChunkTerm
| _: Quantification =>
sys.error(s"Unexpected term $term cannot be translated to SMTLib code")
Expand Down
15 changes: 11 additions & 4 deletions src/main/scala/rules/HavocSupporter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,17 @@ object havocSupporter extends SymbolicExecutionRules {
val id = ChunkIdentifier(resource, s.program)
val (relevantChunks, otherChunks) = chunkSupporter.splitHeap[NonQuantifiedChunk](s.h, id)

val newChunks = relevantChunks.map { ch =>
val havockedSnap = freshSnap(ch.snap.sort, v)
val cond = replacementCond(lhs, ch.args, condInfo)
ch.withSnap(Ite(cond, havockedSnap, ch.snap))
val newChunks = relevantChunks.map {
case ch: MagicWandChunk =>
val havockedSnap = v.decider.fresh("mwsf", sorts.MagicWandSnapFunction)
val cond = replacementCond(lhs, ch.args, condInfo)
val magicWandSnapshot = MagicWandSnapshot(Ite(cond, havockedSnap, ch.snap.mwsf))
ch.withSnap(magicWandSnapshot)

case ch =>
val havockedSnap = freshSnap(ch.snap.sort, v)
val cond = replacementCond(lhs, ch.args, condInfo)
ch.withSnap(Ite(cond, havockedSnap, ch.snap))
}
otherChunks ++ newChunks
}
Expand Down
333 changes: 201 additions & 132 deletions src/main/scala/rules/MagicWandSupporter.scala

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/main/scala/rules/MoreCompleteExhaleSupporter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object moreCompleteExhaleSupporter extends SymbolicExecutionRules {
resource match {
case f: ast.Field => v.symbolConverter.toSort(f.typ)
case _: ast.Predicate => sorts.Snap
case _: ast.MagicWand => sorts.Snap
case _: ast.MagicWand => sorts.MagicWandSnapFunction
}

val `?s` = Var(Identifier("?s"), sort, false)
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/rules/Producer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ object producer extends ProductionRules {
Q(s2, v1)})

case wand: ast.MagicWand =>
val snap = sf(sorts.Snap, v)
val snap = sf(sorts.MagicWandSnapFunction, v)
magicWandSupporter.createChunk(s, wand, MagicWandSnapshot(snap), pve, v)((s1, chWand, v1) =>
chunkSupporter.produce(s1, s1.h, chWand, v1)((s2, h2, v2) =>
Q(s2.copy(h = h2), v2)))
Expand Down
1 change: 0 additions & 1 deletion src/main/scala/rules/QuantifiedChunkSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,6 @@ object quantifiedChunkSupporter extends QuantifiedChunkSupport {

/* Snapshots */

/** @inheritdoc */
def singletonSnapshotMap(s: State,
resource: ast.Resource,
arguments: Seq[Term],
Expand Down
7 changes: 5 additions & 2 deletions src/main/scala/state/Chunks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,16 @@ case class MagicWandChunk(id: MagicWandIdentifier,
override val resourceID = MagicWandID

override def withPerm(newPerm: Term) = MagicWandChunk(id, bindings, args, snap, newPerm)
override def withSnap(newSnap: Term) = MagicWandChunk(id, bindings, args, MagicWandSnapshot(newSnap), perm)
override def withSnap(newSnap: Term) = newSnap match {
case s: MagicWandSnapshot => MagicWandChunk(id, bindings, args, s, perm)
case _ => sys.error(s"MagicWand snapshot has to be of type MagicWandSnapshot but found ${newSnap.getClass}")
}

override lazy val toString = {
val pos = id.ghostFreeWand.pos match {
case rp: viper.silver.ast.HasLineColumn => s"${rp.line}:${rp.column}"
case other => other.toString
}
s"wand@$pos[$snap; ${args.mkString(",")}]"
s"wand@$pos[$snap; ${args.mkString(", ")}]"
}
}
108 changes: 64 additions & 44 deletions src/main/scala/state/Terms.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ object sorts {
override lazy val toString = id.toString
}

object MagicWandSnapFunction extends Sort {
val id: Identifier = Identifier("MWSF")
override lazy val toString: String = id.toString
}

case class FieldPermFunction() extends Sort {
val id = Identifier("FPM")
override lazy val toString = id.toString
Expand Down Expand Up @@ -321,39 +326,39 @@ sealed trait Term extends Node {

lazy val subterms: Seq[Term] = state.utils.subterms(this)

/** @see [[ast.utility.Visitor.visit()]] */
/** @see [[ast.utility.Visitor.visit]] */
def visit(f: PartialFunction[Term, Any]): Unit =
ast.utility.Visitor.visit(this, state.utils.subterms)(f)

/** @see [[ast.utility.Visitor.visitOpt()]] */
/** @see [[ast.utility.Visitor.visitOpt]] */
def visitOpt(f: Term => Boolean): Unit =
ast.utility.Visitor.visitOpt(this, state.utils.subterms)(f)

/** @see [[ast.utility.Visitor.reduceTree()]] */
/** @see [[ast.utility.Visitor.reduceTree]] */
def reduceTree[R](f: (Term, Seq[R]) => R): R =
ast.utility.Visitor.reduceTree(this, state.utils.subterms)(f)

/** @see [[ast.utility.Visitor.existsDefined()]] */
/** @see [[ast.utility.Visitor.existsDefined]] */
def existsDefined(f: PartialFunction[Term, Any]): Boolean =
ast.utility.Visitor.existsDefined(this, state.utils.subterms)(f)

/** @see [[ast.utility.Visitor.hasSubnode()]] */
/** @see [[ast.utility.Visitor.hasSubnode]] */
def hasSubterm(subterm: Term): Boolean =
ast.utility.Visitor.hasSubnode(this, subterm, state.utils.subterms)

/** @see [[ast.utility.Visitor.deepCollect()]] */
/** @see [[ast.utility.Visitor.deepCollect]] */
def deepCollect[R](f: PartialFunction[Term, R]) : Seq[R] =
ast.utility.Visitor.deepCollect(Seq(this), state.utils.subterms)(f)

/** @see [[ast.utility.Visitor.shallowCollect()]] */
/** @see [[ast.utility.Visitor.shallowCollect]] */
def shallowCollect[R](f: PartialFunction[Term, R]): Seq[R] =
ast.utility.Visitor.shallowCollect(Seq(this), state.utils.subterms)(f)

/** @see [[ast.utility.Visitor.find()]] */
/** @see [[ast.utility.Visitor.find]] */
def find[R](f: PartialFunction[Term, R]): Option[R] =
ast.utility.Visitor.find(this, state.utils.subterms)(f)

/** @see [[state.utils.transform()]] */
/** @see [[state.utils.transform]] */
def transform(pre: PartialFunction[Term, Term] = PartialFunction.empty)
(recursive: Term => Boolean = !pre.isDefinedAt(_),
post: PartialFunction[Term, Term] = PartialFunction.empty)
Expand Down Expand Up @@ -2295,48 +2300,63 @@ object PredicateTrigger extends PreciseCondFlyweightFactory[(String, Term, Seq[T

/* Magic wands */

class MagicWandSnapshot(val abstractLhs: Term, val rhsSnapshot: Term) extends Combine(abstractLhs, rhsSnapshot) {
utils.assertSort(abstractLhs, "abstract lhs", sorts.Snap)
utils.assertSort(rhsSnapshot, "rhs", sorts.Snap)
/**
* Represents a snapshot of a magic wand, which is a function from `Snap` to `Snap`.
*
* @param mwsf The function that represents the snapshot of the magic wand. It is a variable of sort [[sorts.MagicWandSnapFunction]].
* In the symbolic execution when we apply a magic wand, it consumes the left-hand side
* and uses this function and the resulting snapshot to look up which right-hand side to produce.
*/
class MagicWandSnapshot(val mwsf: Term) extends Term with ConditionalFlyweight[Term, MagicWandSnapshot] {
utils.assertSort(mwsf, "magic wand snap function", sorts.MagicWandSnapFunction)

override lazy val toString = s"wandSnap(lhs = $abstractLhs, rhs = $rhsSnapshot)"
override val sort: Sort = sorts.MagicWandSnapFunction

def merge(other: MagicWandSnapshot, branchConditions: Stack[Term]): MagicWandSnapshot = {
assert(this.abstractLhs == other.abstractLhs)
val condition = And(branchConditions)
MagicWandSnapshot(this.abstractLhs, if (this.rhsSnapshot == other.rhsSnapshot)
this.rhsSnapshot
else
Ite(condition, other.rhsSnapshot, this.rhsSnapshot))
}
override lazy val toString = s"wandSnap($mwsf)"

override val equalityDefiningMembers: Term = mwsf

/**
* Apply the given snapshot of the left-hand side to the magic wand map to get the snapshot of the right-hand side
* which includes the values of the left-hand side.
*
* @param snapLhs The snapshot of the left-hand side that should be applied to the magic wand map.
* @return The snapshot of the right-hand side that preserves the values of the left-hand side.
*/
def applyToMWSF(snapLhs: Term): Term = MWSFLookup(mwsf, snapLhs)
}

object MagicWandSnapshot {
def apply(snapshot: Term): MagicWandSnapshot = {
assert(snapshot.sort == sorts.Snap)
snapshot match {
case snap: MagicWandSnapshot => snap
case _ =>
MagicWandSnapshot(First(snapshot), Second(snapshot))
}
}
object MagicWandSnapshot extends PreciseCondFlyweightFactory[Term, MagicWandSnapshot] {
/** Create an instance of [[viper.silicon.state.terms.MagicWandSnapshot]]. */
override def actualCreate(arg: Term): MagicWandSnapshot =
new MagicWandSnapshot(arg)
}

// Since MagicWandSnapshot subclasses Combine, we apparently cannot inherit the normal subclass, so we
// have to copy paste the code here.
var pool = new TrieMap[(Term, Term), MagicWandSnapshot]()
/**
* Term that applies a [[sorts.MagicWandSnapFunction]] to a snapshot.
* It returns a snapshot for the RHS of a magic wand that includes that values of the given snapshot.
*
* @param mwsf Term of sort [[sorts.MagicWandSnapFunction]]. Function from `Snap` to `Snap`.
* @param snap Term of sort [[sorts.Snap]] to which the MWSF is applied to. It represents the values of the wand's LHS.
*/
class MWSFLookup(val mwsf: Term, val snap: Term) extends Term with ConditionalFlyweightBinaryOp[MWSFLookup] {
val sort: Sort = sorts.Snap
override def p0: Term = mwsf
override def p1: Term = snap
override lazy val toString = s"$mwsf[$snap]"
}

def createIfNonExistent(args: (Term, Term)): MagicWandSnapshot = {
if (Verifier.config.useFlyweight) {
pool.getOrElseUpdate(args, actualCreate(args))
} else {
actualCreate(args)
}
object MWSFLookup extends PreciseCondFlyweightFactory[(Term, Term), MWSFLookup] {
override def apply(pair: (Term, Term)): MWSFLookup = {
val (mwsf, snap) = pair
utils.assertSort(mwsf, "mwsf", sorts.MagicWandSnapFunction)
utils.assertSort(snap, "snap", sorts.Snap)
createIfNonExistent(pair)
}

def actualCreate(tuple: (Term, Term)) = new MagicWandSnapshot(tuple._1, tuple._2)
def apply(fst: Term, snd: Term): MagicWandSnapshot = createIfNonExistent((fst, snd))

def unapply(mws: MagicWandSnapshot) = Some((mws.abstractLhs, mws.rhsSnapshot))
/** Create an instance of [[viper.silicon.state.terms.MWSFLookup]]. */
override def actualCreate(args: (Term, Term)): MWSFLookup =
new MWSFLookup(args._1, args._2)
}

class MagicWandChunkTerm(val chunk: MagicWandChunk) extends Term with ConditionalFlyweight[MagicWandChunk, MagicWandChunkTerm] {
Expand Down Expand Up @@ -2461,7 +2481,7 @@ object PsfTop extends (String => Identifier) {
*/

/* Note: Sort wrappers should probably not be used as (outermost) triggers
* because they are optimised away if wrappee `t` already has sort `to`.
* because they are optimised away if wrapped `t` already has sort `to`.
*/
class SortWrapper(val t: Term, val to: Sort)
extends Term
Expand Down
5 changes: 3 additions & 2 deletions src/main/scala/state/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ package object utils {

}

/** @see [[viper.silver.ast.utility.Transformer.simplify()]] */
/** @see [[viper.silver.ast.utility.Simplifier.simplify]] */
def transform[T <: Term](term: T,
pre: PartialFunction[Term, Term] = PartialFunction.empty)
(recursive: Term => Boolean = !pre.isDefinedAt(_),
Expand Down Expand Up @@ -196,7 +196,8 @@ package object utils {
case MapUpdate(t0, t1, t2) => MapUpdate(go(t0), go(t1), go(t2))
case MapDomain(t) => MapDomain(go(t))
case MapRange(t) => MapRange(go(t))
case MagicWandSnapshot(lhs, rhs) => MagicWandSnapshot(go(lhs), go(rhs))
case MagicWandSnapshot(t) => MagicWandSnapshot(go(t))
case MWSFLookup(t0, t1) => MWSFLookup(go(t0), go(t1))
case Combine(t0, t1) => Combine(go(t0), go(t1))
case First(t) => First(go(t))
case Second(t) => Second(go(t))
Expand Down
Loading

0 comments on commit f8cc484

Please sign in to comment.