Skip to content

Commit

Permalink
Making progress on endpoint projection
Browse files Browse the repository at this point in the history
  • Loading branch information
bobismijnnaam committed Dec 12, 2024
1 parent fde1216 commit 19bbf72
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 39 deletions.
31 changes: 28 additions & 3 deletions src/col/vct/col/origin/Blame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1415,9 +1415,30 @@ case class BlameUnreachable(message: String, failure: VerificationFailure)
.toString.split('\n').mkString(" > ", "\n > ", "")}"
}

case class PanicBlame(message: String) extends Blame[VerificationFailure] {
override def blame(error: VerificationFailure): Unit =
throw BlameUnreachable(message, error)
// Base PanicBlame object, not for general use. Instead, use
// PanicBlame.apply or one of the extended PanicBlame object below.
class PanicBlame(file: String, line: Int, message: String)
extends Blame[VerificationFailure] {
override def blame(error: VerificationFailure): Unit = {
if (file == null) { // Assuming file & line are either both null or both non-null
throw BlameUnreachable(message, error)
} else
throw BlameUnreachable(s"At $file:$line: $message", error)
}

def this(file: sourcecode.File, line: sourcecode.Line, message: String) =
this(file.value, line.value, message)

def this(message: String) = this(null, null, message)
}

// Plain PanicBlame for general use, if there is no pre-made PanicBlame
// or your blame is too specific.
object PanicBlame {
def apply(
message: String
)(implicit file: sourcecode.File, line: sourcecode.Line): PanicBlame =
new PanicBlame(file, line, message)
}

object NeverNone
Expand Down Expand Up @@ -1479,6 +1500,10 @@ object TriggerPatternBlame
)
object TrueSatisfiable
extends PanicBlame("`requires true` is always satisfiable.")
case class TrivialContract()(
implicit file: sourcecode.File,
line: sourcecode.Line,
) extends PanicBlame(file, line, "the trivial contract cannot fail")
object FramedPtrOffset
extends PanicBlame(
"pointer arithmetic in (0 <= \\pointer_block_offset(p)+i < \\pointer_block_length(p)) ? p+i : _ should always be ok."
Expand Down
30 changes: 25 additions & 5 deletions src/col/vct/col/util/AstBuildHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -965,11 +965,31 @@ object AstBuildHelpers {
def foldAnd[G](exprs: Seq[Expr[G]])(implicit o: Origin): Expr[G] =
exprs.reduceOption(And(_, _)).getOrElse(tt)

def foldAny[G](t: Type[_])(exprs: Seq[Expr[G]])(implicit o: Origin): Expr[G] =
t match {
case TBool() => foldAnd(exprs)
case TResource() => foldStar(exprs)
case _ => ???
def foldAny[G](
t: Type[_]
)(exprs: Seq[Expr[G]])(implicit o: Origin): Option[Expr[G]] = {
exprs match {
case Seq() => None
case exprs =>
t match {
case TBool() => Some(foldAnd(exprs))
case TResource() => Some(foldStar(exprs))
case _ => ???
}
}
}

def foldAny1[G](t: Type[_])(exprs: Seq[Expr[G]])(
implicit o: Origin
): Expr[G] = foldAny(t)(exprs).get

// This is basically unfoldStar, except that it does not make the "true" case
// disappear into Nil.
def unfoldAny[G](expr: Expr[G]): Seq[Expr[G]] =
expr match {
case Star(left, right) => unfoldAny(left) ++ unfoldAny(right)
case And(left, right) => unfoldAny(left) ++ unfoldAny(right)
case other => Seq(other)
}

def loop[G](
Expand Down
11 changes: 2 additions & 9 deletions src/rewrite/vct/rewrite/veymont/StratifyExpressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,8 @@ case class StratifyExpressions[Pre <: Generation]()
case statement => statement.rewriteDefault()
}

def dumbUnfoldStar[G](expr: Expr[G]): Seq[Expr[G]] =
expr match {
case Star(left, right) => dumbUnfoldStar(left) ++ dumbUnfoldStar(right)
case And(left, right) => dumbUnfoldStar(left) ++ dumbUnfoldStar(right)
case other => Seq(other)
}

def stratifyExpr(e: Expr[Pre]): Expr[Post] = {
val exprs = dumbUnfoldStar(e)
val exprs = unfoldAny(e)
foldAny(e.t)(
exprs.map {
case e: ChorExpr[Pre] => (None, e)
Expand All @@ -134,7 +127,7 @@ case class StratifyExpressions[Pre <: Generation]()
EndpointExpr[Post](dispatch(endpoint), dispatch(expr))(expr.o)
case (None, expr) => expr.rewriteDefault()
}
)(e.o)
)(e.o).get
}

// "Points" an expression in the direction of an endpoint if possible
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ case class StratifyUnpointedExpressions[Pre <: Generation]()

def stratifyExpr(expr: Expr[Pre]): Expr[Post] = {
implicit val o = expr.o
foldAny(expr.t)(unfoldStar(expr).flatMap {
foldAny1(expr.t)(unfoldAny(expr).flatMap {
case expr @ (_: EndpointExpr[Pre] | _: ChorExpr[Pre]) =>
Seq(expr.rewriteDefault())
case expr =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ case class EncodeChannels[Pre <: Generation]()
// Helper for rewriting the invariant. Regular expressions we wrap in the EndpointExpr of the sender/receiver
// ChorExpr's we leave untouched. Those will be encoded by the EncodeStratifiedPermissions pass.
def wrapEndpointExpr(expr: Expr[Pre], ep: Endpoint[Pre]): Expr[Post] =
foldAny(expr.t)(unfoldStar(expr).map {
foldAny1(expr.t)(unfoldStar(expr).map {
case e: ChorExpr[Pre] => dispatch(e)
case e => EndpointExpr(CommTargetEndpoint(succ(ep)), dispatch(e))
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,34 @@ package vct.rewrite.veymont.verification

import com.typesafe.scalalogging.LazyLogging
import hre.util.ScopedStack
import vct.col.ast._
import vct.col.ast.{
ChorStatement,
Declaration,
Expr,
LoopContract,
EndpointExpr,
CommTargetEndpoint,
CommTargetIndex,
Declaration,
Choreography,
Program,
Loop,
Branch,
Statement,
Block,
Assert,
LoopInvariant,
IterationContract,
CommunicateTarget,
RangeBinder,
CommTargetRange,
Select,
Variable,
TInt,
TBool,
}
import vct.col.origin._
import vct.col.ref.{Ref, DirectRef}
import vct.col.ref.{DirectRef, Ref}
import vct.col.rewrite.{Generation, Rewriter, RewriterBuilderArg}
import vct.col.util.AstBuildHelpers._
import vct.rewrite.veymont.VeymontContext
Expand Down Expand Up @@ -139,30 +164,129 @@ case class EncodeChorBranchUnanimity[Pre <: Generation](enabled: Boolean)
case _ => contract.rewriteDefault()
}

def min(a: Expr[Post], b: Expr[Post]): Expr[Post] = ???
def max(a: Expr[Post], b: Expr[Post]): Expr[Post] = ???
def unanimous(exprs: Expr[Pre]): Expr[Post] = {
val subExprs: Seq[EndpointExpr[Pre]] = unfoldAny(exprs).map {
case expr: EndpointExpr[Pre] => expr
case _ => ???
}
require(subExprs.nonEmpty)
val alphas = subExprs.map(_.endpoint)
// This is either an endpoint e or an indexed family F[i], meaning a singular endpoint
// All sub-conditions in the total condition will be checked to be equal to the condition of this endpoint
// For aesthetic purposes, we try to pick a singular endpoint from the list of alphas, but this is not strictly
// necessary. You could pick any of the endpoints participating in the condition.
val grounder = {
val singularCandidate = alphas.collectFirst {
case endpoint: CommTargetEndpoint[Pre] => endpoint
case index: CommTargetIndex[Pre] => index
}
singularCandidate match {
case Some(singular) => singular
case None =>
// No singular candidate in condition. Since the condition is nonempty,
// and there were no singular CommunicationTarget, it is guaranteed that the first is a CommTargetRange.
// We just pick that
alphas.head
}
}
// Drop all endpoint exprs that mention the grounder. This is incomplete, as we do only a syntactic check. But in the case
// of a singular CommunicateTarget it reduces the length of the filteredSubExprs list nicely.
val filteredSubExprs = subExprs.filter(_.endpoint != grounder)
val filteredAlphas = filteredSubExprs.map { _.endpoint }
implicit val o: Origin = TraceOrigin()
val groundCondition =
??? // Kind of need to apply ground here, but that requires a groundCondition. I just want to apply the partial projection operator here... Hmm
foldAny1(exprs.t)(filteredAlphas.map { alpha =>
ground(filteredSubExprs, grounder, alpha)
})
}

def ground(
groundCondition: Expr[Post],
exprs: Seq[EndpointExpr[Pre]],
mainTarget: CommunicateTarget[Pre],
)(implicit o: Origin): Expr[Post] =
groundCondition === foldAnd(exprs.map { expr =>
narrowCommunicateTarget(expr.endpoint, mainTarget).map { subTarget =>
expr.rewrite(endpoint = subTarget)
}.collect { case Some(expr) => expr }
})

def intersect(
left: CommunicateTarget[Pre],
right: CommunicateTarget[Pre],
): CommunicateTarget[Post] =
(left, right) match {
case (left, right) if left.isSingle && right.isSingle && left == right =>
dispatch(left)
// Either narrows a target in accordance with some context, or returns None if the two targets are
// not related - e.g. when narrowing an endpoint to the context of a endpoint range.
def narrowCommunicateTarget(
target: CommunicateTarget[Pre],
context: CommunicateTarget[Pre],
): Option[CommunicateTarget[Post]] =
(target, context) match {
case (target, context) if target.isSingle && target == context =>
Some(dispatch(target))
case (
CommTargetRange(Ref(f), RangeBinder(_, fLow, fHigh)),
CommTargetRange(Ref(g), RangeBinder(_, gLow, gHigh)),
) if f == g =>
CommTargetRange(Ref(a), RangeBinder(binder, fLow, fHigh)),
CommTargetRange(Ref(b), RangeBinder(_, gLow, gHigh)),
) if a == b =>
implicit val o: Origin = TraceOrigin()
CommTargetRange[Post](
succ(f),
Some(CommTargetRange[Post](
succ(a),
RangeBinder(
new Variable(TInt()),
variables.dispatch(binder),
max(fLow.rewriteDefault(), gLow.rewriteDefault()),
min(fHigh.rewriteDefault(), gHigh.rewriteDefault()),
),
)
case (left, right) if right.isSingle => intersect(right, left)
case (left, right) => assert(left.isSingle && right.isRange)
))
case (
CommTargetIndex(ref @ Ref(a), i),
CommTargetRange(Ref(b), RangeBinder(_, low, high)),
) if a == b =>
implicit val o: Origin = TraceOrigin()
// Implement support for this case by simulating the case of F[i] as a range F[i' := i .. i + 1]
// (in this cased, the i' variable does not need to be used: it is directly equal to i, so i can safely be used instead)
// (except that i is an expr and i' a var, but that's irrelevant here)
Some(CommTargetRange[Post](
endpoints.dispatch(ref),
RangeBinder(
new Variable(TInt()),
max(dispatch(i), dispatch(low)),
min(dispatch(i) + const(1), dispatch(high)),
),
))

}

def compareFun(
op: (Expr[Post], Expr[Post]) => Expr[Post]
)(implicit o: Origin): Function[Post] = {
val x = new Variable[Post](TInt())(o.sourceName("x"))
val y = new Variable[Post](TInt())(o.sourceName("y"))
function[Post](
args = Seq(x, y),
returnType = TInt(),
body = Some(Select[Post](op(x.get, y.get), x.get, y.get)),
blame = TrivialContract(),
contractBlame = TrueSatisfiable,
).declare()
}

lazy val minFun = {
implicit val o: Origin = TraceOrigin().sourceName("min")
compareFun((x, y) => x < y)
}
lazy val maxFun = {
implicit val o: Origin = TraceOrigin().sourceName("max")
compareFun((x, y) => x > y)
}

def min(a: Expr[Post], b: Expr[Post])(implicit o: Origin): Expr[Post] =
functionInvocation(
ref = minFun.ref,
args = Seq(a, b),
blame = TrivialContract(),
)
def max(a: Expr[Post], b: Expr[Post])(implicit o: Origin): Expr[Post] =
functionInvocation(
ref = maxFun.ref,
args = Seq(a, b),
blame = TrivialContract(),
)

}

0 comments on commit 19bbf72

Please sign in to comment.