Skip to content

Commit

Permalink
[query] Lowering + Optimisation with implict timing context
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Oct 17, 2024
1 parent 96def90 commit 9cf65a9
Show file tree
Hide file tree
Showing 22 changed files with 705 additions and 702 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ object LocalBackend extends Backend {
Validate(ir)
val queryID = Backend.nextID()
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
val res = _jvmLowerAndExecute(ctx, ir)
log.info(s"finished execution of query $queryID")
res
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ class ServiceBackend(
Validate(ir)
val queryID = Backend.nextID()
log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}")
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
val res = _jvmLowerAndExecute(ctx, ir)
log.info(s"finished execution of query $queryID")
res
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ class SparkBackend(val sc: SparkContext) extends Backend {
ctx.time {
TypeCheck(ctx, ir)
Validate(ir)
ctx.irMetadata.semhash = SemanticHash(ctx)(ir)
ctx.irMetadata.semhash = SemanticHash(ctx, ir)
try {
val lowerTable = ctx.flags.get("lower") != null
val lowerBM = ctx.flags.get("lower_bm") != null
Expand Down
10 changes: 5 additions & 5 deletions hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ abstract class BaseIR {
// New sentinel values can be obtained by `nextFlag` on `IRMetadata`.
var mark: Int = 0

def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean =
/* FIXME: rewrite to not rebuild the irs, by maintaining an env mapping left names to right
* names */
NormalizeNames(ctx, this, allowFreeVariables = true) ==
NormalizeNames(ctx, other, allowFreeVariables = true)
def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean = {
// FIXME: rewrite to not rebuild the irs by maintaining an env mapping left to right names
val normalize: (ExecuteContext, BaseIR) => BaseIR = NormalizeNames(allowFreeVariables = true)
normalize(ctx, this) == normalize(ctx, other)
}

def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): BaseIR = {
val newChildren = childrenSeq.view.zipWithIndex.map(f.tupled).toArray
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ object compile {
N: sourcecode.Name,
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F with Mixin) =
ctx.time {
val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true)
val normalizedBody = NormalizeNames(allowFreeVariables = true)(ctx, body)
ctx.CodeCache.getOrElseUpdate(
CodeCacheKey(aggSigs.getOrElse(Array.empty).toFastSeq, params, normalizedBody), {
var ir = Subst(
Expand Down
73 changes: 37 additions & 36 deletions hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,52 +24,53 @@ object ExtractIntervalFilters {

val MAX_LITERAL_SIZE = 4096

def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR = {
MapIR.mapBaseIR(
ir0,
(ir: BaseIR) => {
(
ir match {
case TableFilter(child, pred) =>
extractPartitionFilters(
ctx,
pred,
Ref(TableIR.rowName, child.typ.rowType),
child.typ.key,
)
.map { case (newCond, intervals) =>
def apply(ctx: ExecuteContext, ir0: BaseIR): BaseIR =
ctx.time {
MapIR.mapBaseIR(
ir0,
ir =>
(
ir match {
case TableFilter(child, pred) =>
extractPartitionFilters(
ctx,
pred,
Ref(TableIR.rowName, child.typ.rowType),
child.typ.key,
)
.map { case (newCond, intervals) =>
log.info(
s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " +
s"Intervals: ${intervals.mkString(", ")}\n " +
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
)
TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond)
}
case MatrixFilterRows(child, pred) =>
extractPartitionFilters(
ctx,
pred,
Ref(MatrixIR.rowName, child.typ.rowType),
child.typ.rowKey,
).map { case (newCond, intervals) =>
log.info(
s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " +
s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " +
s"Intervals: ${intervals.mkString(", ")}\n " +
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
)
TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond)
MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond)
}
case MatrixFilterRows(child, pred) => extractPartitionFilters(
ctx,
pred,
Ref(MatrixIR.rowName, child.typ.rowType),
child.typ.rowKey,
).map { case (newCond, intervals) =>
log.info(
s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " +
s"Intervals: ${intervals.mkString(", ")}\n " +
s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}"
)
MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond)
}

case _ => None
}
).getOrElse(ir)
},
)
}
case _ => None
}
).getOrElse(ir),
)
}

def extractPartitionFilters(ctx: ExecuteContext, cond: IR, ref: Ref, key: IndexedSeq[String])
: Option[(IR, IndexedSeq[Interval])] = {
if (key.isEmpty) None
else {
else ctx.time {
val extract =
new ExtractIntervalFilters(ctx, ref.typ.asInstanceOf[TStruct].typeAfterSelectNames(key))
val trueSet = extract.analyze(cond, ref.name)
Expand Down
4 changes: 3 additions & 1 deletion hail/src/main/scala/is/hail/expr/ir/FoldConstants.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import is.hail.utils.HailException

object FoldConstants {
def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR =
ctx.r.pool.scopedRegion(region => ctx.local(r = region)(foldConstants(_, ir)))
ctx.time {
ctx.r.pool.scopedRegion(r => ctx.local(r = r)(foldConstants(_, ir)))
}

private def foldConstants(ctx: ExecuteContext, ir: BaseIR): BaseIR =
RewriteBottomUp(
Expand Down
95 changes: 53 additions & 42 deletions hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@ package is.hail.expr.ir

import is.hail.backend.ExecuteContext
import is.hail.types.virtual.TVoid
import is.hail.utils.BoxedArrayBuilder
import is.hail.utils.{fatal, BoxedArrayBuilder}

import scala.collection.Set
import scala.util.control.NonFatal

object ForwardLets {
def apply[T <: BaseIR](ctx: ExecuteContext)(ir0: T): T = {
val ir1 = NormalizeNames(ctx, ir0, allowFreeVariables = true)
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
val nestingDepth = NestingDepth(ir1)

def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR = {
def apply[T <: BaseIR](ctx: ExecuteContext, ir0: T): T =
ctx.time {
val ir1 = NormalizeNames(allowFreeVariables = true)(ctx, ir0)
val UsesAndDefs(uses, defs, _) = ComputeUsesAndDefs(ir1, errorIfFreeVariables = false)
val nestingDepth = NestingDepth(ctx, ir1)

def shouldForward(value: IR, refs: Set[RefEquality[BaseRef]], base: Block, scope: Int)
: Boolean = {
: Boolean =
IsPure(value) && (
value.isInstanceOf[Ref] ||
value.isInstanceOf[In] ||
Expand All @@ -27,45 +27,56 @@ object ForwardLets {
!ContainsAgg(value)) &&
!ContainsAggIntermediate(value)
)
}

ir match {
case l: Block =>
val keep = new BoxedArrayBuilder[Binding]
val refs = uses(l)
val newEnv = l.bindings.foldLeft(env) {
case (env, Binding(name, value, scope)) =>
val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR]
if (
rewriteValue.typ != TVoid
&& shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope)
) {
env.bindInScope(name, rewriteValue, scope)
} else {
keep += Binding(name, rewriteValue, scope)
env
def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR =
ir match {
case l: Block =>
val keep = new BoxedArrayBuilder[Binding]
val refs = uses(l)
val newEnv = l.bindings.foldLeft(env) {
case (env, Binding(name, value, scope)) =>
val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR]
if (
rewriteValue.typ != TVoid
&& shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope)
) {
env.bindInScope(name, rewriteValue, scope)
} else {
keep += Binding(name, rewriteValue, scope)
env
}
}

val newBody = rewrite(l.body, newEnv).asInstanceOf[IR]
if (keep.isEmpty) newBody
else Block(keep.result(), newBody)

case x @ Ref(name, _) =>
env.eval
.lookupOption(name)
.map { forwarded =>
if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy()
else forwarded
}
}
.getOrElse(x)
case _ =>
ir.mapChildrenWithIndex((ir1, i) =>
rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings))
)
}

val newBody = rewrite(l.body, newEnv).asInstanceOf[IR]
if (keep.isEmpty) newBody
else Block(keep.result(), newBody)
val ir = rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty)))

case x @ Ref(name, _) =>
env.eval
.lookupOption(name)
.map { forwarded =>
if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy()
else forwarded
}
.getOrElse(x)
case _ =>
ir.mapChildrenWithIndex((ir1, i) =>
rewrite(ir1, env.extend(Bindings.get(ir, i).dropBindings))
try
TypeCheck(ctx, ir)
catch {
case NonFatal(e) =>
fatal(
s"bad ir from ForwardLets, started as\n${Pretty(ctx, ir0, preserveNames = true)}",
e,
)
}
}

rewrite(ir1, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty))).asInstanceOf[T]
}
ir.asInstanceOf[T]
}
}
Loading

0 comments on commit 9cf65a9

Please sign in to comment.