Skip to content

Commit

Permalink
Add support for pattern alternatives (#1627)
Browse files Browse the repository at this point in the history
Addresses issues #1128 and #617 by adding support for pattern
alternatives in Stainless.

This was done in the context of our project in the Formal Verification
CS-550 course. This version adds a new Alternative pattern and handles
it throughout the pipeline and therefore allows disjunctions of patterns
to be used.

Note: This does not yet add support for pattern alternatives in the Coq
encoder or GenC, but does include support for pattern alternatives when
using the optional simplifiers OL and OCBSL.

---------

Co-authored-by: Mai-LinhC <mai-linh.cordonnier@epfl.ch>
  • Loading branch information
SidonieBouthors and Mai-LinhC authored Jan 10, 2025
1 parent 8060b33 commit c92fee2
Show file tree
Hide file tree
Showing 19 changed files with 174 additions and 2 deletions.
4 changes: 4 additions & 0 deletions core/src/main/scala/stainless/ast/Deconstructors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ trait TreeDeconstructor extends inox.ast.TreeDeconstructor {
(Seq(id), binder.map(_.toVariable).toSeq, recs, tps, subs, (ids, vs, es, tps, pats) => {
t.UnapplyPattern(vs.headOption.map(_.toVal), es, ids.head, tps, pats)
})
case s.AlternativePattern(binder, subs) =>
(Seq(), binder.map(_.toVariable).toSeq, Seq(), Seq(), subs, (_, vs, _, _, pats) => {
t.AlternativePattern(vs.headOption.map(_.toVal), pats)
})
}

/** Rebuild a match case from the given set of identifiers, variables, expressions and types */
Expand Down
8 changes: 8 additions & 0 deletions core/src/main/scala/stainless/ast/ExprOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,14 @@ class ExprOps(override val trees: Trees) extends inox.ast.ExprOps(trees) { self
(_ ++ _)
)

case AlternativePattern(vdOpt, subPatterns) =>
val freshVdOpt = vdOpt.map(vd => transform(vd.freshen, env))
// We don't need to freshen the subPatterns here, as they are not bound
(
AlternativePattern(freshVdOpt, subPatterns),
env ++ freshVdOpt.map(freshVd => vdOpt.get.id -> freshVd.id)
)

case LiteralPattern(vdOpt, lit) =>
val freshVdOpt = vdOpt.map(vd => transform(vd.freshen, env))
val newEnv = env ++ freshVdOpt.map(freshVd => vdOpt.get.id -> freshVd.id)
Expand Down
7 changes: 7 additions & 0 deletions core/src/main/scala/stainless/ast/Expressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ trait Expressions extends inox.ast.Expressions with Types { self: Trees =>
val subPatterns = Seq()
}

/**
* Pattern encoding like `case binder @ (subPattern1 | subPattern2 | ...) => ...`
*
* If [[binder]] is empty, consider a wildcard `_` in its place.
*/
sealed case class AlternativePattern(binder: Option[ValDef], subPatterns: Seq[Pattern]) extends Pattern

protected def unapplyScrut(scrut: Expr, up: UnapplyPattern)(using s: Symbols): Expr = {
FunctionInvocation(up.id, up.tps, up.recs :+ scrut)
}
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/stainless/ast/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ trait Printer extends inox.ast.Printer {
printNameWithPath(id)
p"(${nary(subs)})"

case AlternativePattern(ovd, subs) =>
ovd foreach (vd => p"${vd.toVariable} : ")
p"(${nary(subs, " | ")})"

case Passes(in, out, cases) =>
optP {
p"""|($in, $out) passes {
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/scala/stainless/ast/SymbolOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ trait SymbolOps extends inox.ast.SymbolOps with TypeOps { self =>
val subTests = subps.zipWithIndex.map { case (p, i) => apply(tupleSelect(in, i+1, subps.size), p) }
bind(ob, in) `merge` subTests

case AlternativePattern(ob, subps) =>
// one of the alternatives must hold (disjunction)
// we use A \/ B = ~ (~A /\ ~B)
val disjunction = subps.map(p => apply(in, p).negate).reduce(_ `merge` _).negate
bind(ob, in) `merge` disjunction

case up @ UnapplyPattern(ob, _, _, _, subps) =>
val subs = unwrapTuple(up.get(in), subps.size).zip(subps) map (apply).tupled
bind(ob, in) `withCond` Not(up.isEmpty(in)) `merge` subs
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/stainless/ast/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ trait TypeOps extends inox.ast.TypeOps {
case _ => false
}

case AlternativePattern(ob, subs) =>
ob.forall(vd => isSubtypeOf(vd.getType, in)) &&
(subs exists (patternIsTyped(in, _)))

case up @ UnapplyPattern(ob, recs, id, tps, subs) =>
ob.forall(vd => isSubtypeOf(vd.getType, in)) &&
lookupFunction(id).exists(_.tparams.size == tps.size) && {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ abstract class RecursiveEvaluator(override val program: Program,
None
}

case (AlternativePattern(ob, subs), scrut) =>
subs.map(matchesPattern(_, scrut)).find(_.isDefined) match {
case Some(_) => Some(obind(ob, expr)) // There should be no mapping nested in the alternative
case _ => None
}

case (up @ UnapplyPattern(ob, rec, id, tps, subs), scrut) =>
val eRec = rec map e
val unapp = e(FunctionInvocation(id, tps, eRec :+ scrut))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ trait TransformerWithType extends TreeTransformer {
val rsubs = (subs zip tps).map(p => transform(p._1, p._2))
t.TuplePattern(ob map transform, rsubs).copiedFrom(pat)

case s.AlternativePattern(ob, subs) =>
val rsubs = subs map (transform(_, tpe))
t.AlternativePattern(ob map transform, rsubs).copiedFrom(pat)

case up @ s.UnapplyPattern(ob, recs, id, tps, subs) =>
val rsubs = (subs zip up.subTypes(tpe)).map(p => transform(p._1, p._2))
val rrecs = (recs zip getFunction(id, tps).params.init).map(p => transform(p._1, p._2.getType))
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/stainless/extraction/oo/TypeEncoding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,12 @@ class TypeEncoding(override val s: Trees, override val t: Trees)
case s.TuplePattern(Some(vd), _) =>
instanceOfPattern(super.transform(pat, vd.tpe), tpe, vd.tpe)

case s.AlternativePattern(None, subs) =>
t.AlternativePattern(None, subs.map(transform)).copiedFrom(pat)

case s.AlternativePattern(Some(vd), subs) =>
instanceOfPattern(t.AlternativePattern(Some(transform(vd)), subs.map(transform)).copiedFrom(pat), tpe, vd.tpe)

case up @ s.UnapplyPattern(ob, recs, id, tps, subs) =>
val funScope = this `in` id
val FunInfo(fun, tparams) = functions(id)
Expand Down Expand Up @@ -1028,6 +1034,10 @@ class TypeEncoding(override val s: Trees, override val t: Trees)
super.transform(pat, in)
}

case s.AlternativePattern(ob, subs) =>
simple --= s.typeOps.typeParamsOf(in.getType)
super.transform(pat, in)

case up @ s.UnapplyPattern(ob, recs, id, tps, subs) =>
val tparams = infos(id).tparams
simple --= tps.zipWithIndex.flatMap { case (tp, i) =>
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/stainless/extraction/oo/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ trait TypeOps extends innerfuns.TypeOps { self =>
.map(cons => ADTType(cons.sort, tps))
.getOrElse(Untyped)
case TuplePattern(_, subs) => TupleType(subs map patternInType)
case AlternativePattern(_, subs) => leastUpperBound(subs map patternInType)
case ClassPattern(_, ct, subs) => ct
case UnapplyPattern(_, recs, id, tps, _) =>
lookupFunction(id)
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/stainless/genc/phases/Scala2IRPhase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,9 @@ private class S2IRImpl(override val s: tt.type,
update(b, scrutinee)
buildBinOp(scrutinee, O.Equals, lit)(pat.getPos)

case AlternativePattern(_, _) =>
reporter.fatalError(pat.getPos, s"Alternative Pattern, a.k.a pattern disjunction, is not yet supported by GenC")

case UnapplyPattern(_, _, _, _, _) =>
reporter.fatalError(pat.getPos, s"Unapply Pattern, a.k.a. Extractor Objects, is not supported by GenC")
}
Expand Down
19 changes: 19 additions & 0 deletions core/src/main/scala/stainless/transformers/lattices/Core.scala
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,13 @@ trait Core extends Definitions { ocbsl =>
val subScruts = tupleSubscrutinees(scrut, tt)
val (rsubs, subst2) = recHelper(subScruts, subps)
(LabelledPattern.TuplePattern(rsubs), subst2)
case AlternativePattern(_, subps) =>
val (rsubs, subst2) = subps.foldLeft((Seq.empty[LabelledPattern], subst1)) {
case ((acc, subst), subp) =>
val (rsub, subst2) = convertPattern(scrut, subp, subst)
(acc :+ rsub, subst2)
}
(LabelledPattern.Alternative(rsubs), subst2)
case UnapplyPattern(_, recs, id, tps, subps) =>
if (recs.nonEmpty) throw UnsupportedOperationException("recs is not empty")
val unapp = unapplySubScrutinees(scrut, id, tps)
Expand Down Expand Up @@ -2644,6 +2651,7 @@ trait Core extends Definitions { ocbsl =>
assert(bases.size == subps.size)
val rsubs = recHelper(tupleSubscrutinees(scrut, tt), subps)
TuplePattern(bdg, rsubs)
case LabelledPattern.Alternative(subs) => AlternativePattern(bdg, subs.map(sub => convertPattern(scrut, sub, vds)))
case LabelledPattern.Lit(lit) => LiteralPattern(bdg, lit)
case LabelledPattern.Unapply(recs, id, tps, subps) =>
assert(recs.isEmpty)
Expand Down Expand Up @@ -3286,6 +3294,17 @@ trait Core extends Definitions { ocbsl =>
assert(ctxs.isPrefixOf(newCtxs))
PatBdgsAndConds(newCtxs, subscruts ++ recBdgs, recPatConds)

case LabelledPattern.Alternative(sub) =>
val PatBdgsAndConds(newCtxs, _, subPatConds) =
sub.foldLeft(PatBdgsAndConds(ctxs, Seq.empty, Seq.empty)) {
case (PatBdgsAndConds(ctxs, _, condsAcc), subpat) =>
val PatBdgsAndConds(ctxs2, _, conds2) = addPatternBindingsAndConds(ctxs, scrut, subpat)
PatBdgsAndConds(ctxs2, Seq.empty, condsAcc ++ conds2)
}
val cond = codeOfSig(mkOr(subPatConds), BoolTy)
assert(ctxs.isPrefixOf(newCtxs))
PatBdgsAndConds(newCtxs.withCond(cond), Seq.empty, Seq(cond))

case LabelledPattern.Unapply(recs, id, tps, subps) =>
assert(recs.isEmpty)
val unapp = unapplySubScrutinees(scrut, id, tps)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ trait Definitions {
case Wildcard extends LabelledPattern
case ADT(id: Identifier, tps: Seq[Type], sub: Seq[LabelledPattern]) extends LabelledPattern
case TuplePattern(sub: Seq[LabelledPattern]) extends LabelledPattern
case Alternative(sub: Seq[LabelledPattern]) extends LabelledPattern
case Lit[T](lit: Literal[T]) extends LabelledPattern
case Unapply(recs: Seq[Code], id: Identifier, tps: Seq[Type], sub: Seq[LabelledPattern]) extends LabelledPattern

Expand All @@ -231,6 +232,7 @@ trait Definitions {
case Wildcard => Seq.empty
case ADT(_, _, sub) => sub.flatMap(_.allPatterns)
case TuplePattern(sub) => sub.flatMap(_.allPatterns)
case Alternative(sub) => sub.flatMap(_.allPatterns)
case Lit(_) => Seq.empty
case Unapply(_, _, _, sub) => sub.flatMap(_.allPatterns)
})
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/stainless/utils/Serialization.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class StainlessSerializer(override val trees: ast.Trees, serializeProducts: Bool
/** An extension to the set of registered classes in the `InoxSerializer`.
* occur within Stainless programs.
*
* The new identifiers in the mapping range from 120 to 172.
* The new identifiers in the mapping range from 120 to 173.
*
* NEXT ID: 173
* NEXT ID: 174
*/
override protected def classSerializers: Map[Class[?], Serializer[?]] =
super.classSerializers ++ Map(
Expand All @@ -40,6 +40,7 @@ class StainlessSerializer(override val trees: ast.Trees, serializeProducts: Bool
stainlessClassSerializer[TuplePattern] (130),
stainlessClassSerializer[LiteralPattern[Any]](131),
stainlessClassSerializer[UnapplyPattern] (132),
stainlessClassSerializer[AlternativePattern] (173),
stainlessClassSerializer[FiniteArray] (133),
stainlessClassSerializer[LargeArray] (134),
stainlessClassSerializer[ArraySelect] (135),
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/stainless/verification/CoqEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ trait CoqEncoder {
ctx.reporter.warning(s"Ignoring type $tpe in the wildcard pattern $p.")
//TODO not tested
CoqTuplePatternVd(ps.map(transformPattern), VariablePattern(Some(makeFresh(id))))
case AlternativePattern(_, _) =>
ctx.reporter.fatalError(s"The translation to Coq does not support disjunctive patterns such as `$p` (${p.getClass}) yet.")
case _ => ctx.reporter.fatalError(s"Coq does not support patterns such as `$p` (${p.getClass}) yet.")
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
object PatternAlternative1 {
sealed trait SignSet
case object None extends SignSet
case object Any extends SignSet
case object Neg extends SignSet
case object Zer extends SignSet
case object Pos extends SignSet
case object NegZer extends SignSet
case object NotZer extends SignSet
case object PosZer extends SignSet

def subsetOf(a: SignSet, b: SignSet): Boolean = (a, b) match {
case (None, _) => true
case (_, Any) => true
case (Neg, NegZer | NotZer) => true
case (Zer, NegZer | PosZer) => true
case (Pos, NotZer | PosZer) => true
case _ => false
}
}
33 changes: 33 additions & 0 deletions frontends/benchmarks/extraction/valid/PatternAlternative2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
object PatternAlternative2 {
sealed trait Tree
case class Node(left: Tree, right: Tree) extends Tree
case class IntLeaf(value: Int) extends Tree
case class StringLeaf(value: String) extends Tree
case class NoneLeaf() extends Tree

def containsNoneLeaf(tree: Tree): Boolean = {
tree match {
case Node(left, right) => containsNoneLeaf(left) || containsNoneLeaf(right)
case NoneLeaf() => true
case _ => false
}
}

def containsOnlyBinaryLeaves(tree: Tree): Boolean = {
tree match {
case Node(left, right) => containsOnlyBinaryLeaves(left) && containsOnlyBinaryLeaves(right)
case IntLeaf(v) => v == 0 || v == 1
case StringLeaf(v) => v == "0" || v == "1"
case _ => true
}
}

def hasBinaryLeaves(tree: Tree): Boolean = {
require(!containsNoneLeaf(tree) && containsOnlyBinaryLeaves(tree))
tree match {
case a @ Node(left: (IntLeaf | StringLeaf), right: (IntLeaf | StringLeaf)) => hasBinaryLeaves(left) && hasBinaryLeaves(right)
case b @ (IntLeaf(0 | 1) | StringLeaf("0" | "1")) => true
case _ => false
}
} ensuring { res => res }
}
33 changes: 33 additions & 0 deletions frontends/benchmarks/verification/valid/PatternAlternative.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
object PatternAlternative {
sealed trait Tree
case class Node(left: Tree, right: Tree) extends Tree
case class IntLeaf(value: Int) extends Tree
case class StringLeaf(value: String) extends Tree
case class NoneLeaf() extends Tree

def containsNoneLeaf(tree: Tree): Boolean = {
tree match {
case Node(left, right) => containsNoneLeaf(left) || containsNoneLeaf(right)
case NoneLeaf() => true
case _ => false
}
}

def containsOnlyBinaryLeaves(tree: Tree): Boolean = {
tree match {
case Node(left, right) => containsOnlyBinaryLeaves(left) && containsOnlyBinaryLeaves(right)
case IntLeaf(v) => v == 0 || v == 1
case StringLeaf(v) => v == "0" || v == "1"
case _ => true
}
}

def hasBinaryLeaves(tree: Tree): Boolean = {
require(!containsNoneLeaf(tree) && containsOnlyBinaryLeaves(tree))
tree match {
case a @ Node(left: (IntLeaf | StringLeaf), right: (IntLeaf | StringLeaf)) => hasBinaryLeaves(left) && hasBinaryLeaves(right)
case b @ (IntLeaf(0 | 1) | StringLeaf("0" | "1")) => true
case _ => false
}
} ensuring { res => res }
}
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,11 @@ class CodeExtraction(inoxCtx: inox.Context,
// Note that this pattern will be correctly rejected as "Unsupported pattern" (in fact, it cannot be even tested at runtime):
// val (aa: B, bb: B) = (a, b)
private def extractPattern(p: tpd.Tree, expectedTpe: Option[xt.Type], binder: Option[xt.ValDef] = None)(using dctx: DefContext): (xt.Pattern, DefContext) = p match {

case a @ Alternative(subpatterns) =>
val (patterns, nctx) = subpatterns.map(extractPattern(_, expectedTpe)).unzip
(xt.AlternativePattern(binder, patterns), dctx)

case b @ Bind(name, t @ Typed(pat, tpt)) =>
val vd = xt.ValDef(FreshIdentifier(name.toString), extractType(tpt), annotationsOf(b.symbol, ignoreOwner = true)).setPos(b.sourcePos)
val pctx = dctx.withNewVar(b.symbol -> (() => vd.toVariable))
Expand Down

0 comments on commit c92fee2

Please sign in to comment.