Skip to content

Commit

Permalink
Merge pull request #976 from utwente-fmt/gpgpu-optimizations
Browse files Browse the repository at this point in the history
Gpgpu optimizations
  • Loading branch information
pieter-bos authored Mar 2, 2023
2 parents 375e728 + 7bf142a commit 7158a75
Show file tree
Hide file tree
Showing 56 changed files with 4,648 additions and 161 deletions.
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ lazy val vercors: Project = (project in file("."))
scriptClasspath := Seq("*", "../res"),

// Force the main classes, as we have some extra main classes that we don't want to generate run scripts for.
Compile / discoveredMainClasses := Seq(),
Compile / mainClass := Some("vct.main.Main"),
Compile / discoveredMainClasses := Seq("vct.main.Vercors","vct.main.Alpinist"),
// Compile / mainClass := Some("vct.main.Main"),

// Add options to run scripts produced by sbt-native-packager. See: https://www.scala-sbt.org/sbt-native-packager/archetypes/java_app/customize.html#via-build-sbt
Universal / javaOptions ++= Seq (
Expand Down
248 changes: 209 additions & 39 deletions col/src/main/java/vct/col/ast/print/PVLPrinter.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import vct.col.ast.stmt.terminal.ReturnStatement;
import vct.col.ast.syntax.PVLSyntax;
import vct.col.ast.type.*;
import vct.col.ast.util.ASTUtils;
import vct.col.ast.util.ClassName;
import vct.col.ast.util.SequenceUtils;

import java.util.*;

Expand All @@ -31,6 +33,12 @@ public PVLPrinter(TrackingOutput out) {
super(PVLSyntax.get(),out);
}

/**
* Flag set before visiting loop invariants.
*/
private boolean loopcontract = false;


public void visit(TypeVariable v){
out.print(v.name());
}
Expand Down Expand Up @@ -570,6 +578,75 @@ public void visit(ASTClass cl){
out.println("");
}

@Override
public void visit(Contract contract) {
if (contract!=null){
//out.incrIndent();
for (DeclarationStatement d:contract.given){
out.printf("given ");
d.accept(this);
out.lnprintf("");
}
for(ASTNode e:ASTUtils.conjuncts(contract.invariant,StandardOperator.Star, StandardOperator.And, StandardOperator.Wrap)){
out.printf((loopcontract) ? "loop_invariant ": "context_everywhere ");
nextExpr();
e.accept(this);
out.lnprintf(";");
}
for(ASTNode e: ASTUtils.conjuncts(contract.pre_condition,StandardOperator.Star)){
out.printf("requires ");
nextExpr();
e.accept(this);
out.lnprintf(";");
}
for (DeclarationStatement d:contract.yields){
out.printf("yields ");
d.accept(this);
out.lnprintf("");
}
for(ASTNode e:ASTUtils.conjuncts(contract.post_condition,StandardOperator.Star)){
out.printf("ensures ");
nextExpr();
e.accept(this);
out.lnprintf(";");
}
for (SignalsClause sc : contract.signals){
sc.accept(this);
}
if (contract.modifies!=null){
out.printf("modifies ");
if (contract.modifies.length==0){
out.lnprintf("\\nothing;");
} else {
nextExpr();
contract.modifies[0].accept(this);
for(int i=1;i<contract.modifies.length;i++){
out.printf(", ");
nextExpr();
contract.modifies[i].accept(this);
}
out.lnprintf(";");
}
}
if (contract.accesses!=null){
out.printf("accessible ");
if (contract.accesses.length==0){
out.lnprintf("\\nothing;");
} else {
nextExpr();
contract.accesses[0].accept(this);
for(int i=1;i<contract.accesses.length;i++){
out.printf(", ");
nextExpr();
contract.accesses[i].accept(this);
}
out.lnprintf(";");
}
}
//out.decrIndent();
}
}

public void visit(SignalsClause sc) {
out.printf("signals (");
sc.type().accept(this);
Expand Down Expand Up @@ -601,6 +678,10 @@ public void visit(Method m){
if (predicate && contract!=null){
Debug("ignoring contract of predicate");
}
if (!m.getGpuOpts().isEmpty()) {
m.getGpuOpts().forEach(opt -> visit(opt));
}

if (contract!=null && !predicate){
visit(contract);
}
Expand Down Expand Up @@ -684,7 +765,7 @@ public void visit(Method m){
body.accept(this);
out.lnprintf("");
} else {
out.printf("=");
out.printf(" = ");
nextExpr();
body.accept(this);
out.lnprintf(";");
Expand Down Expand Up @@ -726,7 +807,11 @@ private boolean self_terminating(ASTNode s) {
|| (s instanceof LoopStatement)
|| (s instanceof ASTSpecial)
|| (s instanceof DeclarationStatement)
|| (s instanceof ParallelRegion);
|| (s instanceof ParallelRegion)
|| (s instanceof ParallelBarrier)
|| (s instanceof ParallelAtomic)
|| (s instanceof ParallelInvariant)
;
}

public void visit(AssignmentStatement s){
Expand Down Expand Up @@ -757,7 +842,34 @@ public void visit(Lemma lemma){
}

public void visit(OperatorExpression e){
visitVerCors(e);
if (e.isa(StandardOperator.NewArray)) {
String[] op_syntax =syntax.getSyntax(e.operator());

out.print(op_syntax[0]);

SequenceUtils.SequenceInfo info = SequenceUtils.getTypeInfo((Type) e.arg(0));
while (info != null && !info.isCell()) {
info = SequenceUtils.getTypeInfo(info.getSequenceTypeArgument());
}
if (info == null || info.getElementType() == null) {
super.visit(e);
return;
}
info.getElementType().accept(this);


for(int i=1;i<e.args().size();i++){
out.print(op_syntax[1]);
boolean tmp = in_expr;
in_expr = true;
e.arg(i).accept(this);
in_expr = tmp;
out.print(op_syntax[2]);
}
} else {
visitVerCors(e);
}

}

private void visitVerCors(OperatorExpression e){
Expand Down Expand Up @@ -892,8 +1004,53 @@ private void visitForStatementList(BlockStatement s) {
}
}

public void visit(GPUOpt o) {
if (o == null) return;
out.printf("gpuopt ");
if (o instanceof LoopUnrolling) {
out.printf("loop_unroll ");
} else if (o instanceof MatrixLinearization) {
out.printf("matrix_lin ");
} else if (o instanceof Tiling) {
out.printf("tile ");
} else if (o instanceof IterationMerging) {
out.printf("iter_merge ");
} else if (o instanceof DataLocation) {
out.printf("glob_to_reg ");
}
else {
Warning("Could not find name of " + o.getClass());
}
if (o instanceof MatrixLinearization) {
nextExpr();
((MatrixLinearization) o).matrixName().accept(this);
out.print(" ");
nextExpr();
out.print(((MatrixLinearization) o).rowOrColumn().equals(Major.Row()) ?"R":"C");
out.print(" ");
nextExpr();
((MatrixLinearization) o).dimX().accept(this);
out.print(" ");
nextExpr();
((MatrixLinearization) o).dimY().accept(this);
} else if (o instanceof Tiling) {
nextExpr();
out.print(((Tiling) o).interOrIntra().equals(TilingConfig.Inter()) ?"inter":"intra");
out.print(" ");
nextExpr();
((Tiling) o).tileSize().accept(this);
} else {
Iterator<ASTNode> argsit = o.argsJava().iterator();
print_tuple(" ", "", "", o.argsJava().toArray(new ASTNode[0]));
}
out.lnprintf(";");
}

public void visit(LoopStatement s){
visit(s.getGpuopt());
loopcontract = true;
visit(s.getContract());
loopcontract = false;
ASTNode tmp;
if (s.getInitBlock()!=null || s.getUpdateBlock()!=null){
out.printf("for(");
Expand Down Expand Up @@ -958,16 +1115,20 @@ public void visit(LoopStatement s){
}
}

private void print_tuple(ASTNode ... args){
out.print("(");
private void print_tuple(ASTNode ... args) {
print_tuple(",", "(", ")", args);
}

private void print_tuple(String delimiter, String prefix, String suffix, ASTNode ... args) {
out.print(prefix);
String sep="";
for(ASTNode n:args){
out.print(sep);
nextExpr();
n.accept(this);
sep=",";
sep=delimiter;
}
out.print(")");
out.print(suffix);
}

public void visit(MethodInvokation s){
Expand Down Expand Up @@ -996,7 +1157,12 @@ public void visit(PrimitiveType t){
break;
}
case Array:
t.firstarg().accept(this);
SequenceUtils.SequenceInfo info = SequenceUtils.getTypeInfo(t);
if (info != null && info.isCell()) {
info.getElementType().accept(this);
} else {
t.firstarg().accept(this);
}
switch(nrofargs){
case 1:
out.printf("[]");
Expand Down Expand Up @@ -1027,7 +1193,19 @@ public void visit(PrimitiveType t){
if (nrofargs!=1){
Fail("Option type constructor with %d arguments instead of 1",nrofargs);
}
t.firstarg().accept(this);

SequenceUtils.SequenceInfo info1 = SequenceUtils.getTypeInfo(t);
if (info1 != null && info1.getSequenceSort() == PrimitiveSort.Array){

info1.getElementType().apply(this);
out.printf("[]");
break;
} else {
out.printf("option<");
t.firstarg().accept(this);
out.printf(">");
}

break;
case Map:
if (nrofargs!=2){
Expand Down Expand Up @@ -1101,28 +1279,18 @@ public void visit(ParallelAtomic pa){
public void visit(ParallelBlock pb){

int j = 0;
out.printf(pb.label());
out.printf("(");
for (DeclarationStatement iter : pb.itersJava()) {
out.print("(");
setExpr();
if (j > 0) out.printf(",");
in_expr = true;
iter.accept(this);
in_expr = false;

j++;
ASTNode expr = iter.initJava();
nextExpr();
iter.getType().accept(this);
out.printf(" %s", iter.name());
if (expr!=null){
out.printf(" = ");
nextExpr();
if(expr instanceof OperatorExpression && ((OperatorExpression) expr).operator() == StandardOperator.RangeSeq) {
((OperatorExpression) expr).arg(0).accept(this);
out.print(" .. ");
((OperatorExpression) expr).arg(1).accept(this);
} else {
Fail("Unexpected DeclarationStatement in iters of ParallelBlock");
}
}
out.println(")");
}
out.printf(")");
out.newline();

if (pb.depslength() > 0){
out.printf(";");
Expand All @@ -1143,34 +1311,36 @@ public void visit(ParallelBarrier pb){
if (pb.contract() == null) {
Fail("parallel barrier with null contract!");
} else {
out.printf("barrier(%s;%s){", pb.label(), pb.invs());
out.println("");
out.incrIndent();
visit(pb.contract());
out.decrIndent();
if (pb.body() == null ) {
out.println("{ }");
out.printf("barrier(%s)", pb.label());
if (pb.body() == null) {
out.println(" { ");
out.incrIndent();
visit(pb.contract());
out.decrIndent();
out.println("}");
} else {
out.newline();
out.incrIndent();
visit(pb.contract());
out.decrIndent();
pb.body().accept(this);
}

}
}
@Override
public void visit(ParallelInvariant pb) {
out.printf("invariants %s (", pb.label());
out.printf("invariant %s (", pb.label());
nextExpr();
pb.inv().accept(this);
out.printf(")");
pb.block().accept(this);
}
@Override
public void visit(ParallelRegion region){
out.print("par ");

if (region.contract() != null) {
out.println("par");
region.contract().accept(this);
} else {
out.println("par");
}
for (Iterator<ParallelBlock> it = region.blocksJava().iterator(); it.hasNext();) {
out.incrIndent();
Expand Down
Loading

0 comments on commit 7158a75

Please sign in to comment.