From 3cb853704677212369073dd3e248acb59a9de921 Mon Sep 17 00:00:00 2001 From: Alexander Stekelenburg Date: Fri, 29 Nov 2024 16:07:07 +0100 Subject: [PATCH] Fix heap variables with pointer types in C, fixes #1286 --- .../vct/rewrite/EncodeByValueClassUsage.scala | 8 ++-- .../vct/rewrite/VariableToPointer.scala | 38 +++++++++++++++---- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/src/rewrite/vct/rewrite/EncodeByValueClassUsage.scala b/src/rewrite/vct/rewrite/EncodeByValueClassUsage.scala index 0a8b9c354..eba9f9c17 100644 --- a/src/rewrite/vct/rewrite/EncodeByValueClassUsage.scala +++ b/src/rewrite/vct/rewrite/EncodeByValueClassUsage.scala @@ -212,11 +212,9 @@ case class EncodeByValueClassUsage[Pre <: Generation]() extends Rewriter[Pre] { case _ if inAssignment.nonEmpty => node.rewriteDefault() case Perm(ByValueClassLocation(e), p) => unwrapClassPerm(dispatch(e), dispatch(p), e.t.asByValueClass.get) - case Perm(pl @ PointerLocation(dhv @ DerefHeapVariable(Ref(v))), p) => - assert( - v.t.isInstanceOf[TNonNullPointer[Pre]], - "Frontends should ensure that HeapVariables are non-null pointers", - ) + // Only doing this for TNonNullPointer pointers since those originate from the frontend and users can define heap variables of the normal TPointer pointer type + case Perm(pl @ PointerLocation(dhv @ DerefHeapVariable(Ref(v))), p) + if v.t.isInstanceOf[TNonNullPointer[Pre]] => val t = v.t.asInstanceOf[TNonNullPointer[Pre]] if (t.element.asByValueClass.isDefined) { val newV: Ref[Post, HeapVariable[Post]] = succ(v) diff --git a/src/rewrite/vct/rewrite/VariableToPointer.scala b/src/rewrite/vct/rewrite/VariableToPointer.scala index 8ca7eb689..d76fb8502 100644 --- a/src/rewrite/vct/rewrite/VariableToPointer.scala +++ b/src/rewrite/vct/rewrite/VariableToPointer.scala @@ -48,7 +48,8 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { SuccessionMap() val fieldMap: SuccessionMap[InstanceField[Pre], InstanceField[Post]] = SuccessionMap() - val noTransform: ScopedStack[Unit] = ScopedStack() + val noTransform: ScopedStack[scala.collection.Set[Variable[Pre]]] = + ScopedStack() override def dispatch(program: Program[Pre]): Program[Rewritten[Pre]] = { // TODO: Replace the asByReferenceClass checks with something that more clearly communicates that we want to exclude all reference types @@ -72,6 +73,7 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { globalDeclarations.succeed(func, func.rewriteDefault()) } case proc: Procedure[Pre] => { + val skipVars = mutable.Set[Variable[Pre]]() val extraVars = mutable.ArrayBuffer[(Variable[Post], Variable[Post])]() // Relies on args being evaluated before body allScopes.anySucceed( @@ -84,6 +86,7 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { if (addressedSet.contains(v)) { variableMap(v) = new Variable(TNonNullPointer(dispatch(v.t)))(v.o) + skipVars += v extraVars += ((newV, variableMap(v))) } } @@ -114,7 +117,9 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { } } }, - contract = { noTransform.having(()) { dispatch(proc.contract) } }, + contract = { + noTransform.having(skipVars) { dispatch(proc.contract) } + }, ), ) } @@ -199,14 +204,13 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { override def dispatch(expr: Expr[Pre]): Expr[Post] = { implicit val o: Origin = expr.o - if (noTransform.nonEmpty) - return expr.rewriteDefault() expr match { case deref @ DerefHeapVariable(Ref(v)) if addressedSet.contains(v) => DerefPointer( DerefHeapVariable[Post](heapVariableMap.ref(v))(deref.blame) )(PanicBlame("Should always be accessible")) - case Local(Ref(v)) if addressedSet.contains(v) => + case Local(Ref(v)) + if addressedSet.contains(v) && !noTransform.exists(_.contains(v)) => DerefPointer(Local[Post](variableMap.ref(v)))(PanicBlame( "Should always be accessible" )) @@ -251,14 +255,34 @@ case class VariableToPointer[Pre <: Generation]() extends Rewriter[Pre] { obj.get, ), ) + case Perm(PointerLocation(AddrOf(DerefHeapVariable(Ref(v)))), perm) + if addressedSet.contains(v) => + val newPerm = dispatch(perm) + Star( + Perm(HeapVariableLocation[Post](heapVariableMap.ref(v)), newPerm), + Perm( + PointerLocation(DerefHeapVariable[Post](heapVariableMap.ref(v))( + PanicBlame("Access is framed") + ))(PanicBlame("Cannot be null")), + newPerm, + ), + ) + case Value(PointerLocation(AddrOf(DerefHeapVariable(Ref(v))))) + if addressedSet.contains(v) => + Star( + Value(HeapVariableLocation[Post](heapVariableMap.ref(v))), + Value( + PointerLocation(DerefHeapVariable[Post](heapVariableMap.ref(v))( + PanicBlame("Access is framed") + ))(PanicBlame("cannot be null")) + ), + ) case other => other.rewriteDefault() } } override def dispatch(loc: Location[Pre]): Location[Post] = { implicit val o: Origin = loc.o - if (noTransform.nonEmpty) - return loc.rewriteDefault() loc match { case HeapVariableLocation(Ref(v)) if addressedSet.contains(v) => PointerLocation(