diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 51d239e61078..6a7e5f0f1337 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -130,7 +130,7 @@ object LocalBackend extends Backend { Validate(ir) val queryID = Backend.nextID() log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}") - ctx.irMetadata.semhash = SemanticHash(ctx)(ir) + ctx.irMetadata.semhash = SemanticHash(ctx, ir) val res = _jvmLowerAndExecute(ctx, ir) log.info(s"finished execution of query $queryID") res diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index c102f2ab63ef..5fc9e7f4e096 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -247,7 +247,7 @@ class ServiceBackend( Validate(ir) val queryID = Backend.nextID() log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}") - ctx.irMetadata.semhash = SemanticHash(ctx)(ir) + ctx.irMetadata.semhash = SemanticHash(ctx, ir) val res = _jvmLowerAndExecute(ctx, ir) log.info(s"finished execution of query $queryID") res diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index 560913b2fae8..da36fcee6703 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -424,7 +424,7 @@ class SparkBackend(val sc: SparkContext) extends Backend { ctx.time { TypeCheck(ctx, ir) Validate(ir) - ctx.irMetadata.semhash = SemanticHash(ctx)(ir) + ctx.irMetadata.semhash = SemanticHash(ctx, ir) try { val lowerTable = ctx.flags.get("lower") != null val lowerBM = ctx.flags.get("lower_bm") != null diff --git a/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala b/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala index 2778bcc0abf2..f2fc955e22df 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala @@ -33,11 +33,11 @@ abstract class BaseIR { // New sentinel values can be obtained by `nextFlag` on `IRMetadata`. var mark: Int = 0 - def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean = - /* FIXME: rewrite to not rebuild the irs, by maintaining an env mapping left names to right - * names */ - NormalizeNames(ctx, this, allowFreeVariables = true) == - NormalizeNames(ctx, other, allowFreeVariables = true) + def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean = { + // FIXME: rewrite to not rebuild the irs by maintaining an env mapping left to right names + val normalize: (ExecuteContext, BaseIR) => BaseIR = NormalizeNames(allowFreeVariables = true) + normalize(ctx, this) == normalize(ctx, other) + } def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): BaseIR = { val newChildren = childrenSeq.view.zipWithIndex.map(f.tupled).toArray diff --git a/hail/src/main/scala/is/hail/expr/ir/Compile.scala b/hail/src/main/scala/is/hail/expr/ir/Compile.scala index dbe02dd5e307..ecb4e0d06d73 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Compile.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Compile.scala @@ -93,7 +93,7 @@ object compile { N: sourcecode.Name, ): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F with Mixin) = ctx.time { - val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true) + val normalizedBody = NormalizeNames(allowFreeVariables = true)(ctx, body) ctx.CodeCache.getOrElseUpdate( CodeCacheKey(aggSigs.getOrElse(Array.empty).toFastSeq, params, normalizedBody), { var ir = Subst( diff --git a/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala b/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala index 446c11868d84..d21c6e8f826e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala @@ -24,52 +24,53 @@ object ExtractIntervalFilters { val MAX_LITERAL_SIZE = 4096 - def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR = { - MapIR.mapBaseIR( - ir0, - (ir: BaseIR) => { - ( - ir match { - case TableFilter(child, pred) => - extractPartitionFilters( - ctx, - pred, - Ref(TableIR.rowName, child.typ.rowType), - child.typ.key, - ) - .map { case (newCond, intervals) => + def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR = + ctx.time { + MapIR.mapBaseIR( + ir0, + ir => + ( + ir match { + case TableFilter(child, pred) => + extractPartitionFilters( + ctx, + pred, + Ref(TableIR.rowName, child.typ.rowType), + child.typ.key, + ) + .map { case (newCond, intervals) => + log.info( + s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " + + s"Intervals: ${intervals.mkString(", ")}\n " + + s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}" + ) + TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond) + } + case MatrixFilterRows(child, pred) => + extractPartitionFilters( + ctx, + pred, + Ref(MatrixIR.rowName, child.typ.rowType), + child.typ.rowKey, + ).map { case (newCond, intervals) => log.info( - s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " + + s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " + s"Intervals: ${intervals.mkString(", ")}\n " + s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}" ) - TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond) + MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond) } - case MatrixFilterRows(child, pred) => extractPartitionFilters( - ctx, - pred, - Ref(MatrixIR.rowName, child.typ.rowType), - child.typ.rowKey, - ).map { case (newCond, intervals) => - log.info( - s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " + - s"Intervals: ${intervals.mkString(", ")}\n " + - s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}" - ) - MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond) - } - case _ => None - } - ).getOrElse(ir) - }, - ) - } + case _ => None + } + ).getOrElse(ir), + ) + } def extractPartitionFilters(ctx: ExecuteContext, cond: IR, ref: Ref, key: IndexedSeq[String]) : Option[(IR, IndexedSeq[Interval])] = { if (key.isEmpty) None - else { + else ctx.time { val extract = new ExtractIntervalFilters(ctx, ref.typ.asInstanceOf[TStruct].typeAfterSelectNames(key)) val trueSet = extract.analyze(cond, ref.name) diff --git a/hail/src/main/scala/is/hail/expr/ir/FoldConstants.scala b/hail/src/main/scala/is/hail/expr/ir/FoldConstants.scala index 2c2cedd215e7..bf604726bfbc 100644 --- a/hail/src/main/scala/is/hail/expr/ir/FoldConstants.scala +++ b/hail/src/main/scala/is/hail/expr/ir/FoldConstants.scala @@ -6,7 +6,9 @@ import is.hail.utils.HailException object FoldConstants { def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR = - ctx.r.pool.scopedRegion(region => ctx.local(r = region)(foldConstants(_, ir))) + ctx.time { + ctx.r.pool.scopedRegion(r => ctx.local(r = r)(foldConstants(_, ir))) + } private def foldConstants(ctx: ExecuteContext, ir: BaseIR): BaseIR = RewriteBottomUp( diff --git a/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala b/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala index e9f137c80c74..4de2609ce92f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala @@ -2,20 +2,20 @@ package is.hail.expr.ir import is.hail.backend.ExecuteContext import is.hail.types.virtual.TVoid -import is.hail.utils.BoxedArrayBuilder +import is.hail.utils.{fatal, BoxedArrayBuilder} import scala.collection.Set +import scala.util.control.NonFatal object ForwardLets { - def apply[T <: BaseIR](ctx: ExecuteContext)(ir0: T): T = { - val ir1 = NormalizeNames(ctx, ir0, allowFreeVariables = true) - val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false) - val nestingDepth = NestingDepth(ir1) - - def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR = { + def apply[T <: BaseIR](ctx: ExecuteContext, ir0: T): T = + ctx.time { + val ir1 = NormalizeNames(allowFreeVariables = true)(ctx, ir0) + val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false) + val nestingDepth = NestingDepth(ctx, ir1) def shouldForward(value: IR, refs: Set[RefEquality[BaseRef]], base: Block, scope: Int) - : Boolean = { + : Boolean = IsPure(value) && ( value.isInstanceOf[Ref] || value.isInstanceOf[In] || @@ -27,45 +27,56 @@ object ForwardLets { !ContainsAgg(value)) && !ContainsAggIntermediate(value) ) - } - ir match { - case l: Block => - val keep = new BoxedArrayBuilder[Binding] - val refs = uses(l) - val newEnv = l.bindings.foldLeft(env) { - case (env, Binding(name, value, scope)) => - val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR] - if ( - rewriteValue.typ != TVoid - && shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope) - ) { - env.bindInScope(name, rewriteValue, scope) - } else { - keep += Binding(name, rewriteValue, scope) - env + def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR = + ir match { + case l: Block => + val keep = new BoxedArrayBuilder[Binding] + val refs = uses(l) + val newEnv = l.bindings.foldLeft(env) { + case (env, Binding(name, value, scope)) => + val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR] + if ( + rewriteValue.typ != TVoid + && shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope) + ) { + env.bindInScope(name, rewriteValue, scope) + } else { + keep += Binding(name, rewriteValue, scope) + env + } + } + + val newBody = rewrite(l.body, newEnv).asInstanceOf[IR] + if (keep.isEmpty) newBody + else Block(keep.result(), newBody) + + case x @ Ref(name, _) => + env.eval + .lookupOption(name) + .map { forwarded => + if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy() + else forwarded } - } + .getOrElse(x) + case _ => + ir.mapChildrenWithIndex((ir1, i) => + rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings)) + ) + } - val newBody = rewrite(l.body, newEnv).asInstanceOf[IR] - if (keep.isEmpty) newBody - else Block(keep.result(), newBody) + val ir = rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty))) - case x @ Ref(name, _) => - env.eval - .lookupOption(name) - .map { forwarded => - if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy() - else forwarded - } - .getOrElse(x) - case _ => - ir.mapChildrenWithIndex((ir1, i) => - rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings)) + try + TypeCheck(ctx, ir) + catch { + case NonFatal(e) => + fatal( + s"bad ir from ForwardLets, started as\n${Pretty(ctx, ir0, preserveNames = true)}", + e, ) } - } - rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty))).asInstanceOf[T] - } + ir.asInstanceOf[T] + } } diff --git a/hail/src/main/scala/is/hail/expr/ir/ForwardRelationalLets.scala b/hail/src/main/scala/is/hail/expr/ir/ForwardRelationalLets.scala index ae600985d8a1..0853b8519a35 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ForwardRelationalLets.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ForwardRelationalLets.scala @@ -1,81 +1,83 @@ package is.hail.expr.ir +import is.hail.backend.ExecuteContext +import is.hail.utils.fatal + import scala.collection.mutable +import scala.util.control.NonFatal object ForwardRelationalLets { - def apply(ir0: BaseIR): BaseIR = { - - val usages = mutable.HashMap.empty[Name, (Int, Int)] - - val nestingDepth = NestingDepth(ir0) - - def visit(ir1: BaseIR): Unit = { - ir1 match { - case RelationalLet(name, _, _) => - usages(name) = (0, 0) - case RelationalLetTable(name, _, _) => - usages(name) = (0, 0) - case RelationalLetMatrixTable(name, _, _) => - usages(name) = (0, 0) - case RelationalLetBlockMatrix(name, _, _) => - usages(name) = (0, 0) + def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR = + ctx.time { + val uses = mutable.HashMap.empty[Name, (Int, Int)] + val nestingDepth = NestingDepth(ctx, ir0) + IRTraversal.levelOrder(ir0).foreach { case x @ RelationalRef(name, _) => - val (n, nd) = usages(name) - usages(name) = (n + 1, math.max(nd, nestingDepth.lookupRef(x))) + val (n, nd) = uses.getOrElseUpdate(name, (0, 0)) + uses(name) = (n + 1, math.max(nd, nestingDepth.lookupRef(x))) case _ => } - ir1.children.foreach(visit) - } - visit(ir0) + def shouldForward(name: Name): Boolean = + uses.get(name).forall(t => t._1 < 2 && t._2 < 1) - def shouldForward(t: (Int, Int)): Boolean = t._1 < 2 && t._2 < 1 + // short circuit if possible + if (!uses.keys.exists(shouldForward)) ir0 + else { + val env = mutable.HashMap.empty[Name, IR] - // short circuit if possible - if (!usages.valuesIterator.exists(shouldForward)) - ir0 - else { - val m = mutable.HashMap.empty[Name, IR] + def rewrite(ir1: BaseIR): BaseIR = ir1 match { + case RelationalLet(name, value, body) => + if (shouldForward(name)) { + env(name) = rewrite(value).asInstanceOf[IR] + rewrite(body) + } else + RelationalLet(name, rewrite(value).asInstanceOf[IR], rewrite(body).asInstanceOf[IR]) + case RelationalLetTable(name, value, body) => + if (shouldForward(name)) { + env(name) = rewrite(value).asInstanceOf[IR] + rewrite(body) + } else RelationalLetTable( + name, + rewrite(value).asInstanceOf[IR], + rewrite(body).asInstanceOf[TableIR], + ) + case RelationalLetMatrixTable(name, value, body) => + if (shouldForward(name)) { + env(name) = rewrite(value).asInstanceOf[IR] + rewrite(body) + } else RelationalLetMatrixTable( + name, + rewrite(value).asInstanceOf[IR], + rewrite(body).asInstanceOf[MatrixIR], + ) + case RelationalLetBlockMatrix(name, value, body) => + if (shouldForward(name)) { + env(name) = rewrite(value).asInstanceOf[IR] + rewrite(body) + } else RelationalLetBlockMatrix( + name, + rewrite(value).asInstanceOf[IR], + rewrite(body).asInstanceOf[BlockMatrixIR], + ) + case x @ RelationalRef(name, _) => + env.getOrElse(name, x) + case _ => ir1.mapChildren(rewrite) + } - def recur(ir1: BaseIR): BaseIR = ir1 match { - case RelationalLet(name, value, body) => - if (shouldForward(usages(name))) { - m(name) = recur(value).asInstanceOf[IR] - recur(body) - } else RelationalLet(name, recur(value).asInstanceOf[IR], recur(body).asInstanceOf[IR]) - case RelationalLetTable(name, value, body) => - if (shouldForward(usages(name))) { - m(name) = recur(value).asInstanceOf[IR] - recur(body) - } else RelationalLetTable( - name, - recur(value).asInstanceOf[IR], - recur(body).asInstanceOf[TableIR], - ) - case RelationalLetMatrixTable(name, value, body) => - if (shouldForward(usages(name))) { - m(name) = recur(value).asInstanceOf[IR] - recur(body) - } else RelationalLetMatrixTable( - name, - recur(value).asInstanceOf[IR], - recur(body).asInstanceOf[MatrixIR], - ) - case RelationalLetBlockMatrix(name, value, body) => - if (shouldForward(usages(name))) { - m(name) = recur(value).asInstanceOf[IR] - recur(body) - } else RelationalLetBlockMatrix( - name, - recur(value).asInstanceOf[IR], - recur(body).asInstanceOf[BlockMatrixIR], - ) - case x @ RelationalRef(name, _) => - m.getOrElse(name, x) - case _ => ir1.mapChildren(recur) - } + val ir = rewrite(ir0) - recur(ir0) + try + TypeCheck(ctx, ir) + catch { + case NonFatal(e) => + fatal( + s"bad ir from ForwardRelationalLets, started as\n${Pretty(ctx, ir0, preserveNames = true)}", + e, + ) + } + + ir + } } - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala b/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala index 59a7b298b612..cc60a7a6c2e3 100644 --- a/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala +++ b/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala @@ -1,5 +1,7 @@ package is.hail.expr.ir +import is.hail.backend.ExecuteContext + case class ScopedDepth(eval: Int, agg: Int, scan: Int) { def incrementEval: ScopedDepth = ScopedDepth(eval + 1, agg, scan) @@ -28,136 +30,137 @@ final class NestingDepth(private val memo: Memo[ScopedDepth]) { } object NestingDepth { - def apply(ir0: BaseIR): NestingDepth = { - - val memo = Memo.empty[ScopedDepth] - - def computeChildren(ir: BaseIR): Unit = { - ir.children - .zipWithIndex - .foreach { - case (child: IR, _) => computeIR(child, ScopedDepth(0, 0, 0)) - case (tir: TableIR, _) => computeTable(tir) - case (mir: MatrixIR, _) => computeMatrix(mir) - case (bmir: BlockMatrixIR, _) => computeBlockMatrix(bmir) - } - } + def apply(ctx: ExecuteContext, ir0: BaseIR): NestingDepth = + ctx.time { + + val memo = Memo.empty[ScopedDepth] + + def computeChildren(ir: BaseIR): Unit = { + ir.children + .zipWithIndex + .foreach { + case (child: IR, _) => computeIR(child, ScopedDepth(0, 0, 0)) + case (tir: TableIR, _) => computeTable(tir) + case (mir: MatrixIR, _) => computeMatrix(mir) + case (bmir: BlockMatrixIR, _) => computeBlockMatrix(bmir) + } + } - def computeTable(tir: TableIR): Unit = computeChildren(tir) + def computeTable(tir: TableIR): Unit = computeChildren(tir) - def computeMatrix(mir: MatrixIR): Unit = computeChildren(mir) + def computeMatrix(mir: MatrixIR): Unit = computeChildren(mir) - def computeBlockMatrix(bmir: BlockMatrixIR): Unit = computeChildren(bmir) + def computeBlockMatrix(bmir: BlockMatrixIR): Unit = computeChildren(bmir) - def computeIR(ir: IR, depth: ScopedDepth): Unit = { - ir match { - case _: Block | _: BaseRef => - memo.bind(ir, depth) - case _ => + def computeIR(ir: IR, depth: ScopedDepth): Unit = { + ir match { + case _: Block | _: BaseRef => + memo.bind(ir, depth) + case _ => + } + ir match { + case StreamMap(a, _, body) => + computeIR(a, depth) + computeIR(body, depth.incrementEval) + case StreamAgg(a, _, body) => + computeIR(a, depth) + computeIR(body, ScopedDepth(depth.eval, depth.eval + 1, depth.scan)) + case StreamAggScan(a, _, body) => + computeIR(a, depth) + computeIR(body, ScopedDepth(depth.eval, depth.agg, depth.eval + 1)) + case StreamZip(as, _, body, _, _) => + as.foreach(computeIR(_, depth)) + computeIR(body, depth.incrementEval) + case StreamZipJoin(as, _, _, _, joinF) => + as.foreach(computeIR(_, depth)) + computeIR(joinF, depth.incrementEval) + case StreamZipJoinProducers(contexts, _, makeProducer, _, _, _, joinF) => + computeIR(contexts, depth) + computeIR(makeProducer, depth.incrementEval) + computeIR(joinF, depth.incrementEval) + case StreamFor(a, _, body) => + computeIR(a, depth) + computeIR(body, depth.incrementEval) + case StreamFlatMap(a, _, body) => + computeIR(a, depth) + computeIR(body, depth.incrementEval) + case StreamFilter(a, _, cond) => + computeIR(a, depth) + computeIR(cond, depth.incrementEval) + case StreamTakeWhile(a, _, cond) => + computeIR(a, depth) + computeIR(cond, depth.incrementEval) + case StreamDropWhile(a, _, cond) => + computeIR(a, depth) + computeIR(cond, depth.incrementEval) + case StreamFold(a, zero, _, _, body) => + computeIR(a, depth) + computeIR(zero, depth) + computeIR(body, depth.incrementEval) + case StreamFold2(a, accum, _, seq, result) => + computeIR(a, depth) + accum.foreach { case (_, value) => computeIR(value, depth) } + seq.foreach(computeIR(_, depth.incrementEval)) + computeIR(result, depth) + case StreamScan(a, zero, _, _, body) => + computeIR(a, depth) + computeIR(zero, depth) + computeIR(body, depth.incrementEval) + case StreamJoinRightDistinct(left, right, _, _, _, _, joinF, _) => + computeIR(left, depth) + computeIR(right, depth) + computeIR(joinF, depth.incrementEval) + case StreamLeftIntervalJoin(left, right, _, _, _, _, body) => + computeIR(left, depth) + computeIR(right, depth) + computeIR(body, depth.incrementEval) + case TailLoop(_, params, _, body) => + params.foreach { case (_, p) => computeIR(p, depth) } + computeIR(body, depth.incrementEval) + case NDArrayMap(nd, _, body) => + computeIR(nd, depth) + computeIR(body, depth.incrementEval) + case NDArrayMap2(nd1, nd2, _, _, body, _) => + computeIR(nd1, depth) + computeIR(nd2, depth) + computeIR(body, depth.incrementEval) + case AggExplode(array, _, aggBody, isScan) => + computeIR(array, depth.promoteScanOrAgg(isScan)) + computeIR(aggBody, depth.incrementScanOrAgg(isScan)) + case AggArrayPerElement(a, _, _, aggBody, knownLength, isScan) => + computeIR(a, depth.promoteScanOrAgg(isScan)) + computeIR(aggBody, depth.incrementScanOrAgg(isScan)) + knownLength.foreach(computeIR(_, depth)) + case TableAggregate(child, query) => + computeTable(child) + computeIR(query, ScopedDepth(0, 0, 0)) + case MatrixAggregate(child, query) => + computeMatrix(child) + computeIR(query, ScopedDepth(0, 0, 0)) + case _ => + ir.children + .zipWithIndex + .foreach { + case (child: IR, i) => if (UsesAggEnv(ir, i)) + computeIR(child, depth.promoteAgg) + else if (UsesScanEnv(ir, i)) + computeIR(child, depth.promoteScan) + else + computeIR(child, depth) + case (child: TableIR, _) => computeTable(child) + case (child: MatrixIR, _) => computeMatrix(child) + case (child: BlockMatrixIR, _) => computeBlockMatrix(child) + } + } } - ir match { - case StreamMap(a, _, body) => - computeIR(a, depth) - computeIR(body, depth.incrementEval) - case StreamAgg(a, _, body) => - computeIR(a, depth) - computeIR(body, ScopedDepth(depth.eval, depth.eval + 1, depth.scan)) - case StreamAggScan(a, _, body) => - computeIR(a, depth) - computeIR(body, ScopedDepth(depth.eval, depth.agg, depth.eval + 1)) - case StreamZip(as, _, body, _, _) => - as.foreach(computeIR(_, depth)) - computeIR(body, depth.incrementEval) - case StreamZipJoin(as, _, _, _, joinF) => - as.foreach(computeIR(_, depth)) - computeIR(joinF, depth.incrementEval) - case StreamZipJoinProducers(contexts, _, makeProducer, _, _, _, joinF) => - computeIR(contexts, depth) - computeIR(makeProducer, depth.incrementEval) - computeIR(joinF, depth.incrementEval) - case StreamFor(a, _, body) => - computeIR(a, depth) - computeIR(body, depth.incrementEval) - case StreamFlatMap(a, _, body) => - computeIR(a, depth) - computeIR(body, depth.incrementEval) - case StreamFilter(a, _, cond) => - computeIR(a, depth) - computeIR(cond, depth.incrementEval) - case StreamTakeWhile(a, _, cond) => - computeIR(a, depth) - computeIR(cond, depth.incrementEval) - case StreamDropWhile(a, _, cond) => - computeIR(a, depth) - computeIR(cond, depth.incrementEval) - case StreamFold(a, zero, _, _, body) => - computeIR(a, depth) - computeIR(zero, depth) - computeIR(body, depth.incrementEval) - case StreamFold2(a, accum, _, seq, result) => - computeIR(a, depth) - accum.foreach { case (_, value) => computeIR(value, depth) } - seq.foreach(computeIR(_, depth.incrementEval)) - computeIR(result, depth) - case StreamScan(a, zero, _, _, body) => - computeIR(a, depth) - computeIR(zero, depth) - computeIR(body, depth.incrementEval) - case StreamJoinRightDistinct(left, right, _, _, _, _, joinF, _) => - computeIR(left, depth) - computeIR(right, depth) - computeIR(joinF, depth.incrementEval) - case StreamLeftIntervalJoin(left, right, _, _, _, _, body) => - computeIR(left, depth) - computeIR(right, depth) - computeIR(body, depth.incrementEval) - case TailLoop(_, params, _, body) => - params.foreach { case (_, p) => computeIR(p, depth) } - computeIR(body, depth.incrementEval) - case NDArrayMap(nd, _, body) => - computeIR(nd, depth) - computeIR(body, depth.incrementEval) - case NDArrayMap2(nd1, nd2, _, _, body, _) => - computeIR(nd1, depth) - computeIR(nd2, depth) - computeIR(body, depth.incrementEval) - case AggExplode(array, _, aggBody, isScan) => - computeIR(array, depth.promoteScanOrAgg(isScan)) - computeIR(aggBody, depth.incrementScanOrAgg(isScan)) - case AggArrayPerElement(a, _, _, aggBody, knownLength, isScan) => - computeIR(a, depth.promoteScanOrAgg(isScan)) - computeIR(aggBody, depth.incrementScanOrAgg(isScan)) - knownLength.foreach(computeIR(_, depth)) - case TableAggregate(child, query) => - computeTable(child) - computeIR(query, ScopedDepth(0, 0, 0)) - case MatrixAggregate(child, query) => - computeMatrix(child) - computeIR(query, ScopedDepth(0, 0, 0)) - case _ => - ir.children - .zipWithIndex - .foreach { - case (child: IR, i) => if (UsesAggEnv(ir, i)) - computeIR(child, depth.promoteAgg) - else if (UsesScanEnv(ir, i)) - computeIR(child, depth.promoteScan) - else - computeIR(child, depth) - case (child: TableIR, _) => computeTable(child) - case (child: MatrixIR, _) => computeMatrix(child) - case (child: BlockMatrixIR, _) => computeBlockMatrix(child) - } + + ir0 match { + case ir: IR => computeIR(ir, ScopedDepth(0, 0, 0)) + case tir: TableIR => computeTable(tir) + case mir: MatrixIR => computeMatrix(mir) + case bmir: BlockMatrixIR => computeBlockMatrix(bmir) } - } - ir0 match { - case ir: IR => computeIR(ir, ScopedDepth(0, 0, 0)) - case tir: TableIR => computeTable(tir) - case mir: MatrixIR => computeMatrix(mir) - case bmir: BlockMatrixIR => computeBlockMatrix(bmir) + new NestingDepth(memo) } - - new NestingDepth(memo) - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala b/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala index 93e456b88913..bac35f18834f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala +++ b/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala @@ -9,7 +9,7 @@ import scala.annotation.tailrec import scala.collection.mutable object NormalizeNames { - def apply[T <: BaseIR](ctx: ExecuteContext, ir: T, allowFreeVariables: Boolean = false): T = + def apply[T <: BaseIR](allowFreeVariables: Boolean = false)(ctx: ExecuteContext, ir: T): T = ctx.time { val freeVariables: Set[Name] = ir match { case ir: IR => diff --git a/hail/src/main/scala/is/hail/expr/ir/Optimize.scala b/hail/src/main/scala/is/hail/expr/ir/Optimize.scala index 8da55591f42f..bf75af0db6ab 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Optimize.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Optimize.scala @@ -2,56 +2,49 @@ package is.hail.expr.ir import is.hail.HailContext import is.hail.backend.ExecuteContext -import is.hail.utils.fatal import scala.util.control.Breaks.{break, breakable} object Optimize { - def apply[T <: BaseIR](ctx: ExecuteContext, ir0: T): T = { - var ir = ir0 - - def runOpt(f: BaseIR => BaseIR, iter: Int, optContext: String): Unit = - ir = ctx.timer.time(s"$optContext, iteration: $iter")(f(ir).asInstanceOf[T]) - - breakable { - for (iter <- 0 until HailContext.get.optimizerIterations) { - val last = ir - runOpt(FoldConstants(ctx, _), iter, "FoldConstants") - runOpt(ExtractIntervalFilters(ctx, _), iter, "ExtractIntervalFilters") - runOpt( - NormalizeNames(ctx, _, allowFreeVariables = true), - iter, - "NormalizeNames", - ) - runOpt(Simplify(ctx, _), iter, "Simplify") - val ircopy = ir.deepCopy() - runOpt(ForwardLets(ctx), iter, "ForwardLets") - try + + private[this] val optimizations: Array[(ExecuteContext, BaseIR) => BaseIR] = + Array( + FoldConstants.apply, + ExtractIntervalFilters.apply, + NormalizeNames(allowFreeVariables = true), + Simplify.apply, + ForwardLets.apply, + ForwardRelationalLets.apply, + PruneDeadFields.apply, + ) + + def apply[T <: BaseIR](ctx: ExecuteContext, ir0: T): T = + ctx.time { + + var ir: BaseIR = ir0 + + breakable { + for (iter <- 0 until HailContext.get.optimizerIterations) { + val last = ir + + for (f <- optimizations) + ir = f(ctx, ir) + TypeCheck(ctx, ir) - catch { - case e: Exception => - fatal( - s"bad ir from forward lets, started as\n${Pretty(ctx, ircopy, preserveNames = true)}", - e, + + if (ir.typ != last.typ) + throw new RuntimeException( + s"Optimize[iteration=$iter] changed type!" + + s"\n before: ${last.typ.parsableString()}" + + s"\n after: ${ir.typ.parsableString()}" + + s"\n Before IR:\n ----------\n${Pretty(ctx, last)}" + + s"\n After IR:\n ---------\n${Pretty(ctx, ir)}" ) + + if (ir == last) break } - runOpt(ForwardRelationalLets(_), iter, "ForwardRelationalLets") - TypeCheck(ctx, ir) - runOpt(PruneDeadFields(ctx, _), iter, "PruneDeadFields") - - if (ir.typ != last.typ) - throw new RuntimeException( - s"Optimize[iteration=$iter] changed type!" + - s"\n before: ${last.typ.parsableString()}" + - s"\n after: ${ir.typ.parsableString()}" + - s"\n Before IR:\n ----------\n${Pretty(ctx, last)}" + - s"\n After IR:\n ---------\n${Pretty(ctx, ir)}" - ) - - if (ir == last) break } - } - ir - } + ir.asInstanceOf[T] + } } diff --git a/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala b/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala index 319717cc689f..b48eec2817c7 100644 --- a/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala +++ b/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala @@ -112,34 +112,35 @@ object PruneDeadFields { } } - def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR = { - try { - val irCopy = ir.deepCopy() - val ms = ComputeMutableState(Memo.empty[BaseType], mutable.HashMap.empty) - irCopy match { - case mir: MatrixIR => - memoizeMatrixIR(ctx, mir, mir.typ, ms) - rebuild(ctx, mir, ms.rebuildState) - case tir: TableIR => - memoizeTableIR(ctx, tir, tir.typ, ms) - rebuild(ctx, tir, ms.rebuildState) - case bmir: BlockMatrixIR => - memoizeBlockMatrixIR(ctx, bmir, bmir.typ, ms) - rebuild(ctx, bmir, ms.rebuildState) - case vir: IR => - memoizeValueIR(ctx, vir, vir.typ, ms) - rebuildIR( - ctx, - vir, - BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty)), - ms.rebuildState, - ) + def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR = + ctx.time { + try { + val irCopy = ir.deepCopy() + val ms = ComputeMutableState(Memo.empty[BaseType], mutable.HashMap.empty) + irCopy match { + case mir: MatrixIR => + memoizeMatrixIR(ctx, mir, mir.typ, ms) + rebuild(ctx, mir, ms.rebuildState) + case tir: TableIR => + memoizeTableIR(ctx, tir, tir.typ, ms) + rebuild(ctx, tir, ms.rebuildState) + case bmir: BlockMatrixIR => + memoizeBlockMatrixIR(ctx, bmir, bmir.typ, ms) + rebuild(ctx, bmir, ms.rebuildState) + case vir: IR => + memoizeValueIR(ctx, vir, vir.typ, ms) + rebuildIR( + ctx, + vir, + BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty)), + ms.rebuildState, + ) + } + } catch { + case e: Throwable => + fatal(s"error trying to rebuild IR:\n${Pretty(ctx, ir, allowUnboundRefs = true)}", e) } - } catch { - case e: Throwable => - fatal(s"error trying to rebuild IR:\n${Pretty(ctx, ir, allowUnboundRefs = true)}", e) } - } def selectKey(t: TStruct, k: IndexedSeq[String]): TStruct = t.filterSet(k.toSet)._1 diff --git a/hail/src/main/scala/is/hail/expr/ir/Simplify.scala b/hail/src/main/scala/is/hail/expr/ir/Simplify.scala index 9d46b76be422..818638490b22 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Simplify.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Simplify.scala @@ -12,63 +12,52 @@ import scala.collection.mutable object Simplify { /** Transform 'ir' using simplification rules until none apply. */ - def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR = ir match { - case ir: IR => simplifyValue(ctx)(ir) - case tir: TableIR => simplifyTable(ctx)(tir) - case mir: MatrixIR => simplifyMatrix(ctx)(mir) - case bmir: BlockMatrixIR => simplifyBlockMatrix(ctx)(bmir) - } + def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR = + ctx.time(recur(ctx, ir)) + + private[this] def recur(ctx: ExecuteContext, ir: BaseIR): BaseIR = + ir match { + case ir: IR => simplifyValue(ctx, ir) + case tir: TableIR => simplifyTable(ctx, tir) + case mir: MatrixIR => simplifyMatrix(ctx, mir) + case bmir: BlockMatrixIR => simplifyBlockMatrix(ctx, bmir) + } private[this] def visitNode[T <: BaseIR]( - visitChildren: BaseIR => BaseIR, - transform: T => Option[T], - post: => (T => T), + transform: (ExecuteContext, T) => Option[T], + post: => (ExecuteContext, T) => T, )( - t: T + ctx: ExecuteContext, + t: T, ): T = { - val t1 = t.mapChildren(visitChildren).asInstanceOf[T] - transform(t1).map(post).getOrElse(t1) + val t1 = t.mapChildren(recur(ctx, _)).asInstanceOf[T] + transform(ctx, t).map(post(ctx, _)).getOrElse(t1) } - private[this] def simplifyValue(ctx: ExecuteContext): IR => IR = - visitNode( - Simplify(ctx, _), - rewriteValueNode(ctx), - simplifyValue(ctx), - ) - - private[this] def simplifyTable(ctx: ExecuteContext)(tir: TableIR): TableIR = - visitNode( - Simplify(ctx, _), - rewriteTableNode(ctx), - simplifyTable(ctx), - )(tir) - - private[this] def simplifyMatrix(ctx: ExecuteContext)(mir: MatrixIR): MatrixIR = - visitNode( - Simplify(ctx, _), - rewriteMatrixNode(), - simplifyMatrix(ctx), - )(mir) - - private[this] def simplifyBlockMatrix(ctx: ExecuteContext)(bmir: BlockMatrixIR): BlockMatrixIR = - visitNode( - Simplify(ctx, _), - rewriteBlockMatrixNode, - simplifyBlockMatrix(ctx), - )(bmir) - - private[this] def rewriteValueNode(ctx: ExecuteContext)(ir: IR): Option[IR] = - valueRules(ctx).lift(ir).orElse(numericRules(ir)) - - private[this] def rewriteTableNode(ctx: ExecuteContext)(tir: TableIR): Option[TableIR] = + private[this] val simplifyValue: (ExecuteContext, IR) => IR = + visitNode(rewriteValueNode, simplifyValue) + + private[this] val simplifyTable: (ExecuteContext, TableIR) => TableIR = + visitNode(rewriteTableNode, simplifyTable) + + private[this] val simplifyMatrix: (ExecuteContext, MatrixIR) => MatrixIR = + visitNode(rewriteMatrixNode, simplifyMatrix) + + private[this] val simplifyBlockMatrix: (ExecuteContext, BlockMatrixIR) => BlockMatrixIR = + visitNode(rewriteBlockMatrixNode, simplifyBlockMatrix) + + private[this] def rewriteValueNode(_ctx: ExecuteContext, ir: IR): Option[IR] = + valueRules.lift(ir).orElse(numericRules(ir)) + + private[this] def rewriteTableNode(ctx: ExecuteContext, tir: TableIR): Option[TableIR] = tableRules(ctx).lift(tir) - private[this] def rewriteMatrixNode()(mir: MatrixIR): Option[MatrixIR] = - matrixRules().lift(mir) + private[this] def rewriteMatrixNode(_ctx: ExecuteContext, mir: MatrixIR): Option[MatrixIR] = + matrixRules.lift(mir) - private[this] def rewriteBlockMatrixNode: BlockMatrixIR => Option[BlockMatrixIR] = - blockMatrixRules.lift + private[this] def rewriteBlockMatrixNode(_ctx: ExecuteContext, bmir: BlockMatrixIR) + : Option[BlockMatrixIR] = + blockMatrixRules.lift(bmir) /** Returns true if any strict child of 'x' is NA. A child is strict if 'x' evaluates to missing * whenever the child does. @@ -205,7 +194,7 @@ object Simplify { ).reduce((f, g) => ir => f(ir).orElse(g(ir))) } - private[this] def valueRules(ctx: ExecuteContext): PartialFunction[IR, IR] = { + private[this] def valueRules: PartialFunction[IR, IR] = { // propagate NA case x: IR if hasMissingStrictChild(x) => NA(x.typ) @@ -1166,7 +1155,7 @@ object Simplify { ) } - private[this] def matrixRules(): PartialFunction[MatrixIR, MatrixIR] = { + private[this] def matrixRules: PartialFunction[MatrixIR, MatrixIR] = { case MatrixMapRows(child, Ref(n, _)) if n == MatrixIR.rowName => child case MatrixKeyRowsBy(MatrixKeyRowsBy(child, _, _), keys, false) => diff --git a/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala b/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala index 2d6679d5d831..4db6794b4a2b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala +++ b/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala @@ -24,12 +24,11 @@ case object SemanticHash extends Logging { def extend(x: Type, bytes: Array[Byte]): Type = MurmurHash3.hash32x86(bytes, 0, bytes.length, x) - def apply(ctx: ExecuteContext)(root: BaseIR): Option[Type] = + def apply(ctx: ExecuteContext, root: BaseIR): Option[Type] = ctx.time { - // Running the algorithm on the name-normalised IR // removes sensitivity to compiler-generated names - val nameNormalizedIR = NormalizeNames(ctx, root, allowFreeVariables = true) + val nameNormalizedIR = NormalizeNames(allowFreeVariables = true)(ctx, root) def go: Option[Int] = { var hash: Type = diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/EvalRelationalLets.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/EvalRelationalLets.scala index 86d7f8b5e8af..8a85d1b77784 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/EvalRelationalLets.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/EvalRelationalLets.scala @@ -8,30 +8,32 @@ import is.hail.expr.ir.{ object EvalRelationalLets { // need to run the rest of lowerings to eval. - def apply(ir: BaseIR, ctx: ExecuteContext, passesBelow: LoweringPipeline): BaseIR = { - def execute(value: BaseIR, letsAbove: Map[Name, IR]): IR = { - val compilable = passesBelow.apply(ctx, lower(value, letsAbove)) - .asInstanceOf[IR] - CompileAndEvaluate.evalToIR(ctx, compilable, true) - } + def apply(ir: BaseIR, ctx: ExecuteContext, passesBelow: LoweringPipeline): BaseIR = + ctx.time { + def execute(value: BaseIR, letsAbove: Map[Name, IR]): IR = + ctx.time { + val compilable = passesBelow.apply(ctx, lower(value, letsAbove)) + .asInstanceOf[IR] + CompileAndEvaluate.evalToIR(ctx, compilable, true) + } - def lower(ir: BaseIR, letsAbove: Map[Name, IR]): BaseIR = { - ir match { - case RelationalLet(name, value, body) => - val valueLit = execute(value, letsAbove) - lower(body, letsAbove + (name -> valueLit)) - case RelationalLetTable(name, value, body) => - val valueLit = execute(value, letsAbove) - lower(body, letsAbove + (name -> valueLit)) - case RelationalLetMatrixTable(name, value, body) => - val valueLit = execute(value, letsAbove) - lower(body, letsAbove + (name -> valueLit)) - case RelationalRef(name, _) => letsAbove(name) - case x => - x.mapChildren(lower(_, letsAbove)) + def lower(ir: BaseIR, letsAbove: Map[Name, IR]): BaseIR = { + ir match { + case RelationalLet(name, value, body) => + val valueLit = execute(value, letsAbove) + lower(body, letsAbove + (name -> valueLit)) + case RelationalLetTable(name, value, body) => + val valueLit = execute(value, letsAbove) + lower(body, letsAbove + (name -> valueLit)) + case RelationalLetMatrixTable(name, value, body) => + val valueLit = execute(value, letsAbove) + lower(body, letsAbove + (name -> valueLit)) + case RelationalRef(name, _) => letsAbove(name) + case x => + x.mapChildren(lower(_, letsAbove)) + } } - } - lower(ir, Map()) - } + lower(ir, Map()) + } } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/IRState.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/IRState.scala index f0f55751a4a6..3d4dcfbd2643 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/IRState.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/IRState.scala @@ -1,26 +1,27 @@ package is.hail.expr.ir.lowering +import is.hail.backend.ExecuteContext import is.hail.expr.ir.{ - BaseIR, RelationalLet, RelationalRef, TableKeyBy, TableKeyByAndAggregate, TableOrderBy, + BaseIR, IRTraversal, RelationalLet, RelationalRef, TableKeyBy, TableKeyByAndAggregate, + TableOrderBy, } -trait IRState { +abstract class IRState(implicit E: sourcecode.Enclosing) { + protected val rules: Array[Rule] - val rules: Array[Rule] - - final def allows(ir: BaseIR): Boolean = rules.forall(_.allows(ir)) - - final def verify(ir: BaseIR): Unit = { - if (!rules.forall(_.allows(ir))) - throw new RuntimeException(s"lowered state ${this.getClass.getCanonicalName} forbids IR $ir") - ir.children.foreach(verify) - } - - final def permits(ir: BaseIR): Boolean = rules.forall(_.allows(ir)) && ir.children.forall(permits) + final def verify(ctx: ExecuteContext, ir: BaseIR): Unit = + ctx.time { + IRTraversal.levelOrder(ir).foreach { ir => + if (!rules.forall(_.allows(ir))) + throw new RuntimeException( + s"lowered state ${this.getClass.getCanonicalName} forbids IR $ir" + ) + } + } - def +(other: IRState): IRState = { + def +(other: IRState)(implicit E: sourcecode.Enclosing): IRState = { val newRules = rules ++ other.rules - new IRState { + new IRState()(E) { val rules: Array[Rule] = newRules } } @@ -55,26 +56,18 @@ case object EmittableStreamIRs extends IRState { } case object NoRelationalLetsState extends IRState { - val rules: Array[Rule] = Array( - new Rule { - def allows(ir: BaseIR): Boolean = ir match { - case _: RelationalRef => false - case _: RelationalLet => false - case _ => true - } - } - ) + val rules: Array[Rule] = Array { + case _: RelationalRef => false + case _: RelationalLet => false + case _ => true + } } case object LoweredShuffles extends IRState { - val rules: Array[Rule] = Array( - new Rule { - def allows(ir: BaseIR): Boolean = ir match { - case t: TableKeyBy => t.definitelyDoesNotShuffle - case _: TableKeyByAndAggregate => false - case t: TableOrderBy => t.definitelyDoesNotShuffle - case _ => true - } - } - ) + val rules: Array[Rule] = Array { + case t: TableKeyBy => t.definitelyDoesNotShuffle + case _: TableKeyByAndAggregate => false + case t: TableOrderBy => t.definitelyDoesNotShuffle + case _ => true + } } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerAndExecuteShuffles.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerAndExecuteShuffles.scala index 92593364df29..62fdcf33edb3 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerAndExecuteShuffles.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerAndExecuteShuffles.scala @@ -9,124 +9,69 @@ import is.hail.utils.FastSeq object LowerAndExecuteShuffles { - def apply(ir: BaseIR, ctx: ExecuteContext, passesBelow: LoweringPipeline): BaseIR = { - RewriteBottomUp( - ir, - { - case t @ TableKeyBy(child, key, _) if !t.definitelyDoesNotShuffle => - val r = Requiredness(child, ctx) - val reader = ctx.backend.lowerDistributedSort( - ctx, - child, - key.map(k => SortField(k, Ascending)), - r.lookup(child).asInstanceOf[RTable], - ) - Some(TableRead(t.typ, false, reader)) - - case t @ TableOrderBy(child, sortFields) if !t.definitelyDoesNotShuffle => - val r = Requiredness(child, ctx) - val reader = ctx.backend.lowerDistributedSort( - ctx, - child, - sortFields, - r.lookup(child).asInstanceOf[RTable], - ) - Some(TableRead(t.typ, false, reader)) - - case t @ TableKeyByAndAggregate(child, expr, newKey, nPartitions, bufferSize) => - val newKeyType = newKey.typ.asInstanceOf[TStruct] - - val req = Requiredness(t, ctx) - - val aggs = Extract(expr, req) - val postAggIR = aggs.postAggIR - val init = aggs.init - val seq = aggs.seqPerElt - val aggSigs = aggs.aggs - - var ts = child - - val origGlobalTyp = ts.typ.globalType - ts = TableKeyBy(child, IndexedSeq()) - ts = TableMapGlobals( - ts, - MakeStruct(FastSeq( - ("oldGlobals", Ref(TableIR.globalName, origGlobalTyp)), - ( - "__initState", - RunAgg( - init, - MakeTuple.ordered(aggSigs.indices.map { aIdx => - AggStateValue(aIdx, aggSigs(aIdx).state) - }), - aggSigs.map(_.state), + def apply(ir: BaseIR, ctx: ExecuteContext, passesBelow: LoweringPipeline): BaseIR = + ctx.time { + RewriteBottomUp( + ir, + { + case t @ TableKeyBy(child, key, _) if !t.definitelyDoesNotShuffle => + val r = Requiredness(child, ctx) + val reader = ctx.backend.lowerDistributedSort( + ctx, + child, + key.map(k => SortField(k, Ascending)), + r.lookup(child).asInstanceOf[RTable], + ) + Some(TableRead(t.typ, false, reader)) + + case t @ TableOrderBy(child, sortFields) if !t.definitelyDoesNotShuffle => + val r = Requiredness(child, ctx) + val reader = ctx.backend.lowerDistributedSort( + ctx, + child, + sortFields, + r.lookup(child).asInstanceOf[RTable], + ) + Some(TableRead(t.typ, false, reader)) + + case t @ TableKeyByAndAggregate(child, expr, newKey, nPartitions, bufferSize) => + val newKeyType = newKey.typ.asInstanceOf[TStruct] + + val req = Requiredness(t, ctx) + + val aggs = Extract(expr, req) + val postAggIR = aggs.postAggIR + val init = aggs.init + val seq = aggs.seqPerElt + val aggSigs = aggs.aggs + + var ts = child + + val origGlobalTyp = ts.typ.globalType + ts = TableKeyBy(child, IndexedSeq()) + ts = TableMapGlobals( + ts, + MakeStruct(FastSeq( + ("oldGlobals", Ref(TableIR.globalName, origGlobalTyp)), + ( + "__initState", + RunAgg( + init, + MakeTuple.ordered(aggSigs.indices.map { aIdx => + AggStateValue(aIdx, aggSigs(aIdx).state) + }), + aggSigs.map(_.state), + ), ), - ), - )), - ) + )), + ) - val partiallyAggregated = - mapPartitions(ts) { (insGlob, partStream) => - Let( - FastSeq(TableIR.globalName -> GetField(insGlob, "oldGlobals")), - StreamBufferedAggregate( - partStream, - bindIR(GetField(insGlob, "__initState")) { states => - Begin(aggSigs.indices.map { aIdx => - InitFromSerializedValue( - aIdx, - GetTupleElement(states, aIdx), - aggSigs(aIdx).state, - ) - }) - }, - newKey, - seq, - TableIR.rowName, - aggSigs, - bufferSize, - ), - ) - }.noSharing(ctx) - - val analyses = LoweringAnalyses(partiallyAggregated, ctx) - val preShuffleStage = ctx.backend.tableToTableStage(ctx, partiallyAggregated, analyses) - // annoying but no better alternative right now - val rt = analyses.requirednessAnalysis.lookup(partiallyAggregated).asInstanceOf[RTable] - val partiallyAggregatedReader = ctx.backend.lowerDistributedSort( - ctx, - preShuffleStage, - newKeyType.fieldNames.map(k => SortField(k, Ascending)), - rt, - nPartitions, - ) - - val takeVirtualSig = - TakeStateSig(VirtualTypeWithReq(newKeyType, rt.rowType.select(newKeyType.fieldNames))) - val takeAggSig = PhysicalAggSig(Take(), takeVirtualSig) - val aggStateSigsPlusTake = aggs.states ++ Array(takeVirtualSig) - - val result = ResultOp(aggs.aggs.length, takeAggSig) - - val shuffleRead = - TableRead(partiallyAggregatedReader.fullType, false, partiallyAggregatedReader) - - val postAggUID = Ref(freshName(), postAggIR.typ) - val resultFromTakeUID = Ref(freshName(), result.typ) - val tmp = mapPartitions( - shuffleRead, - newKeyType.size, - newKeyType.size - 1, - ) { (insGlob, shuffledPartStream) => - Let( - FastSeq(TableIR.globalName -> GetField(insGlob, "oldGlobals")), - mapIR(StreamGroupByKey( - shuffledPartStream, - newKeyType.fieldNames.toIndexedSeq, - missingEqual = true, - )) { groupRef => - RunAgg( - Begin(FastSeq( + val partiallyAggregated = + mapPartitions(ts) { (insGlob, partStream) => + Let( + FastSeq(TableIR.globalName -> GetField(insGlob, "oldGlobals")), + StreamBufferedAggregate( + partStream, bindIR(GetField(insGlob, "__initState")) { states => Begin(aggSigs.indices.map { aIdx => InitFromSerializedValue( @@ -136,55 +81,111 @@ object LowerAndExecuteShuffles { ) }) }, - InitOp( - aggSigs.length, - IndexedSeq(I32(1)), - PhysicalAggSig(Take(), takeVirtualSig), - ), - forIR(groupRef) { elem => - Begin(FastSeq( - SeqOp( - aggSigs.length, - IndexedSeq(SelectFields(elem, newKeyType.fieldNames)), - PhysicalAggSig(Take(), takeVirtualSig), - ), - Begin((0 until aggSigs.length).map { aIdx => - CombOpValue( - aIdx, - GetTupleElement(GetField(elem, "agg"), aIdx), - aggSigs(aIdx), - ) - }), - )) - }, - )), - Let( - FastSeq( - aggs.resultRef.name -> ResultOp.makeTuple(aggs.aggs), - postAggUID.name -> postAggIR, - resultFromTakeUID.name -> result, - ), { - val keyIRs: IndexedSeq[(String, IR)] = - newKeyType.fieldNames.map(keyName => - keyName -> GetField(ArrayRef(resultFromTakeUID, 0), keyName) - ) - - MakeStruct(keyIRs ++ expr.typ.asInstanceOf[TStruct].fieldNames.map { f => - (f, GetField(postAggUID, f)) - }) - }, + newKey, + seq, + TableIR.rowName, + aggSigs, + bufferSize, ), - aggStateSigsPlusTake, ) - }, + }.noSharing(ctx) + + val analyses = LoweringAnalyses(partiallyAggregated, ctx) + val preShuffleStage = ctx.backend.tableToTableStage(ctx, partiallyAggregated, analyses) + // annoying but no better alternative right now + val rt = analyses.requirednessAnalysis.lookup(partiallyAggregated).asInstanceOf[RTable] + val partiallyAggregatedReader = ctx.backend.lowerDistributedSort( + ctx, + preShuffleStage, + newKeyType.fieldNames.map(k => SortField(k, Ascending)), + rt, + nPartitions, ) - } - Some(TableMapGlobals( - tmp, - GetField(Ref(TableIR.globalName, tmp.typ.globalType), "oldGlobals"), - )) - case _ => None - }, - ) - } + + val takeVirtualSig = + TakeStateSig(VirtualTypeWithReq(newKeyType, rt.rowType.select(newKeyType.fieldNames))) + val takeAggSig = PhysicalAggSig(Take(), takeVirtualSig) + val aggStateSigsPlusTake = aggs.states ++ Array(takeVirtualSig) + + val result = ResultOp(aggs.aggs.length, takeAggSig) + + val shuffleRead = + TableRead(partiallyAggregatedReader.fullType, false, partiallyAggregatedReader) + + val postAggUID = Ref(freshName(), postAggIR.typ) + val resultFromTakeUID = Ref(freshName(), result.typ) + val tmp = mapPartitions( + shuffleRead, + newKeyType.size, + newKeyType.size - 1, + ) { (insGlob, shuffledPartStream) => + Let( + FastSeq(TableIR.globalName -> GetField(insGlob, "oldGlobals")), + mapIR(StreamGroupByKey( + shuffledPartStream, + newKeyType.fieldNames.toIndexedSeq, + missingEqual = true, + )) { groupRef => + RunAgg( + Begin(FastSeq( + bindIR(GetField(insGlob, "__initState")) { states => + Begin(aggSigs.indices.map { aIdx => + InitFromSerializedValue( + aIdx, + GetTupleElement(states, aIdx), + aggSigs(aIdx).state, + ) + }) + }, + InitOp( + aggSigs.length, + IndexedSeq(I32(1)), + PhysicalAggSig(Take(), takeVirtualSig), + ), + forIR(groupRef) { elem => + Begin(FastSeq( + SeqOp( + aggSigs.length, + IndexedSeq(SelectFields(elem, newKeyType.fieldNames)), + PhysicalAggSig(Take(), takeVirtualSig), + ), + Begin((0 until aggSigs.length).map { aIdx => + CombOpValue( + aIdx, + GetTupleElement(GetField(elem, "agg"), aIdx), + aggSigs(aIdx), + ) + }), + )) + }, + )), + Let( + FastSeq( + aggs.resultRef.name -> ResultOp.makeTuple(aggs.aggs), + postAggUID.name -> postAggIR, + resultFromTakeUID.name -> result, + ), { + val keyIRs: IndexedSeq[(String, IR)] = + newKeyType.fieldNames.map(keyName => + keyName -> GetField(ArrayRef(resultFromTakeUID, 0), keyName) + ) + + MakeStruct(keyIRs ++ expr.typ.asInstanceOf[TStruct].fieldNames.map { f => + (f, GetField(postAggUID, f)) + }) + }, + ), + aggStateSigsPlusTake, + ) + }, + ) + } + Some(TableMapGlobals( + tmp, + GetField(Ref(TableIR.globalName, tmp.typ.globalType), "oldGlobals"), + )) + case _ => None + }, + ) + } } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala index f821e58da245..dd8975ad2633 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala @@ -23,16 +23,16 @@ final class IrMetadata() { } } -trait LoweringPass { +abstract class LoweringPass(implicit E: sourcecode.Enclosing) { val before: IRState val after: IRState val context: String final def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR = - ctx.timer.time(context) { - ctx.timer.time("Verify")(before.verify(ir)) - val result = ctx.timer.time("Transform")(transform(ctx, ir)) - ctx.timer.time("Verify")(after.verify(result)) + ctx.time { + before.verify(ctx, ir) + val result = transform(ctx, ir) + after.verify(ctx, result) result } @@ -96,13 +96,16 @@ case object InlineApplyIR extends LoweringPass { val after: IRState = CompilableIRNoApply val context: String = "InlineApplyIR" - override def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = RewriteBottomUp( - ir, - { - case x: ApplyIR => Some(x.explicitNode) - case _ => None - }, - ) + override def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = + ctx.time { + RewriteBottomUp( + ir, + { + case x: ApplyIR => Some(x.explicitNode) + case _ => None + }, + ) + } } case object LowerArrayAggsToRunAggsPass extends LoweringPass { @@ -110,53 +113,54 @@ case object LowerArrayAggsToRunAggsPass extends LoweringPass { val after: IRState = EmittableIR val context: String = "LowerArrayAggsToRunAggs" - def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = { - val x = ir.noSharing(ctx) - val r = Requiredness(x, ctx) - RewriteBottomUp( - x, - { - case x @ StreamAgg(a, name, query) => - val aggs = Extract(query, r) - - val newNode = aggs.rewriteFromInitBindingRoot { root => - Let( - FastSeq( - aggs.resultRef.name -> RunAgg( - Begin(FastSeq( - aggs.init, - StreamFor(a, name, aggs.seqPerElt), - )), - aggs.results, - aggs.states, - ) - ), - root, - ) - } - - if (newNode.typ != x.typ) - throw new RuntimeException(s"types differ:\n new: ${newNode.typ}\n old: ${x.typ}") - Some(newNode.noSharing(ctx)) - case x @ StreamAggScan(a, name, query) => - val aggs = Extract(query, r, isScan = true) - val newNode = aggs.rewriteFromInitBindingRoot { root => - RunAggScan( - a, - name, - aggs.init, - aggs.seqPerElt, - Let(FastSeq(aggs.resultRef.name -> aggs.results), root), - aggs.states, - ) - } - if (newNode.typ != x.typ) - throw new RuntimeException(s"types differ:\n new: ${newNode.typ}\n old: ${x.typ}") - Some(newNode.noSharing(ctx)) - case _ => None - }, - ) - } + def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = + ctx.time { + val x = ir.noSharing(ctx) + val r = Requiredness(x, ctx) + RewriteBottomUp( + x, + { + case x @ StreamAgg(a, name, query) => + val aggs = Extract(query, r) + + val newNode = aggs.rewriteFromInitBindingRoot { root => + Let( + FastSeq( + aggs.resultRef.name -> RunAgg( + Begin(FastSeq( + aggs.init, + StreamFor(a, name, aggs.seqPerElt), + )), + aggs.results, + aggs.states, + ) + ), + root, + ) + } + + if (newNode.typ != x.typ) + throw new RuntimeException(s"types differ:\n new: ${newNode.typ}\n old: ${x.typ}") + Some(newNode.noSharing(ctx)) + case x @ StreamAggScan(a, name, query) => + val aggs = Extract(query, r, isScan = true) + val newNode = aggs.rewriteFromInitBindingRoot { root => + RunAggScan( + a, + name, + aggs.init, + aggs.seqPerElt, + Let(FastSeq(aggs.resultRef.name -> aggs.results), root), + aggs.states, + ) + } + if (newNode.typ != x.typ) + throw new RuntimeException(s"types differ:\n new: ${newNode.typ}\n old: ${x.typ}") + Some(newNode.noSharing(ctx)) + case _ => None + }, + ) + } } case class EvalRelationalLetsPass(passesBelow: LoweringPipeline) extends LoweringPass { diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPipeline.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPipeline.scala index fd3a329abd85..2113f1f7b384 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPipeline.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPipeline.scala @@ -13,13 +13,13 @@ case class LoweringPipeline(lowerings: LoweringPass*) { def render(context: String): Unit = if (ctx.shouldLogIR()) - log.info(s"$context: IR size ${IRSize(x)}: \n" + Pretty(ctx, x, elideLiterals = true)) + log.info(s"$context: IR size ${IRSize(x)}: \n" + Pretty(ctx, x)) render(s"initial IR") - lowerings.foreach { l => + for (l <- lowerings) { try { - x = l.apply(ctx, x) + x = l(ctx, x) render(s"after ${l.context}") } catch { case e: Throwable => diff --git a/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala index f3a8b300d7ef..77bc6ff00b40 100644 --- a/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala @@ -1,6 +1,7 @@ package is.hail.expr.ir import is.hail.HailSuite +import is.hail.backend.ExecuteContext import is.hail.expr.Nat import is.hail.types.virtual._ import is.hail.utils._ @@ -117,41 +118,41 @@ class ForwardLetsSuite extends HailSuite { FastSeq(Binding(x.name, I32(1), Scope.AGG), Binding(y.name, x, Scope.AGG)), ApplyAggOp(Sum())(y), ) - val after: IR = ForwardLets(ctx)(ir) + val after: IR = ForwardLets(ctx, ir) val expected = ApplyAggOp(Sum())(I32(1)) - assert(NormalizeNames(ctx, after) == NormalizeNames(ctx, expected)) + assert(NormalizeNames()(ctx, after) == NormalizeNames()(ctx, expected)) } @Test(dataProvider = "nonForwardingOps") def testNonForwardingOps(ir: IR): Unit = { - val after = ForwardLets(ctx)(ir) - val normalizedBefore = NormalizeNames(ctx, ir) - val normalizedAfter = NormalizeNames(ctx, after) + val after = ForwardLets(ctx, ir) + val normalizedBefore = NormalizeNames()(ctx, ir) + val normalizedAfter = NormalizeNames()(ctx, after) assert(normalizedBefore == normalizedAfter) } @Test(dataProvider = "nonForwardingNonEvalOps") def testNonForwardingNonEvalOps(ir: IR): Unit = { - val after = ForwardLets(ctx)(ir) + val after = ForwardLets(ctx, ir) assert(after.isInstanceOf[Block]) } @Test(dataProvider = "nonForwardingAggOps") def testNonForwardingAggOps(ir: IR): Unit = { - val after = ForwardLets(ctx)(ir) + val after = ForwardLets(ctx, ir) assert(after.isInstanceOf[Block]) } @Test(dataProvider = "forwardingOps") def testForwardingOps(ir: IR): Unit = { - val after = ForwardLets(ctx)(ir) + val after = ForwardLets(ctx, ir) assert(!after.isInstanceOf[Block]) assertEvalSame(ir, args = Array(5 -> TInt32)) } @Test(dataProvider = "forwardingAggOps") def testForwardingAggOps(ir: IR): Unit = { - val after = ForwardLets(ctx)(ir) + val after = ForwardLets(ctx, ir) assert(!after.isInstanceOf[Block]) } @@ -218,10 +219,11 @@ class ForwardLetsSuite extends HailSuite { @Test(dataProvider = "TrivialIRCases") def testTrivialCases(input: IR, _expected: IR, reason: String): Unit = { - val result = NormalizeNames(ctx, ForwardLets(ctx)(input), allowFreeVariables = true) - val expected = NormalizeNames(ctx, _expected, allowFreeVariables = true) + val normalize: (ExecuteContext, BaseIR) => BaseIR = NormalizeNames(allowFreeVariables = true) + val result = normalize(ctx, ForwardLets(ctx, input)) + val expected = normalize(ctx, _expected) assert( - result == NormalizeNames(ctx, expected, allowFreeVariables = true), + result == normalize(ctx, expected), s"\ninput:\n${Pretty.sexprStyle(input)}\nexpected:\n${Pretty.sexprStyle(expected)}\ngot:\n${Pretty.sexprStyle(result)}\n$reason", ) } @@ -236,7 +238,7 @@ class ForwardLetsSuite extends HailSuite { AggSignature(Sum(), FastSeq(), FastSeq(TFloat64)), ) - TypeCheck(ctx, ForwardLets(ctx)(ir0), BindingEnv(Env.empty, agg = Some(aggEnv))) + TypeCheck(ctx, ForwardLets(ctx, ir0), BindingEnv(Env.empty, agg = Some(aggEnv))) } @Test def testNestedBindingOverwrites(): Unit = { @@ -246,7 +248,7 @@ class ForwardLetsSuite extends HailSuite { val ir = bindIRs(xCast, xCast) { case Seq(x1, x2) => x2 + x2 + x1 } TypeCheck(ctx, ir, BindingEnv(env)) - TypeCheck(ctx, ForwardLets(ctx)(ir), BindingEnv(env)) + TypeCheck(ctx, ForwardLets(ctx, ir), BindingEnv(env)) } @Test def testLetsDoNotForwardInsideArrayAggWithNoOps(): Unit = { @@ -256,6 +258,6 @@ class ForwardLetsSuite extends HailSuite { )(x => streamAggIR(ToStream(In(1, TArray(TInt32))))(_ => y + x)) TypeCheck(ctx, x, BindingEnv(Env(y.name -> TInt32))) - TypeCheck(ctx, ForwardLets(ctx)(x), BindingEnv(Env(y.name -> TInt32))) + TypeCheck(ctx, ForwardLets(ctx, x), BindingEnv(Env(y.name -> TInt32))) } } diff --git a/hail/src/test/scala/is/hail/expr/ir/analyses/SemanticHashSuite.scala b/hail/src/test/scala/is/hail/expr/ir/analyses/SemanticHashSuite.scala index 8866db7d0340..bb38e9e24f17 100644 --- a/hail/src/test/scala/is/hail/expr/ir/analyses/SemanticHashSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/analyses/SemanticHashSuite.scala @@ -295,7 +295,7 @@ class SemanticHashSuite extends HailSuite { isEqual, s"expected semhash($a) ${if (isEqual) "==" else "!="} semhash($b), $comment", )( - SemanticHash(ctx)(a) == SemanticHash(ctx)(b) + SemanticHash(ctx, a) == SemanticHash(ctx, b) ) } @@ -311,12 +311,12 @@ class SemanticHashSuite extends HailSuite { ctx.local(fs = fs) { ctx => assertResult(None, "SemHash should be resilient to FileNotFoundExceptions.")( - SemanticHash(ctx)(ir) + SemanticHash(ctx, ir) ) } } - val fakeFs: FS = + private[this] val fakeFs: FS = new FakeFS { override def eTag(url: FakeURL): Option[String] = Some(url.getPath)