Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Choice-else branch support #232

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,43 @@ public SimplifyResult visit(final NonDetStmt stmt, final MutableValuation valuat
successfulSubStmts.add(result.stmt);
}
}
if (successfulSubStmts.size() == 0) {
MutableValuation elzeValuation = null;
Stmt elzeSubStmt = null;
if (stmt.getElze() != null) {
final MutableValuation subVal = MutableValuation.copyOf(valuation);
final SimplifyResult result = stmt.getElze().accept(this, subVal);
if (result.status == SimplifyStatus.SUCCESS) {
elzeValuation = subVal;
elzeSubStmt = result.stmt;
}
}

if (elzeSubStmt != null) {
if (successfulSubStmts.isEmpty()) {
elzeSubStmt.accept(this, valuation);
return SimplifyResult.of(elzeSubStmt, SimplifyStatus.SUCCESS);
} else {
successfulSubStmts.get(0).accept(this, valuation);
List<Decl<?>> toRemove = new ArrayList<>();
for (Decl<?> decl : valuation.getDecls()) {
for (MutableValuation subVal : valuations) {
if (!valuation.eval(decl).equals(subVal.eval(decl))) {
toRemove.add(decl);
break;
}
}
if (!valuation.eval(decl).equals(elzeValuation.eval(decl))) {
toRemove.add(decl);
}
}
for (Decl<?> decl : toRemove) {
valuation.remove(decl);
}
return SimplifyResult.of(NonDetStmt.of(successfulSubStmts, elzeSubStmt), SimplifyStatus.SUCCESS);
}
}

if (successfulSubStmts.isEmpty()) {
return SimplifyResult.of(AssumeStmt.of(False()), SimplifyStatus.BOTTOM);
} else if (successfulSubStmts.size() == 1) {
successfulSubStmts.get(0).accept(this, valuation);
Expand All @@ -168,7 +204,6 @@ public SimplifyResult visit(final NonDetStmt stmt, final MutableValuation valuat
}
return SimplifyResult.of(NonDetStmt.of(successfulSubStmts), SimplifyStatus.SUCCESS);
}

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,38 @@ public final class NonDetStmt implements Stmt {

private final List<Stmt> stmts;

private final Stmt elze;

private static final int HASH_SEED = 361;
private static final String STMT_LABEL = "nondet";

private volatile int hashCode = 0;

private NonDetStmt(final List<Stmt> stmts) {
private NonDetStmt(final List<Stmt> stmts, final Stmt elze) {
if (stmts.isEmpty()) {
this.stmts = ImmutableList.of(SkipStmt.getInstance());
} else {
this.stmts = stmts;
}
this.elze = elze; // Null if does not exist
}

public static NonDetStmt of(final List<Stmt> stmts) {
return new NonDetStmt(stmts);
return of(stmts, null);
}

public static NonDetStmt of(final List<Stmt> stmts, Stmt elze) {
return new NonDetStmt(stmts, elze);
}

public List<Stmt> getStmts() {
return stmts;
}

public Stmt getElze() {
return elze;
}

@Override
public <P, R> R accept(final StmtVisitor<? super P, ? extends R> visitor, final P param) {
return visitor.visit(this, param);
Expand All @@ -57,6 +68,9 @@ public int hashCode() {
if (result == 0) {
result = HASH_SEED;
result = 31 * result + stmts.hashCode();
if (getElze() != null) {
result = 31 * result + elze.hashCode();
}
hashCode = result;
}
return result;
Expand All @@ -68,15 +82,27 @@ public boolean equals(final Object obj) {
return true;
} else if (obj instanceof NonDetStmt) {
final NonDetStmt that = (NonDetStmt) obj;
return this.getStmts().equals(that.getStmts());
boolean equalsValue = this.getStmts().equals(that.getStmts());
if (this.getElze() == null) {
equalsValue &= that.getElze() == null;
} else {
equalsValue &= this.getElze().equals(that.getElze());
}
return equalsValue;
} else {
return false;
}
}

@Override
public String toString() {
return Utils.lispStringBuilder(STMT_LABEL).addAll(stmts).toString();
final var str = Utils.lispStringBuilder(STMT_LABEL).addAll(stmts);

if (getElze() != null) {
str.add("elze").add(getElze());
}

return str.toString();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package hu.bme.mit.theta.core.utils;

import com.google.common.collect.ImmutableList;
import hu.bme.mit.theta.core.decl.Decls;
import hu.bme.mit.theta.core.decl.VarDecl;
import hu.bme.mit.theta.core.stmt.AssignStmt;
import hu.bme.mit.theta.core.stmt.AssumeStmt;
Expand All @@ -30,6 +31,8 @@
import hu.bme.mit.theta.core.stmt.StmtVisitor;
import hu.bme.mit.theta.core.type.Expr;
import hu.bme.mit.theta.core.type.Type;
import hu.bme.mit.theta.core.type.anytype.Exprs;
import hu.bme.mit.theta.core.type.anytype.IteExpr;
import hu.bme.mit.theta.core.type.booltype.BoolType;
import hu.bme.mit.theta.core.type.booltype.SmartBoolExprs;
import hu.bme.mit.theta.core.type.fptype.FpType;
Expand All @@ -42,14 +45,18 @@
import java.util.Set;

import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Eq;
import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Geq;
import static hu.bme.mit.theta.core.type.abstracttype.AbstractExprs.Ite;
import static hu.bme.mit.theta.core.type.anytype.Exprs.Prime;
import static hu.bme.mit.theta.core.type.anytype.Exprs.Ref;
import static hu.bme.mit.theta.core.type.booltype.BoolExprs.And;
import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Bool;
import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Iff;
import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Or;
import static hu.bme.mit.theta.core.type.booltype.BoolExprs.True;
import static hu.bme.mit.theta.core.type.fptype.FpExprs.FpAssign;
import static hu.bme.mit.theta.core.type.inttype.IntExprs.Int;
import static hu.bme.mit.theta.core.type.inttype.IntExprs.Leq;
import static hu.bme.mit.theta.core.utils.TypeUtils.cast;

final class StmtToExprTransformer {
Expand Down Expand Up @@ -130,42 +137,88 @@ public StmtUnfoldResult visit(SequenceStmt sequenceStmt, VarIndexing indexing) {

@Override
public StmtUnfoldResult visit(NonDetStmt nonDetStmt, VarIndexing indexing) {

final List<Expr<BoolType>> choices = new ArrayList<>();

final List<VarIndexing> indexings = new ArrayList<>();
VarIndexing jointIndexing = indexing;
int count = 0;
VarDecl<IntType> tempVar = VarPoolUtil.requestInt();
for (Stmt stmt : nonDetStmt.getStmts()) {
var tempVar = VarPoolUtil.requestInt();
for (var stmt : nonDetStmt.getStmts()) {
final Expr<BoolType> tempExpr = Eq(
ExprUtils.applyPrimes(tempVar.getRef(), indexing), Int(count++));
ExprUtils.applyPrimes(tempVar.getRef(), indexing),
Int(count++)
);
final StmtUnfoldResult result = toExpr(stmt, indexing.inc(tempVar));
choices.add(And(tempExpr, And(result.exprs)));
indexings.add(result.indexing);
jointIndexing = jointIndexing.join(result.indexing);
}
final Set<VarDecl<?>> vars = ExprUtils.getVars(choices);
final List<Expr<BoolType>> branchExprs = new ArrayList<>();
for (int i = 0; i < choices.size(); i++) {
final List<Expr<BoolType>> exprs = new ArrayList<>();
exprs.add(choices.get(i));
for (VarDecl<?> decl : vars) {
int currentBranchIndex = indexings.get(i).get(decl);
int jointIndex = jointIndexing.get(decl);

final var branchExprs = fixVariablePrimes(choices, indexings, jointIndexing);

final var choiceExpr = Or(branchExprs);
final Expr<BoolType> expr;

if (nonDetStmt.getElze() != null) {
final var tempExpr = Eq(
ExprUtils.applyPrimes(tempVar.getRef(), indexing), Int(count)
);
final var result = toExpr(nonDetStmt.getElze(), indexing.inc(tempVar));
final var elzeIndexing = result.indexing;
final var elzeExpr = And(tempExpr, And(result.exprs));
final var elzeJointIndexing = jointIndexing.join(elzeIndexing);

final var expressions = fixVariablePrimes(
ImmutableList.of(choiceExpr, elzeExpr),
ImmutableList.of(jointIndexing, elzeIndexing),
elzeJointIndexing
);

final var choiceExprExtended = expressions.get(0);
final var elzeExprExtended = expressions.get(1);

final var choiceVar = VarPoolUtil.requestBool();
final var iffExpr = Iff(choiceExprExtended, Ref(choiceVar));
final var iteExpr = IteExpr.of(Ref(choiceVar), Ref(choiceVar), elzeExprExtended);
final var tempValidExpr = And(
Leq(Ref(tempVar), Int(count)),
Geq(Ref(tempVar), Int(0))
);

expr = And(iffExpr, iteExpr, tempValidExpr);

VarPoolUtil.returnBool(choiceVar);
} else {
expr = choiceExpr;
}

VarPoolUtil.returnInt(tempVar);
return StmtUnfoldResult.of(ImmutableList.of(expr), jointIndexing);
}

private static List<Expr<BoolType>> fixVariablePrimes(List<Expr<BoolType>> branches, List<VarIndexing> indexings, VarIndexing jointIndexing) {
final var vars = ExprUtils.getVars(branches);
final var branchExprs = new ArrayList<Expr<BoolType>>();
for (int i = 0; i < branches.size(); i++) {
final var exprs = new ArrayList<Expr<BoolType>>();
exprs.add(branches.get(i));
for (var declaration : vars) {
int currentBranchIndex = indexings.get(i).get(declaration);
int jointIndex = jointIndexing.get(declaration);
if (currentBranchIndex < jointIndex) {
if (currentBranchIndex > 0) {
exprs.add(Eq(Prime(decl.getRef(), currentBranchIndex),
Prime(decl.getRef(), jointIndex)));
exprs.add(Eq(
Prime(declaration.getRef(), currentBranchIndex),
Prime(declaration.getRef(), jointIndex)
));
} else {
exprs.add(Eq(decl.getRef(), Prime(decl.getRef(), jointIndex)));
exprs.add(Eq(declaration.getRef(), Prime(declaration.getRef(), jointIndex)));
}
}
}
branchExprs.add(And(exprs));
}
final Expr<BoolType> expr = Or(branchExprs);
VarPoolUtil.returnInt(tempVar);
return StmtUnfoldResult.of(ImmutableList.of(expr), jointIndexing);
return branchExprs;
}

@Override
Expand All @@ -183,40 +236,19 @@ public StmtUnfoldResult visit(IfStmt ifStmt, VarIndexing indexing) {

final Expr<BoolType> thenExpr = And(thenResult.getExprs());
final Expr<BoolType> elzeExpr = And(elzeResult.getExprs());
final Set<VarDecl<?>> vars = ExprUtils.getVars(ImmutableList.of(thenExpr, elzeExpr));

VarIndexing jointIndexing = thenIndexing.join(elzeIndexing);
final List<Expr<BoolType>> thenAdditions = new ArrayList<>();
final List<Expr<BoolType>> elzeAdditions = new ArrayList<>();
for (VarDecl<?> decl : vars) {
final int thenIndex = thenIndexing.get(decl);
final int elzeIndex = elzeIndexing.get(decl);
if (thenIndex < elzeIndex) {
if (thenIndex > 0) {
thenAdditions.add(
Eq(Prime(decl.getRef(), thenIndex), Prime(decl.getRef(), elzeIndex)));
} else {
thenAdditions.add(Eq(decl.getRef(), Prime(decl.getRef(), elzeIndex)));
}
} else if (elzeIndex < thenIndex) {
if (elzeIndex > 0) {
elzeAdditions.add(
Eq(Prime(decl.getRef(), elzeIndex), Prime(decl.getRef(), thenIndex)));
} else {
elzeAdditions.add(Eq(decl.getRef(), Prime(decl.getRef(), thenIndex)));
}
}
}

final Expr<BoolType> thenExprExtended =
thenAdditions.size() > 0 ? SmartBoolExprs.And(thenExpr, And(thenAdditions))
: thenExpr;
final Expr<BoolType> elzeExprExtended =
elzeAdditions.size() > 0 ? SmartBoolExprs.And(elzeExpr, And(elzeAdditions))
: elzeExpr;
final var expressions = fixVariablePrimes(
ImmutableList.of(thenExpr, elzeExpr),
ImmutableList.of(thenIndexing, elzeIndexing),
jointIndexing
);

final Expr<BoolType> thenExprExtended = expressions.get(0);
final Expr<BoolType> elzeExprExtended = expressions.get(1);

final Expr<BoolType> ite = cast(Ite(condExpr, thenExprExtended, elzeExprExtended),
Bool());
final Expr<BoolType> ite = cast(Ite(condExpr, thenExprExtended, elzeExprExtended), Bool());
return StmtUnfoldResult.of(ImmutableList.of(ite), jointIndexing);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@

import hu.bme.mit.theta.core.decl.Decls;
import hu.bme.mit.theta.core.decl.VarDecl;
import hu.bme.mit.theta.core.type.booltype.BoolType;
import hu.bme.mit.theta.core.type.inttype.IntType;

import java.util.ArrayDeque;

import static hu.bme.mit.theta.core.type.booltype.BoolExprs.Bool;
import static hu.bme.mit.theta.core.type.inttype.IntExprs.Int;

public class VarPoolUtil {

private VarPoolUtil() {
}

private static ArrayDeque<VarDecl<IntType>> intPool = new ArrayDeque<VarDecl<IntType>>();
private static ArrayDeque<VarDecl<IntType>> intPool = new ArrayDeque<>();
private static ArrayDeque<VarDecl<BoolType>> boolPool = new ArrayDeque<>();
private static int counter = 0;

public static VarDecl<IntType> requestInt() {
Expand All @@ -39,10 +42,24 @@ public static VarDecl<IntType> requestInt() {
}
}

public static VarDecl<BoolType> requestBool() {
if (boolPool.isEmpty()) {
return Decls.Var("temp" + counter++, Bool());
} else {
return boolPool.remove();
}
}

public static void returnInt(VarDecl<IntType> var) {
if (!intPool.contains(var)) {
intPool.addFirst(var);
}
}

public static void returnBool(VarDecl<BoolType> var) {
if (!boolPool.contains(var)) {
boolPool.addFirst(var);
}
}

}
Loading