From d7533095ab778135abf6feaee74afa95637bc079 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Tue, 17 Dec 2024 11:16:26 -0500 Subject: [PATCH] codegen all value ir node classes --- hail/build.mill | 18 + hail/hail/ir-gen/src/Main.scala | 1106 +++++++++++++ hail/hail/src/is/hail/expr/Parser.scala | 10 +- hail/hail/src/is/hail/expr/ir/Children.scala | 267 ---- hail/hail/src/is/hail/expr/ir/Copy.scala | 607 -------- hail/hail/src/is/hail/expr/ir/Emit.scala | 6 +- .../is/hail/expr/ir/GenericTableValue.scala | 3 +- hail/hail/src/is/hail/expr/ir/IR.scala | 1379 +++++------------ hail/hail/src/is/hail/expr/ir/Interpret.scala | 12 +- .../src/is/hail/expr/ir/MatrixWriter.scala | 3 +- hail/hail/src/is/hail/expr/ir/Parser.scala | 7 +- hail/hail/src/is/hail/expr/ir/TableIR.scala | 3 +- hail/hail/src/is/hail/expr/ir/TypeCheck.scala | 2 +- .../expr/ir/functions/ArrayFunctions.scala | 4 +- .../is/hail/expr/ir/functions/Functions.scala | 20 +- .../expr/ir/lowering/LoweringPipeline.scala | 4 +- .../is/hail/expr/ir/streams/EmitStream.scala | 3 +- .../src/is/hail/io/index/IndexReader.scala | 3 +- .../src/is/hail/io/index/IndexWriter.scala | 3 +- hail/hail/src/is/hail/io/vcf/LoadVCF.scala | 6 +- .../src/is/hail/rvd/AbstractRVDSpec.scala | 3 +- .../hail/types/virtual/BlockMatrixType.scala | 5 +- .../is/hail/expr/ir/Aggregators2Suite.scala | 3 +- .../is/hail/expr/ir/FoldConstantsSuite.scala | 3 +- .../test/src/is/hail/expr/ir/IRSuite.scala | 6 +- .../ir/{ => defs}/EncodedLiteralSuite.scala | 3 +- .../test/src/is/hail/io/IndexBTreeSuite.scala | 3 +- .../hail/test/src/is/hail/io/IndexSuite.scala | 6 +- .../src/is/hail/linalg/BlockMatrixSuite.scala | 84 +- .../test/src/is/hail/utils/UtilsSuite.scala | 3 +- 30 files changed, 1662 insertions(+), 1923 deletions(-) create mode 100644 hail/hail/ir-gen/src/Main.scala delete mode 100644 hail/hail/src/is/hail/expr/ir/Children.scala delete mode 100644 hail/hail/src/is/hail/expr/ir/Copy.scala rename hail/hail/test/src/is/hail/expr/ir/{ => defs}/EncodedLiteralSuite.scala (88%) diff --git a/hail/build.mill b/hail/build.mill index 6a14da4d4b3..ab7b2d2afc6 100644 --- a/hail/build.mill +++ b/hail/build.mill @@ -120,6 +120,7 @@ trait HailModule extends ScalaModule with ScalafmtModule with ScalafixModule { o "-Xlint", "-Ywarn-unused:_,-explicits,-implicits", "-Wconf:msg=legacy-binding:s", + "-feature", ) ++ ( if (build.env.debugMode()) Seq() else Seq( @@ -175,6 +176,10 @@ object hail extends HailModule { outer => buildInfo(), ) + override def generatedSources: T[Seq[PathRef]] = Task { + Seq(`ir-gen`.generate()) + } + override def unmanagedClasspath: T[Agg[PathRef]] = Agg(shadedazure.assembly()) @@ -246,6 +251,19 @@ object hail extends HailModule { outer => PathRef(T.dest) } + object `ir-gen` extends HailModule { + def ivyDeps = Agg( + ivy"com.lihaoyi::mainargs:0.6.2", + ivy"com.lihaoyi::os-lib:0.10.7", + ivy"com.lihaoyi::sourcecode:0.4.2", + ) + + def generate: T[PathRef] = Task { + runner().run(Args("--path", T.dest).value) + PathRef(T.dest) + } + } + object memory extends JavaModule { // with CrossValue { override def zincIncrementalCompilation: T[Boolean] = false diff --git a/hail/hail/ir-gen/src/Main.scala b/hail/hail/ir-gen/src/Main.scala new file mode 100644 index 00000000000..01186d101b6 --- /dev/null +++ b/hail/hail/ir-gen/src/Main.scala @@ -0,0 +1,1106 @@ +import scala.annotation.nowarn +import scala.language.implicitConversions + +import mainargs.{main, ParserForMethods} + +trait IRDSL { + def node(name: String, attsAndChildren: NamedAttOrChildPack*): IR + + def att(typ: String): AttOrChildPack + val binding: AttOrChildPack + def tup(elts: AttOrChildPack*): AttOrChildPack + def child: AttOrChildPack + def child(t: String): AttOrChildPack + + def named(name: String, pack: AttOrChildPack): NamedAttOrChildPack + + val TrivialIR: Trait + val BaseRef: Trait + def TypedIR(t: String): Trait + val NDArrayIR: Trait + // AbstractApplyNodeUnseededMissingness{Aware, Oblivious}JVMFunction + def ApplyNode(missingnessAware: Boolean = false): Trait + + type IR <: IR_Interface + type AttOrChildPack <: AttOrChildPack_Interface + type NamedAttOrChildPack <: NamedAttOrChildPack_Interface + type Trait + + trait IR_Interface { + def withTraits(newTraits: Trait*): IR + def withMethod(methodDef: String): IR + def typed(typ: String): IR + def withConstraint(c: String): IR + def withCompanionExtension: IR + def withClassExtension: IR + def withDocstring(docstring: String): IR + + def generateDef: String + } + + trait AttOrChildPack_Interface { + def withConstraint(c: String => String): AttOrChildPack + def * : AttOrChildPack + def + : AttOrChildPack + def ? : AttOrChildPack + } + + trait NamedAttOrChildPack_Interface { + def withDefault(value: String): NamedAttOrChildPack + def mutable: NamedAttOrChildPack + } + + // each `WithDefaultName` object can be used in a context expecting either + // an `AttOrChildPack` or a `NamedAttOrChildPack`, using a default name in + // the latter case + object name extends WithDefaultName(att("Name"), "name") + object key extends WithDefaultName(att("IndexedSeq[String]"), "key") + object tableChild extends WithDefaultName(child("TableIR"), "child") + object matrixChild extends WithDefaultName(child("MatrixIR"), "child") + object blockMatrixChild extends WithDefaultName(child("BlockMatrixIR"), "child") + + implicit def makeNamedPack(tup: (String, AttOrChildPack)): NamedAttOrChildPack = + named(tup._1, tup._2) + + abstract class WithDefaultName(t: AttOrChildPack, defaultName: String) { + implicit def makeNamedChild(tup: (String, this.type)): NamedAttOrChildPack = + named(tup._1, t) + + implicit def makeDefaultNamedChild(x: this.type): NamedAttOrChildPack = + makeNamedChild((defaultName, x)) + + implicit def makeChild(x: this.type): AttOrChildPack = t + } +} + +object IRDSL_Impl extends IRDSL { + def node(name: String, attsAndChildren: NamedAttOrChildPack*): IR = + IR(name, attsAndChildren) + + def att(typ: String): AttOrChildPack = Att(typ) + val binding: AttOrChildPack = Binding + def tup(elts: AttOrChildPack*): AttOrChildPack = Tup(elts: _*) + def child: AttOrChildPack = Child() + def child(t: String): AttOrChildPack = Child(t) + + def named(name: String, pack: AttOrChildPack): NamedAttOrChildPack = + NamedAttOrChildPack(name, pack) + + final case class Trait(name: String) + + val TrivialIR: Trait = Trait("TrivialIR") + val BaseRef: Trait = Trait("BaseRef") + def TypedIR(typ: String): Trait = Trait(s"TypedIR[$typ]") + val NDArrayIR: Trait = Trait("NDArrayIR") + + def ApplyNode(missingnessAware: Boolean = false): Trait = { + val t = + s"AbstractApplyNode[UnseededMissingness${if (missingnessAware) "Aware" else "Oblivious"}JVMFunction]" + Trait(t) + } + + case class NChildren(static: Int = 0, dynamic: String = "") { + override def toString: String = (static, dynamic) match { + case (s, "") => s.toString + case (0, d) => d + case _ => s"$dynamic + $static" + } + + def getStatic: Option[Int] = if (dynamic.isEmpty) Some(static) else None + + def hasStaticValue(i: Int): Boolean = getStatic.contains(i) + + def +(other: NChildren): NChildren = NChildren( + static = static + other.static, + dynamic = (dynamic, other.dynamic) match { + case ("", r) => r + case (l, "") => l + case (l, r) => s"$l + $r" + }, + ) + } + + object NChildren { + implicit def makeStatic(i: Int): NChildren = NChildren(static = i) + } + + sealed abstract class ChildrenSeq { + def asDyn: ChildrenSeq.Dynamic + + override def toString: String = asDyn.children + + def ++(other: ChildrenSeq): ChildrenSeq = (this, other) match { + case (ChildrenSeq.Static(Seq()), r) => r + case (l, ChildrenSeq.Static(Seq())) => l + case (ChildrenSeq.Static(l), ChildrenSeq.Static(r)) => ChildrenSeq.Static(l ++ r) + case _ => ChildrenSeq.Dynamic(this + " ++ " + other) + } + } + + object ChildrenSeq { + val empty: Static = Static(Seq.empty) + + final case class Static(children: Seq[String]) extends ChildrenSeq { + def asDyn: Dynamic = Dynamic(s"FastSeq(${children.mkString(", ")})") + } + + final case class Dynamic(children: String) extends ChildrenSeq { + def asDyn: Dynamic = this + } + } + + sealed abstract class ChildrenSeqSlice { + def hasStaticLen(l: Int): Boolean + def slice(relStart: NChildren, len: NChildren): ChildrenSeqSlice + def apply(i: NChildren): String + } + + object ChildrenSeqSlice { + def apply(seq: ChildrenSeq, baseLen: NChildren, start: NChildren, len: NChildren) + : ChildrenSeqSlice = + Range(seq, baseLen, start, len) + + def apply(value: String): ChildrenSeqSlice = Singleton(value) + + final private case class Range( + seq: ChildrenSeq, + baseLen: NChildren, + start: NChildren, + len: NChildren, + ) extends ChildrenSeqSlice { + override def toString: String = if (start.hasStaticValue(0) && len == baseLen) + seq.toString + else + s"$seq.slice($start, ${start + len})" + + override def hasStaticLen(l: Int): Boolean = len.dynamic == "" && len.static == l + + override def slice(relStart: NChildren, len: NChildren): ChildrenSeqSlice = + Range(seq, baseLen, start + relStart, len) + + override def apply(i: NChildren): String = s"$seq(${start + i})" + } + + final private case class Singleton(value: String) extends ChildrenSeqSlice { + override def toString: String = value + override def hasStaticLen(l: Int): Boolean = l == 1 + + override def slice(relStart: NChildren, len: NChildren): ChildrenSeqSlice = { + if (relStart.hasStaticValue(0) && len.hasStaticValue(1)) + this + else { + assert( + (relStart.hasStaticValue(0) && len.hasStaticValue(0)) + || (relStart.hasStaticValue(1) && len.hasStaticValue(1)) + ) + Empty + } + } + + override def apply(i: NChildren): String = { + assert(i.hasStaticValue(0)) + value + } + } + + private case object Empty extends ChildrenSeqSlice { + override def toString: String = "IndexedSeq.empty" + override def hasStaticLen(l: Int): Boolean = l == 0 + + override def slice(relStart: NChildren, len: NChildren): ChildrenSeqSlice = { + assert(relStart.hasStaticValue(0) && len.hasStaticValue(0)) + this + } + + override def apply(i: NChildren): String = ??? + } + } + + final case class NamedAttOrChildPack( + name: String, + pack: AttOrChildPack, + isVar: Boolean = false, + default: Option[String] = None, + ) extends NamedAttOrChildPack_Interface { + def mutable: NamedAttOrChildPack = copy(isVar = true) + def withDefault(value: String): NamedAttOrChildPack = copy(default = Some(value)) + + def generateDeclaration: String = + s"${if (isVar) "var " else ""}$name: ${pack.generateDeclaration}${default.map(d => s" = $d").getOrElse("")}" + + def constraints: Seq[String] = pack.constraints.map(_(name)) + def nChildren: NChildren = pack.nChildren(name) + def childrenSeq: ChildrenSeq = pack.childrenSeq(name) + } + + sealed abstract class AttOrChildPack extends AttOrChildPack_Interface { + def * : AttOrChildPack = Collection(this) + def + : AttOrChildPack = Collection(this, allowEmpty = false) + def ? : AttOrChildPack = Optional(this) + + def generateDeclaration: String + def constraints: Seq[String => String] = Seq.empty + def withConstraint(c: String => String): AttOrChildPack = Constrained(this, Seq(c)) + def nChildren(self: String): NChildren = NChildren() + def childrenSeq(self: String): ChildrenSeq + def copyWithNewChildren(self: String, newChildren: ChildrenSeqSlice): String + } + + final case class Constrained(value: AttOrChildPack, newConstraints: Seq[String => String]) + extends AttOrChildPack { + override def generateDeclaration: String = value.generateDeclaration + override def constraints: Seq[String => String] = value.constraints ++ newConstraints + + override def withConstraint(c: String => String): AttOrChildPack = + copy(newConstraints = newConstraints :+ c) + + override def nChildren(self: String): NChildren = value.nChildren(self) + override def childrenSeq(self: String): ChildrenSeq = value.childrenSeq(self) + + override def copyWithNewChildren(self: String, newChildren: ChildrenSeqSlice): String = + value.copyWithNewChildren(self, newChildren) + } + + final case class Att(typ: String) extends AttOrChildPack { + override def generateDeclaration: String = typ + override def childrenSeq(self: String): ChildrenSeq = ChildrenSeq.empty + + override def copyWithNewChildren(self: String, newChildren: ChildrenSeqSlice): String = { + assert(newChildren.hasStaticLen(0)) + self + } + } + + final case class Child(t: String = "IR") extends AttOrChildPack { + override def generateDeclaration: String = t + override def nChildren(self: String): NChildren = 1 + override def childrenSeq(self: String): ChildrenSeq = ChildrenSeq.Static(Seq(self)) + + override def copyWithNewChildren(self: String, newChildren: ChildrenSeqSlice): String = { + assert(newChildren.hasStaticLen(1)) + newChildren(0) + s".asInstanceOf[$t]" + } + } + + case object Binding extends AttOrChildPack { + override def generateDeclaration: String = "Binding" + override def nChildren(self: String): NChildren = 1 + override def childrenSeq(self: String): ChildrenSeq = ChildrenSeq.Static(Seq(s"$self.value")) + + override def copyWithNewChildren(self: String, newChildren: ChildrenSeqSlice): String = { + assert(newChildren.hasStaticLen(1)) + s"$self.copy(value = ${newChildren(0)}.asInstanceOf[IR])" + } + } + + final case class Optional(elt: AttOrChildPack) extends AttOrChildPack { + override def generateDeclaration: String = s"Option[${elt.generateDeclaration}]" + + override def nChildren(self: String): NChildren = elt.nChildren("").getStatic match { + case Some(0) => 0 + case _ => NChildren(dynamic = s"$self.map(x => ${elt.nChildren("x")}).sum") + } + + override def childrenSeq(self: String): ChildrenSeq = elt match { + case Att(_) | Constrained(Att(_), _) => ChildrenSeq.empty + case Child(_) | Constrained(Child(_), _) => ChildrenSeq.Dynamic(s"$self.toSeq") + case _ => + ChildrenSeq.Dynamic(s"$self.toSeq.flatMap(x => ${elt.childrenSeq("x")})") + } + + override def copyWithNewChildren(self: String, newChildren: ChildrenSeqSlice): String = + s"$self.map(x => ${elt.copyWithNewChildren("x", newChildren.slice(0, 1))})" + } + + final case class Collection(elt: AttOrChildPack, allowEmpty: Boolean = true) + extends AttOrChildPack { + override def generateDeclaration: String = s"IndexedSeq[${elt.generateDeclaration}]" + + override def constraints: Seq[String => String] = { + val nestedConstraints: Seq[String => String] = if (elt.constraints.nonEmpty) + Seq(elts => + s"$elts.forall(x => ${elt.constraints.map(c => s"(${c("x")})").mkString(" && ")})" + ) + else Seq() + val nonEmptyConstraint: Seq[String => String] = + if (allowEmpty) Seq() else Seq(x => s"$x.nonEmpty") + nestedConstraints ++ nonEmptyConstraint + } + + override def nChildren(self: String): NChildren = elt.nChildren("").getStatic match { + case Some(0) => 0 + case Some(1) => NChildren(dynamic = s"$self.size") + case _ => NChildren(dynamic = s"$self.map(x => ${elt.nChildren("x")}).sum") + } + + override def childrenSeq(self: String): ChildrenSeq = elt match { + case Att(_) | Constrained(Att(_), _) => ChildrenSeq.empty + case Child(_) | Constrained(Child(_), _) => ChildrenSeq.Dynamic(self) + case _ => ChildrenSeq.Dynamic(s"$self.flatMap(x => ${elt.childrenSeq("x")})") + } + + override def copyWithNewChildren(self: String, newChildren: ChildrenSeqSlice): String = + elt match { + case Att(_) | Constrained(Att(_), _) => self + case Child(t) => s"$newChildren.map(_.asInstanceOf[$t])" + case Constrained(Child(t), _) => s"$newChildren.map(_.asInstanceOf[$t])" + case _ => + assert(elt.nChildren("").hasStaticValue(1)) + s"($self, $newChildren).zipped.map { (x, newChild) => ${elt.copyWithNewChildren("x", ChildrenSeqSlice("newChild"))} }" + } + } + + final case class Tup(elts: AttOrChildPack*) extends AttOrChildPack { + override def generateDeclaration: String = + elts.map(_.generateDeclaration).mkString("(", ", ", ")") + + override def constraints: Seq[String => String] = + elts.zipWithIndex.flatMap { case (elt, i) => + elt.constraints.map(c => (t: String) => c(s"$t._${i + 1}")) + } + + override def nChildren(self: String): NChildren = + elts.zipWithIndex + .map { case (elt, i) => elt.nChildren(s"$self._${i + 1}") } + .foldLeft(NChildren())(_ + _) + + override def childrenSeq(self: String): ChildrenSeq = + elts + .zipWithIndex.map { case (elt, i) => elt.childrenSeq(s"$self._${i + 1}") } + .foldLeft[ChildrenSeq](ChildrenSeq.empty)(_ ++ _) + + override def copyWithNewChildren(self: String, newChildren: ChildrenSeqSlice): String = { + val offsets = elts.zipWithIndex.scanLeft(NChildren()) { case (acc, (elt, i)) => + acc + elt.nChildren(s"$self._${i + 1}") + } + elts.zipWithIndex.zip(offsets).map { case ((elt, i), n) => + elt.copyWithNewChildren( + s"$self._${i + 1}", + newChildren.slice(n, elt.nChildren(s"$self._${i + 1}")), + ) + }.mkString("(", ", ", ")") + } + } + + case class IR( + name: String, + attsAndChildren: Seq[NamedAttOrChildPack], + traits: Seq[Trait] = Seq.empty, + constraints: Seq[String] = Seq.empty, + extraMethods: Seq[String] = Seq.empty, + staticMethods: Seq[String] = Seq.empty, + docstring: String = "", + hasCompanionExtension: Boolean = false, + ) extends IR_Interface { + def withTraits(newTraits: Trait*): IR = copy(traits = traits ++ newTraits) + def withMethod(methodDef: String): IR = copy(extraMethods = extraMethods :+ methodDef) + def typed(typ: String): IR = withTraits(TypedIR(typ)) + def withConstraint(c: String): IR = copy(constraints = constraints :+ c) + + def withCompanionExtension: IR = copy(hasCompanionExtension = true) + + def withClassExtension: IR = withTraits(Trait(s"${name}Ext")) + + def withDocstring(docstring: String): IR = copy(docstring = docstring) + + private def nChildren: NChildren = attsAndChildren.foldLeft(NChildren())(_ + _.nChildren) + + private def children: String = + attsAndChildren.foldLeft[ChildrenSeq](ChildrenSeq.empty)(_ ++ _.childrenSeq).toString + + private def childrenOffsets: Seq[NChildren] = + attsAndChildren.scanLeft(NChildren())(_ + _.nChildren) + + private def copyMethod: String = { + val decl = s"override def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): $name = " + val assertion = s"assert(newChildren.length == $nChildren)" + val body = name + attsAndChildren.zipWithIndex.map { case (x, i) => + try + x.pack.copyWithNewChildren( + x.name, + ChildrenSeqSlice( + ChildrenSeq.Dynamic("newChildren"), + nChildren, + childrenOffsets(i), + x.nChildren, + ), + ) + catch { + case _: NotImplementedError => + assert(false, name) + } + }.mkString("(", ", ", ")") + decl + "{\n " + assertion + "\n " + body + "}" + } + + private def paramList = s"$name(${attsAndChildren.map(_.generateDeclaration).mkString(", ")})" + + private def classDecl = + s"final case class $paramList extends IR" + traits.map(" with " + _.name).mkString + + private def classBody = { + val childrenSeqDef = if (nChildren.hasStaticValue(0)) + s"override def childrenSeq: IndexedSeq[BaseIR] = IndexedSeq.empty" + else + s"override lazy val childrenSeq: IndexedSeq[BaseIR] = $children" + val extraMethods = + this.extraMethods :+ childrenSeqDef :+ copyMethod + val constraints = this.constraints ++ attsAndChildren.flatMap(_.constraints) + if (constraints.nonEmpty || extraMethods.nonEmpty) { + ( + " {" + + (if (constraints.nonEmpty) + constraints.map(c => s" require($c)").mkString("\n", "\n", "\n") + else "") + + ( + if (extraMethods.nonEmpty) + extraMethods.map(" " + _).mkString("\n", "\n", "\n") + else "" + ) + + "}" + ) + } else "" + } + + private def classDef = + (if (docstring.nonEmpty) s"\n/** $docstring*/\n" else "") + classDecl + classBody + + private def companionDef = if (hasCompanionExtension) + s"object $name extends ${name}CompanionExt\n" + else "" + + def generateDef: String = companionDef + classDef + "\n" + } +} + +object Main { + val irdsl: IRDSL = IRDSL_Impl + import irdsl._ + + private val errorID = ("errorID", att("Int")).withDefault("ErrorIDs.NO_ERROR") + + private def _typ(t: String = "Type") = ("_typ", att(t)) + + private val mmPerElt = ("requiresMemoryManagementPerElement", att("Boolean")).withDefault("false") + + private def allNodes: Seq[IR] = { + // scalafmt: {} + + val r = Seq.newBuilder[IR] + + r += node("I32", ("x", att("Int"))).withTraits(TrivialIR) + r += node("I64", ("x", att("Long"))).withTraits(TrivialIR) + r += node("F32", ("x", att("Float"))).withTraits(TrivialIR) + r += node("F64", ("x", att("Double"))).withTraits(TrivialIR) + r += node("Str", ("x", att("String"))).withTraits(TrivialIR) + .withMethod( + "override def toString(): String = s\"\"\"Str(\"${StringEscapeUtils.escapeString(x)}\")\"\"\"" + ): @nowarn("msg=possible missing interpolator") + r += node("True").withTraits(TrivialIR) + r += node("False").withTraits(TrivialIR) + r += node("Void").withTraits(TrivialIR) + r += node("NA", _typ()).withTraits(TrivialIR) + r += node("UUID4", ("id", att("String"))) + .withDocstring( + """WARNING! This node can only be used when trying to append a one-off, + |random string that will not be reused elsewhere in the pipeline. + |Any other uses will need to write and then read again; this node is non-deterministic + |and will not e.g. exhibit the correct semantics when self-joining on streams. + |""".stripMargin + ) + .withCompanionExtension + + r += node( + "Literal", + ("_typ", att("Type").withConstraint(self => s"!CanEmit($self)")), + ("value", att("Annotation").withConstraint(self => s"$self != null")), + ) + .withCompanionExtension + .withMethod( + """// expensive, for debugging + |// require(SafeRow.isSafe(value)) + |// assert(_typ.typeCheck(value), s"literal invalid:\n ${_typ}\n $value") + |""".stripMargin + ) + + r += node( + "EncodedLiteral", + ( + "codec", + att("AbstractTypedCodecSpec").withConstraint(self => + s"!CanEmit($self.encodedVirtualType)" + ), + ), + ("value", att("WrappedByteArrays").withConstraint(self => s"$self != null")), + ) + .withCompanionExtension + + r += node("Cast", ("v", child), _typ()) + r += node("CastRename", ("v", child), _typ()) + + r += node("IsNA", ("value", child)) + r += node("Coalesce", ("values", child.+)) + r += node("Consume", ("value", child)) + + r += node("If", ("cond", child), ("cnsq", child), ("altr", child)) + r += node("Switch", ("x", child), ("default", child), ("cases", child.*)) + .withMethod("override lazy val size: Int = 2 + cases.length") + + r += node("Block", ("bindings", binding.*), ("body", child)) + .withMethod("override lazy val size: Int = bindings.length + 1") + .withCompanionExtension + + r += node("Ref", name, _typ().mutable).withTraits(BaseRef) + + r += node( + "TailLoop", + name, + ("params", tup(name, child).*), + ("resultType", att("Type")), + ("body", child), + ) + .withDocstring( + """Recur can't exist outside of loop. Loops can be nested, but we can't call outer + |loops in terms of inner loops so there can only be one loop "active" in a given + |context. + |""".stripMargin + ) + .withMethod("lazy val paramIdx: Map[Name, Int] = params.map(_._1).zipWithIndex.toMap") + r += node("Recur", name, ("args", child.*), _typ().mutable).withTraits(BaseRef) + + r += node("RelationalLet", name, ("value", child), ("body", child)) + r += node("RelationalRef", name, _typ()).withTraits(BaseRef) + + r += node("ApplyBinaryPrimOp", ("op", att("BinaryOp")), ("l", child), ("r", child)) + r += node("ApplyUnaryPrimOp", ("op", att("UnaryOp")), ("x", child)) + r += node( + "ApplyComparisonOp", + ("op", att("ComparisonOp[_]")).mutable, + ("l", child), + ("r", child), + ) + + r += node("MakeArray", ("args", child.*), _typ("TArray")).withCompanionExtension + r += node("MakeStream", ("args", child.*), _typ("TStream"), mmPerElt).withCompanionExtension + r += node("ArrayRef", ("a", child), ("i", child), errorID) + r += node( + "ArraySlice", + ("a", child), + ("start", child), + ("stop", child.?), + ("step", child).withDefault("I32(1)"), + errorID, + ) + r += node("ArrayLen", ("a", child)) + r += node("ArrayZeros", ("length", child)) + r += node( + "ArrayMaximalIndependentSet", + ("edges", child), + ("tieBreaker", tup(name, name, child).?), + ) + + r += node("StreamIota", ("start", child), ("step", child), mmPerElt) + .withDocstring( + """[[StreamIota]] is an infinite stream producer, whose element is an integer starting at + |`start`, updated by `step` at each iteration. The name comes from APL: + |[[https://stackoverflow.com/questions/9244879/what-does-iota-of-stdiota-stand-for]] + |""".stripMargin + ) + r += node("StreamRange", ("start", child), ("stop", child), ("step", child), mmPerElt, errorID) + + r += node("ArraySort", ("a", child), ("left", name), ("right", name), ("lessThan", child)) + .withCompanionExtension + + r += node("ToSet", ("a", child)) + r += node("ToDict", ("a", child)) + r += node("ToArray", ("a", child)) + r += node("CastToArray", ("a", child)) + r += node("ToStream", ("a", child), mmPerElt) + r += node("GroupByKey", ("collection", child)) + + r += node( + "StreamBufferedAggregate", + ("streamChild", child), + ("initAggs", child), + ("newKey", child), + ("seqOps", child), + name, + ("aggSignature", att("IndexedSeq[PhysicalAggSig]")), + ("bufferSize", att("Int")), + ) + r += node( + "LowerBoundOnOrderedCollection", + ("orderedCollection", child), + ("elem", child), + ("onKey", att("Boolean")), + ) + + r += node("RNGStateLiteral") + r += node("RNGSplit", ("state", child), ("dynBitstring", child)) + + r += node("StreamLen", ("a", child)) + r += node("StreamGrouped", ("a", child), ("groupSize", child)) + r += node("StreamGroupByKey", ("a", child), key, ("missingEqual", att("Boolean"))) + r += node("StreamMap", ("a", child), name, ("body", child)).typed("TStream") + r += node("StreamTakeWhile", ("a", child), ("elementName", name), ("body", child)) + .typed("TStream") + r += node("StreamDropWhile", ("a", child), ("elementName", name), ("body", child)) + .typed("TStream") + r += node("StreamTake", ("a", child), ("num", child)).typed("TStream") + r += node("StreamDrop", ("a", child), ("num", child)).typed("TStream") + + r += node( + "SeqSample", + ("totalRange", child), + ("numToSample", child), + ("rngState", child), + mmPerElt, + ) + .typed("TStream") + .withDocstring( + """Generate, in ascending order, a uniform random sample, without replacement, of + |numToSample integers in the range [0, totalRange) + |""".stripMargin + ) + + r += node( + "StreamDistribute", + ("child", child), + ("pivots", child), + ("path", child), + ("comparisonOp", att("ComparisonOp[_]")), + ("spec", att("AbstractTypedCodecSpec")), + ) + .withDocstring( + """Take the child stream and sort each element into buckets based on the provided pivots. + |The first and last elements of pivots are the endpoints of the first and last interval + |respectively, should not be contained in the dataset. + |""".stripMargin + ) + + r += node( + "StreamWhiten", + ("stream", child), + ("newChunk", att("String")), + ("prevWindow", att("String")), + ("vecSize", att("Int")), + ("windowSize", att("Int")), + ("chunkSize", att("Int")), + ("blockSize", att("Int")), + ("normalizeAfterWhiten", att("Boolean")), + ) + .typed("TStream") + .withDocstring( + """"Whiten" a stream of vectors by regressing out from each vector all components + |in the direction of vectors in the preceding window. For efficiency, takes + |a stream of "chunks" of vectors. + |Takes a stream of structs, with two designated fields: `prevWindow` is the + |previous window (e.g. from the previous partition), if there is one, and + |`newChunk` is the new chunk to whiten. + |""".stripMargin + ) + + r += node( + "StreamZip", + ("as", child.*), + ("names", name.*), + ("body", child), + ("behavior", att("ArrayZipBehavior.ArrayZipBehavior")), + errorID, + ) + .typed("TStream") + + r += node("StreamMultiMerge", ("as", child.*), key).typed("TStream") + + r += node( + "StreamZipJoinProducers", + ("contexts", child), + ("ctxName", name), + ("makeProducer", child), + key, + ("curKey", name), + ("curVals", name), + ("joinF", child), + ) + .typed("TStream") + + r += node( + "StreamZipJoin", + ("as", child.*), + key, + ("curKey", name), + ("curVals", name), + ("joinF", child), + ) + .typed("TStream") + .withDocstring( + """The StreamZipJoin node assumes that input streams have distinct keys. If input streams do not + |have distinct keys, the key that is included in the result is undefined, but is likely the + |last. + |""".stripMargin + ) + + r += node("StreamFilter", ("a", child), name, ("cond", child)).typed("TStream") + r += node("StreamFlatMap", ("a", child), name, ("cond", child)).typed("TStream") + + r += node( + "StreamFold", + ("a", child), + ("zero", child), + ("accumName", name), + ("valueName", name), + ("body", child), + ) + + r += node( + "StreamFold2", + ("a", child), + ("accum", tup(name, child).*), + ("valueName", name), + ("seq", child.*), + ("result", child), + ) + .withConstraint("accum.length == seq.length") + .withMethod("val nameIdx: Map[Name, Int] = accum.map(_._1).zipWithIndex.toMap") + .withCompanionExtension + + r += node( + "StreamScan", + ("a", child), + ("zero", child), + ("accumName", name), + ("valueName", name), + ("body", child), + ) + .typed("TStream") + + r += node("StreamFor", ("a", child), ("valueName", name), ("body", child)).typed("TVoid.type") + r += node("StreamAgg", ("a", child), name, ("query", child)) + r += node("StreamAggScan", ("a", child), name, ("query", child)).typed("TStream") + + r += node( + "StreamLeftIntervalJoin", + ("left", child), + ("right", child), + ("lKeyFieldName", att("String")), + ("rIntervalFieldName", att("String")), + ("lname", name), + ("rname", name), + ("body", child), + ) + .typed("TStream") + + r += node( + "StreamJoinRightDistinct", + ("left", child), + ("right", child), + ("lKey", att("IndexedSeq[String]")), + ("rKey", att("IndexedSeq[String]")), + ("l", name), + ("r", name), + ("joinF", child), + ("joinType", att("String")), + ) + .typed("TStream").withClassExtension + + r += node( + "StreamLocalLDPrune", + ("child", child), + ("r2Threshold", child), + ("windowSize", child), + ("maxQueueSize", child), + ("nSamples", child), + ) + .typed("TStream") + + r += node("MakeNDArray", ("data", child), ("shape", child), ("rowMajor", child), errorID) + .withTraits(NDArrayIR).withCompanionExtension + r += node("NDArrayShape", ("nd", child)) + r += node("NDArrayReshape", ("nd", child), ("shape", child), errorID).withTraits(NDArrayIR) + r += node("NDArrayConcat", ("nds", child), ("axis", att("Int"))).withTraits(NDArrayIR) + r += node("NDArrayRef", ("nd", child), ("idxs", child.*), errorID) + r += node("NDArraySlice", ("nd", child), ("slices", child)).withTraits(NDArrayIR) + r += node("NDArrayFilter", ("nd", child), ("keep", child.*)).withTraits(NDArrayIR) + r += node("NDArrayMap", ("nd", child), ("valueName", name), ("body", child)) + .withTraits(NDArrayIR) + r += node( + "NDArrayMap2", + ("l", child), + ("r", child), + ("lName", name), + ("rName", name), + ("body", child), + errorID, + ) + .withTraits(NDArrayIR) + r += node("NDArrayReindex", ("nd", child), ("indexExpr", att("IndexedSeq[Int]"))) + .withTraits(NDArrayIR) + r += node("NDArrayAgg", ("nd", child), ("axes", att("IndexedSeq[Int]"))) + r += node("NDArrayWrite", ("nd", child), ("path", child)).typed("TVoid.type") + r += node("NDArrayMatMul", ("l", child), ("r", child), errorID).withTraits(NDArrayIR) + r += node("NDArrayQR", ("nd", child), ("mode", att("String")), errorID).withCompanionExtension + r += node( + "NDArraySVD", + ("nd", child), + ("fullMatrices", att("Boolean")), + ("computeUV", att("Boolean")), + errorID, + ) + .withCompanionExtension + r += node("NDArrayEigh", ("nd", child), ("eigvalsOnly", att("Boolean")), errorID) + .withCompanionExtension + r += node("NDArrayInv", ("nd", child), errorID).withTraits(NDArrayIR).withCompanionExtension + + val isScan = ("isScan", att("Boolean")) + + r += node("AggFilter", ("cond", child), ("aggIR", child), isScan) + r += node("AggExplode", ("array", child), name, ("aggBody", child), isScan) + r += node("AggGroupBy", ("key", child), ("aggIR", child), isScan) + r += node( + "AggArrayPerElement", + ("a", child), + ("elementName", name), + ("indexName", name), + ("aggBody", child), + ("knownLength", child.?), + isScan, + ) + r += node( + "AggFold", + ("zero", child), + ("seqOp", child), + ("combOp", child), + ("accumName", name), + ("otherAccumName", name), + isScan, + ) + .withCompanionExtension + + r += node( + "ApplyAggOp", + ("initOpArgs", child.*), + ("seqOpArgs", child.*), + ("aggSig", att("AggSignature")), + ) + .withClassExtension.withCompanionExtension + r += node( + "ApplyScanOp", + ("initOpArgs", child.*), + ("seqOpArgs", child.*), + ("aggSig", att("AggSignature")), + ) + .withClassExtension.withCompanionExtension + r += node("InitOp", ("i", att("Int")), ("args", child.*), ("aggSig", att("PhysicalAggSig"))) + r += node("SeqOp", ("i", att("Int")), ("args", child.*), ("aggSig", att("PhysicalAggSig"))) + r += node("CombOp", ("i1", att("Int")), ("i2", att("Int")), ("aggSig", att("PhysicalAggSig"))) + r += node("ResultOp", ("idx", att("Int")), ("aggSig", att("PhysicalAggSig"))) + .withCompanionExtension + r += node("CombOpValue", ("i", att("Int")), ("value", child), ("aggSig", att("PhysicalAggSig"))) + r += node("AggStateValue", ("i", att("Int")), ("aggSig", att("AggStateSig"))) + r += node( + "InitFromSerializedValue", + ("i", att("Int")), + ("value", child), + ("aggSig", att("AggStateSig")), + ) + r += node( + "SerializeAggs", + ("startIdx", att("Int")), + ("serializedIdx", att("Int")), + ("spec", att("BufferSpec")), + ("aggSigs", att("IndexedSeq[AggStateSig]")), + ) + r += node( + "DeserializeAggs", + ("startIdx", att("Int")), + ("serializedIdx", att("Int")), + ("spec", att("BufferSpec")), + ("aggSigs", att("IndexedSeq[AggStateSig]")), + ) + r += node( + "RunAgg", + ("body", child), + ("result", child), + ("signature", att("IndexedSeq[AggStateSig]")), + ) + r += node( + "RunAggScan", + ("array", child), + name, + ("init", child), + ("seqs", child), + ("result", child), + ("signature", att("IndexedSeq[AggStateSig]")), + ) + + r += node("MakeStruct", ("fields", tup(att("String"), child).*)).typed("TStruct") + r += node("SelectFields", ("old", child), ("fields", att("IndexedSeq[String]"))) + .typed("TStruct") + r += node( + "InsertFields", + ("old", child), + ("fields", tup(att("String"), child).*), + ("fieldOrder", att("Option[IndexedSeq[String]]")).withDefault("None"), + ) + .typed("TStruct") + r += node("GetField", ("o", child), ("name", att("String"))) + r += node("MakeTuple", ("fields", tup(att("Int"), child).*)) + .typed("TTuple").withCompanionExtension + r += node("GetTupleElement", ("o", child), ("idx", att("Int"))) + + r += node("In", ("i", att("Int")), ("_typ", att("EmitParamType"))) + .withDocstring("Function input").withCompanionExtension + + r += node("Die", ("message", child), ("_typ", att("Type")), errorID).withCompanionExtension + r += node("Trap", ("child", child)).withDocstring( + """The Trap node runs the `child` node with an exception handler. If the child throws a + |HailException (user exception), then we return the tuple ((msg, errorId), NA). If the child + |throws any other exception, we raise that exception. If the child does not throw, then we + |return the tuple (NA, child value). + |""".stripMargin + ) + r += node("ConsoleLog", ("message", child), ("result", child)) + + r += node( + "ApplyIR", + ("function", att("String")), + ("typeArgs", att("Seq[Type]")), + ("args", child.*), + ("returnType", att("Type")), + errorID, + ) + .withClassExtension + + r += node( + "Apply", + ("function", att("String")), + ("typeArgs", att("Seq[Type]")), + ("args", child.*), + ("returnType", att("Type")), + errorID, + ).withTraits(ApplyNode()) + + r += node( + "ApplySeeded", + ("function", att("String")), + ("_args", child.*), + ("rngState", child), + ("staticUID", att("Long")), + ("returnType", att("Type")), + ).withTraits(ApplyNode()) + .withMethod("val args = rngState +: _args") + .withMethod("val typeArgs: Seq[Type] = Seq.empty[Type]") + + r += node( + "ApplySpecial", + ("function", att("String")), + ("typeArgs", att("Seq[Type]")), + ("args", child.*), + ("returnType", att("Type")), + errorID, + ).withTraits(ApplyNode(missingnessAware = true)) + + r += node("LiftMeOut", ("child", child)) + + r += node("TableCount", tableChild) + r += node("MatrixCount", matrixChild) + r += node("TableAggregate", tableChild, ("query", child)) + r += node("MatrixAggregate", matrixChild, ("query", child)) + r += node("TableWrite", tableChild, ("writer", att("TableWriter"))) + r += node( + "TableMultiWrite", + ("_children", tableChild.*), + ("writer", att("WrappedMatrixNativeMultiWriter")), + ) + r += node("TableGetGlobals", tableChild) + r += node("TableCollect", tableChild) + r += node("MatrixWrite", matrixChild, ("writer", att("MatrixWriter"))) + r += node( + "MatrixMultiWrite", + ("_children", matrixChild.*), + ("writer", att("MatrixNativeMultiWriter")), + ) + r += node("TableToValueApply", tableChild, ("function", att("TableToValueFunction"))) + r += node("MatrixToValueApply", matrixChild, ("function", att("MatrixToValueFunction"))) + r += node( + "BlockMatrixToValueApply", + blockMatrixChild, + ("function", att("BlockMatrixToValueFunction")), + ) + r += node("BlockMatrixCollect", blockMatrixChild) + r += node("BlockMatrixWrite", blockMatrixChild, ("writer", att("BlockMatrixWriter"))) + r += node( + "BlockMatrixMultiWrite", + ("blockMatrices", blockMatrixChild.*), + ("writer", att("BlockMatrixMultiWriter")), + ) + + r += node( + "CollectDistributedArray", + ("contexts", child), + ("globals", child), + ("cname", name), + ("gname", name), + ("body", child), + ("dynamicID", child), + ("staticID", att("String")), + ("tsd", att("Option[TableStageDependency]")).withDefault("None"), + ) + + r += node( + "ReadPartition", + ("context", child), + ("rowType", att("TStruct")), + ("reader", att("PartitionReader")), + ) + r += node( + "WritePartition", + ("value", child), + ("writeCtx", child), + ("writer", att("PartitionWriter")), + ) + r += node("WriteMetadata", ("writeAnnotations", child), ("writer", att("MetadataWriter"))) + r += node( + "ReadValue", + ("path", child), + ("reader", att("ValueReader")), + ("requestedType", att("Type")), + ) + r += node( + "WriteValue", + ("value", child), + ("path", child), + ("writer", att("ValueWriter")), + ("stagingFile", att("Option[IR]")).withDefault("None"), + ) + + r.result() + } + + @main + def main(path: String) = { + val pack = "package is.hail.expr.ir.defs" + val imports = Seq( + "is.hail.annotations.Annotation", + "is.hail.io.{AbstractTypedCodecSpec, BufferSpec}", + "is.hail.types.virtual.{Type, TArray, TStream, TVoid, TStruct, TTuple}", + "is.hail.utils.{FastSeq, StringEscapeUtils}", + "is.hail.expr.ir.{BaseIR, IR, TableIR, MatrixIR, BlockMatrixIR, Name, UnaryOp, BinaryOp, " + + "ComparisonOp, CanEmit, AggSignature, EmitParamType, TableWriter, " + + "WrappedMatrixNativeMultiWriter, MatrixWriter, MatrixNativeMultiWriter, BlockMatrixWriter, " + + "BlockMatrixMultiWriter, ValueReader, ValueWriter}", + "is.hail.expr.ir.lowering.TableStageDependency", + "is.hail.expr.ir.agg.{PhysicalAggSig, AggStateSig}", + "is.hail.expr.ir.functions.{UnseededMissingnessAwareJVMFunction, " + + "UnseededMissingnessObliviousJVMFunction, TableToValueFunction, MatrixToValueFunction, " + + "BlockMatrixToValueFunction}", + "is.hail.expr.ir.defs.exts._", + ) + val gen = pack + "\n\n" + imports.map(i => s"import $i").mkString("\n") + "\n\n" + allNodes.map( + _.generateDef + ).mkString("\n") + os.write(os.Path(path) / "IR_gen.scala", gen) + } + + def main(args: Array[String]): Unit = ParserForMethods(this).runOrExit(args) +} diff --git a/hail/hail/src/is/hail/expr/Parser.scala b/hail/hail/src/is/hail/expr/Parser.scala index b71ea4961e4..5c581e8a136 100644 --- a/hail/hail/src/is/hail/expr/Parser.scala +++ b/hail/hail/src/is/hail/expr/Parser.scala @@ -27,9 +27,8 @@ object ParserUtils { fatal( s"""$msg |$prefix$lineContents - |${" " * prefix.length}${lineContents.take(pos.column - 1).map { c => - if (c == '\t') c else ' ' - }}^""".stripMargin + |${" " * prefix.length}${lineContents.take(pos.column - 1) + .map(c => if (c == '\t') c else ' ')}^""".stripMargin ) } @@ -39,9 +38,8 @@ object ParserUtils { fatal( s"""$msg |$prefix$lineContents - |${" " * prefix.length}${lineContents.take(pos.column - 1).map { c => - if (c == '\t') c else ' ' - }}^""".stripMargin, + |${" " * prefix.length}${lineContents.take(pos.column - 1) + .map(c => if (c == '\t') c else ' ')}^""".stripMargin, tr, ) } diff --git a/hail/hail/src/is/hail/expr/ir/Children.scala b/hail/hail/src/is/hail/expr/ir/Children.scala deleted file mode 100644 index 9eb780815d4..00000000000 --- a/hail/hail/src/is/hail/expr/ir/Children.scala +++ /dev/null @@ -1,267 +0,0 @@ -package is.hail.expr.ir - -import is.hail.expr.ir.defs._ -import is.hail.utils._ - -object Children { - private val none: IndexedSeq[BaseIR] = Array.empty[BaseIR] - - def apply(x: IR): IndexedSeq[BaseIR] = x match { - case I32(_) => none - case I64(_) => none - case F32(_) => none - case F64(_) => none - case Str(_) => none - case UUID4(_) => none - case True() => none - case False() => none - case Literal(_, _) => none - case EncodedLiteral(_, _) => none - case Void() => none - case Cast(v, _) => - Array(v) - case CastRename(v, _) => - Array(v) - case NA(_) => none - case IsNA(value) => - Array(value) - case Coalesce(values) => values.toFastSeq - case Consume(value) => FastSeq(value) - case If(cond, cnsq, altr) => - Array(cond, cnsq, altr) - case s @ Switch(x, default, cases) => - val children = Array.ofDim[BaseIR](s.size) - children(0) = x - children(1) = default - for (i <- cases.indices) children(2 + i) = cases(i) - children - case Block(bindings, body) => - val children = Array.ofDim[BaseIR](x.size) - for (i <- bindings.indices) children(i) = bindings(i).value - children(bindings.size) = body - children - case RelationalLet(_, value, body) => - Array(value, body) - case TailLoop(_, args, _, body) => - args.map(_._2).toFastSeq :+ body - case Recur(_, args, _) => - args.toFastSeq - case Ref(_, _) => - none - case RelationalRef(_, _) => - none - case ApplyBinaryPrimOp(_, l, r) => - Array(l, r) - case ApplyUnaryPrimOp(_, x) => - Array(x) - case ApplyComparisonOp(_, l, r) => - Array(l, r) - case MakeArray(args, _) => - args.toFastSeq - case MakeStream(args, _, _) => - args.toFastSeq - case ArrayRef(a, i, _) => - Array(a, i) - case ArraySlice(a, start, stop, step, _) => - if (stop.isEmpty) - Array(a, start, step) - else - Array(a, start, stop.get, step) - case ArrayLen(a) => - Array(a) - case StreamIota(start, step, _) => - Array(start, step) - case StreamRange(start, stop, step, _, _) => - Array(start, stop, step) - case SeqSample(totalRange, numToSample, rngState, _) => - Array(totalRange, numToSample, rngState) - case StreamDistribute(child, pivots, path, _, _) => - Array(child, pivots, path) - case StreamWhiten(stream, _, _, _, _, _, _, _) => - Array(stream) - case ArrayZeros(length) => - Array(length) - case MakeNDArray(data, shape, rowMajor, _) => - Array(data, shape, rowMajor) - case NDArrayShape(nd) => - Array(nd) - case NDArrayReshape(nd, shape, _) => - Array(nd, shape) - case NDArrayConcat(nds, _) => - Array(nds) - case ArraySort(a, _, _, lessThan) => - Array(a, lessThan) - case ArrayMaximalIndependentSet(a, tieBreaker) => - Array(a) ++ tieBreaker.map { case (_, _, tb) => tb } - case ToSet(a) => - Array(a) - case ToDict(a) => - Array(a) - case ToArray(a) => - Array(a) - case CastToArray(a) => - Array(a) - case ToStream(a, _) => - Array(a) - case LowerBoundOnOrderedCollection(orderedCollection, elem, _) => - Array(orderedCollection, elem) - case GroupByKey(collection) => - Array(collection) - case RNGStateLiteral() => none - case RNGSplit(state, split) => - Array(state, split) - case StreamLen(a) => - Array(a) - case StreamTake(a, len) => - Array(a, len) - case StreamDrop(a, len) => - Array(a, len) - case StreamGrouped(a, size) => - Array(a, size) - case StreamGroupByKey(a, _, _) => - Array(a) - case StreamMap(a, _, body) => - Array(a, body) - case StreamZip(as, _, body, _, _) => - as :+ body - case StreamZipJoin(as, _, _, _, joinF) => - as :+ joinF - case StreamZipJoinProducers(contexts, _, makeProducer, _, _, _, joinF) => - Array(contexts, makeProducer, joinF) - case StreamMultiMerge(as, _) => - as - case StreamFilter(a, _, cond) => - Array(a, cond) - case StreamTakeWhile(a, _, cond) => - Array(a, cond) - case StreamDropWhile(a, _, cond) => - Array(a, cond) - case StreamFlatMap(a, _, body) => - Array(a, body) - case StreamFold(a, zero, _, _, body) => - Array(a, zero, body) - case StreamFold2(a, accum, _, seq, result) => - Array(a) ++ accum.map(_._2) ++ seq ++ Array(result) - case StreamScan(a, zero, _, _, body) => - Array(a, zero, body) - case StreamJoinRightDistinct(left, right, _, _, _, _, join, _) => - Array(left, right, join) - case StreamFor(a, _, body) => - Array(a, body) - case StreamAgg(a, _, query) => - Array(a, query) - case StreamAggScan(a, _, query) => - Array(a, query) - case StreamBufferedAggregate(streamChild, initAggs, newKey, seqOps, _, _, _) => - Array(streamChild, initAggs, newKey, seqOps) - case StreamLocalLDPrune(streamChild, r2Threshold, windowSize, maxQueueSize, nSamples) => - Array(streamChild, r2Threshold, windowSize, maxQueueSize, nSamples) - case RunAggScan(array, _, init, seq, result, _) => - Array(array, init, seq, result) - case RunAgg(body, result, _) => - Array(body, result) - case NDArrayRef(nd, idxs, _) => - nd +: idxs - case NDArraySlice(nd, slices) => - Array(nd, slices) - case NDArrayFilter(nd, keep) => - nd +: keep - case NDArrayMap(nd, _, body) => - Array(nd, body) - case NDArrayMap2(l, r, _, _, body, _) => - Array(l, r, body) - case NDArrayReindex(nd, _) => - Array(nd) - case NDArrayAgg(nd, _) => - Array(nd) - case NDArrayMatMul(l, r, _) => - Array(l, r) - case NDArrayQR(nd, _, _) => - Array(nd) - case NDArraySVD(nd, _, _, _) => - Array(nd) - case NDArrayEigh(nd, _, _) => - Array(nd) - case NDArrayInv(nd, _) => - Array(nd) - case NDArrayWrite(nd, path) => - Array(nd, path) - case AggFilter(cond, aggIR, _) => - Array(cond, aggIR) - case AggExplode(array, _, aggBody, _) => - Array(array, aggBody) - case AggGroupBy(key, aggIR, _) => - Array(key, aggIR) - case AggArrayPerElement(a, _, _, aggBody, knownLength, _) => - Array(a, aggBody) ++ knownLength.toArray[IR] - case MakeStruct(fields) => - fields.map(_._2).toFastSeq - case SelectFields(old, _) => - Array(old) - case InsertFields(old, fields, _) => - (old +: fields.map(_._2)).toFastSeq - case InitOp(_, args, _) => args - case SeqOp(_, args, _) => args - case _: ResultOp => none - case _: AggStateValue => none - case _: CombOp => none - case CombOpValue(_, value, _) => Array(value) - case InitFromSerializedValue(_, value, _) => Array(value) - case SerializeAggs(_, _, _, _) => none - case DeserializeAggs(_, _, _, _) => none - case ApplyAggOp(initOpArgs, seqOpArgs, _) => - initOpArgs ++ seqOpArgs - case ApplyScanOp(initOpArgs, seqOpArgs, _) => - initOpArgs ++ seqOpArgs - case AggFold(zero, seqOp, combOp, _, _, _) => - Array(zero, seqOp, combOp) - case GetField(o, _) => - Array(o) - case MakeTuple(fields) => - fields.map(_._2).toFastSeq - case GetTupleElement(o, _) => - Array(o) - case In(_, _) => - none - case Die(message, _, _) => - Array(message) - case Trap(child) => Array(child) - case ConsoleLog(message, result) => - Array(message, result) - case ApplyIR(_, _, args, _, _) => - args.toFastSeq - case Apply(_, _, args, _, _) => - args.toFastSeq - case ApplySeeded(_, args, rngState, _, _) => - args.toFastSeq :+ rngState - case ApplySpecial(_, _, args, _, _) => - args.toFastSeq - // from MatrixIR - case MatrixWrite(child, _) => Array(child) - case MatrixMultiWrite(children, _) => children - // from TableIR - case TableCount(child) => Array(child) - case MatrixCount(child) => Array(child) - case TableGetGlobals(child) => Array(child) - case TableCollect(child) => Array(child) - case TableAggregate(child, query) => Array(child, query) - case MatrixAggregate(child, query) => Array(child, query) - case TableWrite(child, _) => Array(child) - case TableMultiWrite(children, _) => children - case TableToValueApply(child, _) => Array(child) - case MatrixToValueApply(child, _) => Array(child) - // from BlockMatrixIR - case BlockMatrixToValueApply(child, _) => Array(child) - case BlockMatrixCollect(child) => Array(child) - case BlockMatrixWrite(child, _) => Array(child) - case BlockMatrixMultiWrite(blockMatrices, _) => blockMatrices - case CollectDistributedArray(ctxs, globals, _, _, body, dynamicID, _, _) => - Array(ctxs, globals, body, dynamicID) - case ReadPartition(path, _, _) => Array(path) - case WritePartition(stream, ctx, _) => Array(stream, ctx) - case WriteMetadata(writeAnnotations, _) => Array(writeAnnotations) - case ReadValue(path, _, _) => Array(path) - case WriteValue(value, path, _, staged) => Array(value, path) ++ staged.toArray[IR] - case LiftMeOut(child) => Array(child) - } -} diff --git a/hail/hail/src/is/hail/expr/ir/Copy.scala b/hail/hail/src/is/hail/expr/ir/Copy.scala deleted file mode 100644 index f1529baf809..00000000000 --- a/hail/hail/src/is/hail/expr/ir/Copy.scala +++ /dev/null @@ -1,607 +0,0 @@ -package is.hail.expr.ir - -import is.hail.expr.ir.defs._ - -object Copy { - def apply(x: IR, newChildren: IndexedSeq[BaseIR]): IR = { - x match { - case I32(value) => I32(value) - case I64(value) => I64(value) - case F32(value) => F32(value) - case F64(value) => F64(value) - case Str(value) => Str(value) - case UUID4(id) => UUID4(id) - case True() => True() - case False() => False() - case Literal(typ, value) => Literal(typ, value) - case EncodedLiteral(codec, value) => EncodedLiteral(codec, value) - case Void() => Void() - case Cast(_, typ) => - assert(newChildren.length == 1) - Cast(newChildren(0).asInstanceOf[IR], typ) - case CastRename(_, typ) => - assert(newChildren.length == 1) - CastRename(newChildren(0).asInstanceOf[IR], typ) - case NA(t) => NA(t) - case IsNA(_) => - assert(newChildren.length == 1) - IsNA(newChildren(0).asInstanceOf[IR]) - case Coalesce(_) => - Coalesce(newChildren.map(_.asInstanceOf[IR])) - case Consume(_) => - Consume(newChildren(0).asInstanceOf[IR]) - case If(_, _, _) => - assert(newChildren.length == 3) - If( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - newChildren(2).asInstanceOf[IR], - ) - case s: Switch => - assert(s.size == newChildren.size) - Switch( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - newChildren.drop(2).asInstanceOf[IndexedSeq[IR]], - ) - case Block(bindings, _) => - assert(newChildren.length == x.size) - val newBindings = - (bindings, newChildren.init) - .zipped - .map { case (binding, ir: IR) => binding.copy(value = ir) } - Block(newBindings, newChildren.last.asInstanceOf[IR]) - case TailLoop(name, params, resultType, _) => - assert(newChildren.length == params.length + 1) - TailLoop( - name, - params.map(_._1).zip(newChildren.init.map(_.asInstanceOf[IR])), - resultType, - newChildren.last.asInstanceOf[IR], - ) - case Recur(name, args, t) => - assert(newChildren.length == args.length) - Recur(name, newChildren.map(_.asInstanceOf[IR]), t) - case Ref(name, t) => Ref(name, t) - case RelationalRef(name, t) => RelationalRef(name, t) - case RelationalLet(name, _, _) => - assert(newChildren.length == 2) - RelationalLet(name, newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) - case ApplyBinaryPrimOp(op, _, _) => - assert(newChildren.length == 2) - ApplyBinaryPrimOp(op, newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) - case ApplyUnaryPrimOp(op, _) => - assert(newChildren.length == 1) - ApplyUnaryPrimOp(op, newChildren(0).asInstanceOf[IR]) - case ApplyComparisonOp(op, _, _) => - assert(newChildren.length == 2) - ApplyComparisonOp(op, newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) - case MakeArray(args, typ) => - assert(args.length == newChildren.length) - MakeArray(newChildren.map(_.asInstanceOf[IR]), typ) - case MakeStream(args, typ, requiresMemoryManagementPerElement) => - assert(args.length == newChildren.length) - MakeStream(newChildren.map(_.asInstanceOf[IR]), typ, requiresMemoryManagementPerElement) - case ArrayRef(_, _, errorID) => - assert(newChildren.length == 2) - ArrayRef(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], errorID) - case ArraySlice(_, _, stop, _, errorID) => - if (stop.isEmpty) { - assert(newChildren.length == 3) - ArraySlice( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - None, - newChildren(2).asInstanceOf[IR], - errorID, - ) - } else { - assert(newChildren.length == 4) - ArraySlice( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - Some(newChildren(2).asInstanceOf[IR]), - newChildren(3).asInstanceOf[IR], - errorID, - ) - } - case ArrayLen(_) => - assert(newChildren.length == 1) - ArrayLen(newChildren(0).asInstanceOf[IR]) - case StreamIota(_, _, requiresMemoryManagementPerElement) => - assert(newChildren.length == 2) - StreamIota( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - requiresMemoryManagementPerElement, - ) - case StreamRange(_, _, _, requiresMemoryManagementPerElement, errorID) => - assert(newChildren.length == 3) - StreamRange( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - newChildren(2).asInstanceOf[IR], - requiresMemoryManagementPerElement, - errorID, - ) - case SeqSample(_, _, _, requiresMemoryManagementPerElement) => - assert(newChildren.length == 3) - SeqSample( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - newChildren(2).asInstanceOf[IR], - requiresMemoryManagementPerElement, - ) - case StreamDistribute(_, _, _, op, spec) => - StreamDistribute( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - newChildren(2).asInstanceOf[IR], - op, - spec, - ) - case StreamWhiten(_, newChunk, prevWindow, vecSize, windowSize, chunkSize, blockSize, - normalizeAfterWhiten) => - StreamWhiten( - newChildren(0).asInstanceOf[IR], - newChunk, - prevWindow, - vecSize, - windowSize, - chunkSize, - blockSize, - normalizeAfterWhiten, - ) - case ArrayZeros(_) => - assert(newChildren.length == 1) - ArrayZeros(newChildren(0).asInstanceOf[IR]) - case MakeNDArray(_, _, _, errorId) => - assert(newChildren.length == 3) - MakeNDArray( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - newChildren(2).asInstanceOf[IR], - errorId, - ) - case NDArrayShape(_) => - assert(newChildren.length == 1) - NDArrayShape(newChildren(0).asInstanceOf[IR]) - case NDArrayReshape(_, _, errorID) => - assert(newChildren.length == 2) - NDArrayReshape(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], errorID) - case NDArrayConcat(_, axis) => - assert(newChildren.length == 1) - NDArrayConcat(newChildren(0).asInstanceOf[IR], axis) - case NDArrayRef(_, _, errorId) => - NDArrayRef( - newChildren(0).asInstanceOf[IR], - newChildren.tail.map(_.asInstanceOf[IR]), - errorId, - ) - case NDArraySlice(_, _) => - assert(newChildren.length == 2) - NDArraySlice(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) - case NDArrayFilter(_, _) => - NDArrayFilter(newChildren(0).asInstanceOf[IR], newChildren.tail.map(_.asInstanceOf[IR])) - case NDArrayMap(_, name, _) => - assert(newChildren.length == 2) - NDArrayMap(newChildren(0).asInstanceOf[IR], name, newChildren(1).asInstanceOf[IR]) - case NDArrayMap2(_, _, lName, rName, _, errorID) => - assert(newChildren.length == 3) - NDArrayMap2( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - lName, - rName, - newChildren(2).asInstanceOf[IR], - errorID, - ) - case NDArrayReindex(_, indexExpr) => - assert(newChildren.length == 1) - NDArrayReindex(newChildren(0).asInstanceOf[IR], indexExpr) - case NDArrayAgg(_, axes) => - assert(newChildren.length == 1) - NDArrayAgg(newChildren(0).asInstanceOf[IR], axes) - case NDArrayMatMul(_, _, errorID) => - assert(newChildren.length == 2) - NDArrayMatMul(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], errorID) - case NDArrayQR(_, mode, errorID) => - assert(newChildren.length == 1) - NDArrayQR(newChildren(0).asInstanceOf[IR], mode, errorID) - case NDArraySVD(_, fullMatrices, computeUV, errorID) => - assert(newChildren.length == 1) - NDArraySVD(newChildren(0).asInstanceOf[IR], fullMatrices, computeUV, errorID) - case NDArrayEigh(_, eigvalsOnly, errorID) => - assert(newChildren.length == 1) - NDArrayEigh(newChildren(0).asInstanceOf[IR], eigvalsOnly, errorID) - case NDArrayInv(_, errorID) => - assert(newChildren.length == 1) - NDArrayInv(newChildren(0).asInstanceOf[IR], errorID) - case NDArrayWrite(_, _) => - assert(newChildren.length == 2) - NDArrayWrite(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) - case ArraySort(_, l, r, _) => - assert(newChildren.length == 2) - ArraySort(newChildren(0).asInstanceOf[IR], l, r, newChildren(1).asInstanceOf[IR]) - case ArrayMaximalIndependentSet(_, tb) => - ArrayMaximalIndependentSet( - newChildren(0).asInstanceOf[IR], - tb.map { case (l, r, _) => (l, r, newChildren(1).asInstanceOf[IR]) }, - ) - case ToSet(_) => - assert(newChildren.length == 1) - ToSet(newChildren(0).asInstanceOf[IR]) - case ToDict(_) => - assert(newChildren.length == 1) - ToDict(newChildren(0).asInstanceOf[IR]) - case ToArray(_) => - assert(newChildren.length == 1) - ToArray(newChildren(0).asInstanceOf[IR]) - case CastToArray(_) => - assert(newChildren.length == 1) - CastToArray(newChildren(0).asInstanceOf[IR]) - case ToStream(_, requiresMemoryManagementPerElement) => - assert(newChildren.length == 1) - ToStream(newChildren(0).asInstanceOf[IR], requiresMemoryManagementPerElement) - case LowerBoundOnOrderedCollection(_, _, asKey) => - assert(newChildren.length == 2) - LowerBoundOnOrderedCollection( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - asKey, - ) - case GroupByKey(_) => - assert(newChildren.length == 1) - GroupByKey(newChildren(0).asInstanceOf[IR]) - case RNGStateLiteral() => RNGStateLiteral() - case RNGSplit(_, _) => - assert(newChildren.nonEmpty) - RNGSplit(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) - case StreamLen(_) => - StreamLen(newChildren(0).asInstanceOf[IR]) - case StreamTake(_, _) => - assert(newChildren.length == 2) - StreamTake(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) - case StreamDrop(_, _) => - assert(newChildren.length == 2) - StreamDrop(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) - case StreamGrouped(_, _) => - assert(newChildren.length == 2) - StreamGrouped(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) - case StreamGroupByKey(_, key, missingEqual) => - assert(newChildren.length == 1) - StreamGroupByKey(newChildren(0).asInstanceOf[IR], key, missingEqual) - case StreamMap(_, name, _) => - assert(newChildren.length == 2) - StreamMap(newChildren(0).asInstanceOf[IR], name, newChildren(1).asInstanceOf[IR]) - case StreamZip(_, names, _, behavior, errorID) => - assert(newChildren.length == names.length + 1) - StreamZip( - newChildren.init.asInstanceOf[IndexedSeq[IR]], - names, - newChildren(names.length).asInstanceOf[IR], - behavior, - errorID, - ) - case StreamZipJoin(as, key, curKey, curVals, _) => - assert(newChildren.length == as.length + 1) - StreamZipJoin( - newChildren.init.asInstanceOf[IndexedSeq[IR]], - key, - curKey, - curVals, - newChildren(as.length).asInstanceOf[IR], - ) - case StreamZipJoinProducers(_, ctxName, _, key, curKey, curVals, _) => - assert(newChildren.length == 3) - StreamZipJoinProducers( - newChildren(0).asInstanceOf[IR], - ctxName, - newChildren(1).asInstanceOf[IR], - key, - curKey, - curVals, - newChildren(2).asInstanceOf[IR], - ) - case StreamMultiMerge(as, key) => - assert(newChildren.length == as.length) - StreamMultiMerge(newChildren.asInstanceOf[IndexedSeq[IR]], key) - case StreamFilter(_, name, _) => - assert(newChildren.length == 2) - StreamFilter(newChildren(0).asInstanceOf[IR], name, newChildren(1).asInstanceOf[IR]) - case StreamTakeWhile(_, name, _) => - assert(newChildren.length == 2) - StreamTakeWhile(newChildren(0).asInstanceOf[IR], name, newChildren(1).asInstanceOf[IR]) - case StreamDropWhile(_, name, _) => - assert(newChildren.length == 2) - StreamDropWhile(newChildren(0).asInstanceOf[IR], name, newChildren(1).asInstanceOf[IR]) - case StreamFlatMap(_, name, _) => - assert(newChildren.length == 2) - StreamFlatMap(newChildren(0).asInstanceOf[IR], name, newChildren(1).asInstanceOf[IR]) - case StreamFold(_, _, accumName, valueName, _) => - assert(newChildren.length == 3) - StreamFold( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - accumName, - valueName, - newChildren(2).asInstanceOf[IR], - ) - case StreamFold2(_, accum, valueName, seq, _) => - val ncIR = newChildren.map(_.asInstanceOf[IR]) - assert(newChildren.length == 2 + accum.length + seq.length) - StreamFold2( - ncIR(0), - accum.indices.map(i => (accum(i)._1, ncIR(i + 1))), - valueName, - seq.indices.map(i => ncIR(i + 1 + accum.length)), - ncIR.last, - ) - case StreamScan(_, _, accumName, valueName, _) => - assert(newChildren.length == 3) - StreamScan( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - accumName, - valueName, - newChildren(2).asInstanceOf[IR], - ) - case StreamLeftIntervalJoin(_, _, lKeyNames, rIntrvlName, lname, rname, _) => - assert(newChildren.length == 3) - StreamLeftIntervalJoin( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - lKeyNames, - rIntrvlName, - lname, - rname, - newChildren(2).asInstanceOf[IR], - ) - case StreamJoinRightDistinct(_, _, lKey, rKey, l, r, _, joinType) => - assert(newChildren.length == 3) - StreamJoinRightDistinct( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - lKey, - rKey, - l, - r, - newChildren(2).asInstanceOf[IR], - joinType, - ) - case _: StreamLocalLDPrune => - val IndexedSeq(child: IR, r2Threshold: IR, windowSize: IR, maxQueueSize: IR, nSamples: IR) = - newChildren - StreamLocalLDPrune(child, r2Threshold, windowSize, maxQueueSize, nSamples) - case StreamFor(_, valueName, _) => - assert(newChildren.length == 2) - StreamFor(newChildren(0).asInstanceOf[IR], valueName, newChildren(1).asInstanceOf[IR]) - case StreamAgg(_, name, _) => - assert(newChildren.length == 2) - StreamAgg(newChildren(0).asInstanceOf[IR], name, newChildren(1).asInstanceOf[IR]) - case StreamAggScan(_, name, _) => - assert(newChildren.length == 2) - StreamAggScan(newChildren(0).asInstanceOf[IR], name, newChildren(1).asInstanceOf[IR]) - case RunAgg(_, _, signatures) => - RunAgg(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], signatures) - case RunAggScan(_, name, _, _, _, signatures) => - RunAggScan( - newChildren(0).asInstanceOf[IR], - name, - newChildren(1).asInstanceOf[IR], - newChildren(2).asInstanceOf[IR], - newChildren(3).asInstanceOf[IR], - signatures, - ) - case StreamBufferedAggregate(_, _, _, _, name, aggSignatures, bufferSize) => - StreamBufferedAggregate( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - newChildren(2).asInstanceOf[IR], - newChildren(3).asInstanceOf[IR], - name, - aggSignatures, - bufferSize, - ) - case AggFilter(_, _, isScan) => - assert(newChildren.length == 2) - AggFilter(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], isScan) - case AggExplode(_, name, _, isScan) => - assert(newChildren.length == 2) - AggExplode(newChildren(0).asInstanceOf[IR], name, newChildren(1).asInstanceOf[IR], isScan) - case AggGroupBy(_, _, isScan) => - assert(newChildren.length == 2) - AggGroupBy(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], isScan) - case AggArrayPerElement(_, elementName, indexName, _, _, isScan) => - val newKnownLength = if (newChildren.length == 3) - Some(newChildren(2).asInstanceOf[IR]) - else { - assert(newChildren.length == 2) - None - } - AggArrayPerElement( - newChildren(0).asInstanceOf[IR], - elementName, - indexName, - newChildren(1).asInstanceOf[IR], - newKnownLength, - isScan, - ) - case MakeStruct(fields) => - assert(fields.length == newChildren.length) - MakeStruct(fields.zip(newChildren).map { case ((n, _), a) => (n, a.asInstanceOf[IR]) }) - case SelectFields(_, fields) => - assert(newChildren.length == 1) - SelectFields(newChildren(0).asInstanceOf[IR], fields) - case InsertFields(_, fields, fieldOrder) => - assert(newChildren.length == fields.length + 1) - InsertFields( - newChildren.head.asInstanceOf[IR], - fields.zip(newChildren.tail).map { case ((n, _), a) => (n, a.asInstanceOf[IR]) }, - fieldOrder, - ) - case GetField(_, name) => - assert(newChildren.length == 1) - GetField(newChildren(0).asInstanceOf[IR], name) - case InitOp(i, _, aggSig) => - InitOp(i, newChildren.map(_.asInstanceOf[IR]), aggSig) - case SeqOp(i, _, aggSig) => - SeqOp(i, newChildren.map(_.asInstanceOf[IR]), aggSig) - case ResultOp(i, aggSigs) => - ResultOp(i, aggSigs) - case CombOp(i1, i2, aggSig) => - CombOp(i1, i2, aggSig) - case AggStateValue(i, aggSig) => - AggStateValue(i, aggSig) - case CombOpValue(i, _, aggSig) => - assert(newChildren.length == 1) - CombOpValue(i, newChildren(0).asInstanceOf[IR], aggSig) - case InitFromSerializedValue(i, _, aggSig) => - assert(newChildren.length == 1) - InitFromSerializedValue(i, newChildren(0).asInstanceOf[IR], aggSig) - case SerializeAggs(startIdx, serIdx, spec, aggSigs) => - SerializeAggs(startIdx, serIdx, spec, aggSigs) - case DeserializeAggs(startIdx, serIdx, spec, aggSigs) => - DeserializeAggs(startIdx, serIdx, spec, aggSigs) - case x @ ApplyAggOp(_, _, aggSig) => - val args = newChildren.map(_.asInstanceOf[IR]) - assert(args.length == x.nInitArgs + x.nSeqOpArgs) - ApplyAggOp( - args.take(x.nInitArgs), - args.drop(x.nInitArgs), - aggSig, - ) - case x @ ApplyScanOp(_, _, aggSig) => - val args = newChildren.map(_.asInstanceOf[IR]) - assert(args.length == x.nInitArgs + x.nSeqOpArgs) - ApplyScanOp( - args.take(x.nInitArgs), - args.drop(x.nInitArgs), - aggSig, - ) - case AggFold(_, _, _, accumName, otherAccumName, isScan) => - AggFold( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - newChildren(2).asInstanceOf[IR], - accumName, - otherAccumName, - isScan, - ) - case MakeTuple(fields) => - assert(fields.length == newChildren.length) - MakeTuple(fields.zip(newChildren).map { case ((i, _), newValue) => - (i, newValue.asInstanceOf[IR]) - }) - case GetTupleElement(_, idx) => - assert(newChildren.length == 1) - GetTupleElement(newChildren(0).asInstanceOf[IR], idx) - case In(i, t) => In(i, t) - case Die(_, typ, errorId) => - assert(newChildren.length == 1) - Die(newChildren(0).asInstanceOf[IR], typ, errorId) - case Trap(_) => - assert(newChildren.length == 1) - Trap(newChildren(0).asInstanceOf[IR]) - case ConsoleLog(_, _) => - assert(newChildren.length == 2) - ConsoleLog(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) - case x @ ApplyIR(fn, typeArgs, _, rt, errorID) => - val r = ApplyIR(fn, typeArgs, newChildren.map(_.asInstanceOf[IR]), rt, errorID) - r.conversion = x.conversion - r.inline = x.inline - r - case Apply(fn, typeArgs, _, t, errorID) => - Apply(fn, typeArgs, newChildren.map(_.asInstanceOf[IR]), t, errorID) - case ApplySeeded(fn, _, _, staticUID, t) => - ApplySeeded( - fn, - newChildren.init.map(_.asInstanceOf[IR]), - newChildren.last.asInstanceOf[IR], - staticUID, - t, - ) - case ApplySpecial(fn, typeArgs, _, t, errorID) => - ApplySpecial(fn, typeArgs, newChildren.map(_.asInstanceOf[IR]), t, errorID) - // from MatrixIR - case MatrixWrite(_, writer) => - assert(newChildren.length == 1) - MatrixWrite(newChildren(0).asInstanceOf[MatrixIR], writer) - case MatrixMultiWrite(_, writer) => - MatrixMultiWrite(newChildren.map(_.asInstanceOf[MatrixIR]), writer) - case MatrixCount(_) => - assert(newChildren.length == 1) - MatrixCount(newChildren(0).asInstanceOf[MatrixIR]) - // from TableIR - case TableCount(_) => - assert(newChildren.length == 1) - TableCount(newChildren(0).asInstanceOf[TableIR]) - case TableGetGlobals(_) => - assert(newChildren.length == 1) - TableGetGlobals(newChildren(0).asInstanceOf[TableIR]) - case TableCollect(_) => - assert(newChildren.length == 1) - TableCollect(newChildren(0).asInstanceOf[TableIR]) - case TableAggregate(_, _) => - assert(newChildren.length == 2) - TableAggregate(newChildren(0).asInstanceOf[TableIR], newChildren(1).asInstanceOf[IR]) - case MatrixAggregate(_, _) => - assert(newChildren.length == 2) - MatrixAggregate(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[IR]) - case TableWrite(_, writer) => - assert(newChildren.length == 1) - TableWrite(newChildren(0).asInstanceOf[TableIR], writer) - case TableMultiWrite(_, writer) => - TableMultiWrite(newChildren.map(_.asInstanceOf[TableIR]), writer) - case TableToValueApply(_, function) => - assert(newChildren.length == 1) - TableToValueApply(newChildren(0).asInstanceOf[TableIR], function) - case MatrixToValueApply(_, function) => - assert(newChildren.length == 1) - MatrixToValueApply(newChildren(0).asInstanceOf[MatrixIR], function) - case BlockMatrixToValueApply(_, function) => - assert(newChildren.length == 1) - BlockMatrixToValueApply(newChildren(0).asInstanceOf[BlockMatrixIR], function) - case BlockMatrixCollect(_) => - assert(newChildren.length == 1) - BlockMatrixCollect(newChildren(0).asInstanceOf[BlockMatrixIR]) - case BlockMatrixWrite(_, writer) => - assert(newChildren.length == 1) - BlockMatrixWrite(newChildren(0).asInstanceOf[BlockMatrixIR], writer) - case BlockMatrixMultiWrite(_, writer) => - BlockMatrixMultiWrite(newChildren.map(_.asInstanceOf[BlockMatrixIR]), writer) - case CollectDistributedArray(_, _, cname, gname, _, _, id, tsd) => - assert(newChildren.length == 4) - CollectDistributedArray( - newChildren(0).asInstanceOf[IR], - newChildren(1).asInstanceOf[IR], - cname, - gname, - newChildren(2).asInstanceOf[IR], - newChildren(3).asInstanceOf[IR], - id, - tsd, - ) - case ReadPartition(_, rowType, reader) => - assert(newChildren.length == 1) - ReadPartition(newChildren(0).asInstanceOf[IR], rowType, reader) - case WritePartition(_, _, writer) => - assert(newChildren.length == 2) - WritePartition(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], writer) - case WriteMetadata(_, writer) => - assert(newChildren.length == 1) - WriteMetadata(newChildren(0).asInstanceOf[IR], writer) - case ReadValue(_, writer, requestedType) => - assert(newChildren.length == 1) - ReadValue(newChildren(0).asInstanceOf[IR], writer, requestedType) - case WriteValue(_, _, writer, _) => - assert(newChildren.length == 2 || newChildren.length == 3) - val value = newChildren(0).asInstanceOf[IR] - val path = newChildren(1).asInstanceOf[IR] - val stage = if (newChildren.length == 3) Some(newChildren(2).asInstanceOf[IR]) else None - WriteValue(value, path, writer, stage) - case LiftMeOut(_) => - LiftMeOut(newChildren(0).asInstanceOf[IR]) - } - } -} diff --git a/hail/hail/src/is/hail/expr/ir/Emit.scala b/hail/hail/src/is/hail/expr/ir/Emit.scala index be2e79a091b..ad8e9761456 100644 --- a/hail/hail/src/is/hail/expr/ir/Emit.scala +++ b/hail/hail/src/is/hail/expr/ir/Emit.scala @@ -1903,7 +1903,8 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { newStrides, dataPtr, cb, - region) + region, + ) } case NDArrayRef(nd, idxs, errorId) => @@ -1946,7 +1947,8 @@ class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { ) if ( - (lSType.elementType.virtualType == TFloat64 || lSType.elementType.virtualType == TFloat32) && lSType.nDims == 2 && rSType.nDims == 2 + (lSType.elementType.virtualType == TFloat64 || lSType.elementType + .virtualType == TFloat32) && lSType.nDims == 2 && rSType.nDims == 2 ) { val leftDataAddress = leftPVal.firstDataAddress val rightDataAddress = rightPVal.firstDataAddress diff --git a/hail/hail/src/is/hail/expr/ir/GenericTableValue.scala b/hail/hail/src/is/hail/expr/ir/GenericTableValue.scala index 325b7d8deae..f248de463c4 100644 --- a/hail/hail/src/is/hail/expr/ir/GenericTableValue.scala +++ b/hail/hail/src/is/hail/expr/ir/GenericTableValue.scala @@ -222,7 +222,8 @@ class GenericTableValue( ctx, globalsIR, contextType, contexts, - requestedBody) + requestedBody, + ) } } } diff --git a/hail/hail/src/is/hail/expr/ir/IR.scala b/hail/hail/src/is/hail/expr/ir/IR.scala index 32dcf26f3f6..460f5219ccb 100644 --- a/hail/hail/src/is/hail/expr/ir/IR.scala +++ b/hail/hail/src/is/hail/expr/ir/IR.scala @@ -1,12 +1,11 @@ package is.hail.expr.ir -import is.hail.annotations.{Annotation, Region} +import is.hail.annotations.Region import is.hail.asm4s.Value import is.hail.backend.ExecuteContext -import is.hail.expr.ir.agg.{AggStateSig, PhysicalAggSig} +import is.hail.expr.ir.agg.PhysicalAggSig import is.hail.expr.ir.defs.ApplyIR import is.hail.expr.ir.functions._ -import is.hail.expr.ir.lowering.TableStageDependency import is.hail.expr.ir.streams.StreamProducer import is.hail.io.{AbstractTypedCodecSpec, BufferSpec, TypedCodecSpec} import is.hail.io.avro.{AvroPartitionReader, AvroSchemaSerializer} @@ -27,25 +26,21 @@ import java.io.OutputStream import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} import org.json4s.JsonAST.{JNothing, JString} -sealed trait IR extends BaseIR { +trait IR extends BaseIR { private var _typ: Type = null - def typ: Type = { - if (_typ == null) + override def typ: Type = { + if (_typ == null) { try _typ = InferType(this) catch { case e: Throwable => throw new RuntimeException(s"typ: inference failure:", e) } + assert(_typ != null) + } _typ } - protected lazy val childrenSeq: IndexedSeq[BaseIR] = - Children(this) - - override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): IR = - Copy(this, newChildren) - override def mapChildren(f: BaseIR => BaseIR): IR = super.mapChildren(f).asInstanceOf[IR] override def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): IR = @@ -74,65 +69,12 @@ sealed trait IR extends BaseIR { package defs { - import is.hail.expr.ir.defs.ArrayZipBehavior.ArrayZipBehavior - - sealed trait TypedIR[T <: Type] extends IR { + trait TypedIR[T <: Type] extends IR { override def typ: T = tcoerce[T](super.typ) } // Mark Refs and constants as IRs that are safe to duplicate - sealed trait TrivialIR extends IR - - object Literal { - def coerce(t: Type, x: Any): IR = { - if (x == null) - return NA(t) - t match { - case TInt32 => I32(x.asInstanceOf[Number].intValue()) - case TInt64 => I64(x.asInstanceOf[Number].longValue()) - case TFloat32 => F32(x.asInstanceOf[Number].floatValue()) - case TFloat64 => F64(x.asInstanceOf[Number].doubleValue()) - case TBoolean => if (x.asInstanceOf[Boolean]) True() else False() - case TString => Str(x.asInstanceOf[String]) - case _ => Literal(t, x) - } - } - } - - final case class Literal(_typ: Type, value: Annotation) extends IR { - require(!CanEmit(_typ)) - require(value != null) - // expensive, for debugging - // require(SafeRow.isSafe(value)) - // assert(_typ.typeCheck(value), s"literal invalid:\n ${_typ}\n $value") - } - - object EncodedLiteral { - def apply(codec: AbstractTypedCodecSpec, value: Array[Array[Byte]]): EncodedLiteral = - EncodedLiteral(codec, new WrappedByteArrays(value)) - - def fromPTypeAndAddress(pt: PType, addr: Long, ctx: ExecuteContext): IR = { - pt match { - case _: PInt32 => I32(Region.loadInt(addr)) - case _: PInt64 => I64(Region.loadLong(addr)) - case _: PFloat32 => F32(Region.loadFloat(addr)) - case _: PFloat64 => F64(Region.loadDouble(addr)) - case _: PBoolean => if (Region.loadBoolean(addr)) True() else False() - case ts: PString => Str(ts.loadString(addr)) - case _ => - val etype = EType.defaultFromPType(ctx, pt) - val codec = TypedCodecSpec(etype, pt.virtualType, BufferSpec.wireSpec) - val bytes = codec.encodeArrays(ctx, pt, addr) - EncodedLiteral(codec, bytes) - } - } - } - - final case class EncodedLiteral(codec: AbstractTypedCodecSpec, value: WrappedByteArrays) - extends IR { - require(!CanEmit(codec.encodedVirtualType)) - require(value != null) - } + trait TrivialIR extends IR class WrappedByteArrays(val ba: Array[Array[Byte]]) { override def hashCode(): Int = @@ -150,49 +92,6 @@ package defs { } } - final case class I32(x: Int) extends IR with TrivialIR - final case class I64(x: Long) extends IR with TrivialIR - final case class F32(x: Float) extends IR with TrivialIR - final case class F64(x: Double) extends IR with TrivialIR - - final case class Str(x: String) extends IR with TrivialIR { - override def toString(): String = s"""Str("${StringEscapeUtils.escapeString(x)}")""" - } - - final case class True() extends IR with TrivialIR - final case class False() extends IR with TrivialIR - final case class Void() extends IR with TrivialIR - - object UUID4 { - def apply(): UUID4 = UUID4(genUID()) - } - -// WARNING! This node can only be used when trying to append a one-off, -// random string that will not be reused elsewhere in the pipeline. -// Any other uses will need to write and then read again; this node is -// non-deterministic and will not e.g. exhibit the correct semantics when -// self-joining on streams. - final case class UUID4(id: String) extends IR - - final case class Cast(v: IR, _typ: Type) extends IR - final case class CastRename(v: IR, _typ: Type) extends IR - - final case class NA(_typ: Type) extends IR with TrivialIR - final case class IsNA(value: IR) extends IR - - final case class Coalesce(values: Seq[IR]) extends IR { - require(values.nonEmpty) - } - - final case class Consume(value: IR) extends IR - - final case class If(cond: IR, cnsq: IR, altr: IR) extends IR - - final case class Switch(x: IR, default: IR, cases: IndexedSeq[IR]) extends IR { - override lazy val size: Int = - 2 + cases.length - } - object AggLet { def apply(name: Name, value: IR, body: IR, isScan: Boolean): IR = { val scope = if (isScan) Scope.SCAN else Scope.AGG @@ -215,284 +114,23 @@ package defs { Let(bindings.init, bindings.last._2) } } - } - case class Binding(name: Name, value: IR, scope: Int = Scope.EVAL) - - final case class Block(bindings: IndexedSeq[Binding], body: IR) extends IR { - override lazy val size: Int = - bindings.length + 1 + object Begin { + def apply(xs: IndexedSeq[IR]): IR = + if (xs.isEmpty) + Void() + else + Let(xs.init.map(x => (freshName(), x)), xs.last) } - object Block { - object Insert { - def unapply(bindings: IndexedSeq[Binding]) - : Option[(IndexedSeq[Binding], Binding, IndexedSeq[Binding])] = { - val idx = bindings.indexWhere(_.value.isInstanceOf[InsertFields]) - if (idx == -1) None else Some((bindings.take(idx), bindings(idx), bindings.drop(idx + 1))) - } - } - - object Nested { - def unapply(bindings: IndexedSeq[Binding]): Option[(Int, IndexedSeq[Binding])] = { - val idx = bindings.indexWhere(_.value.isInstanceOf[Block]) - if (idx == -1) None else Some((idx, bindings)) - } - } - } + case class Binding(name: Name, value: IR, scope: Int = Scope.EVAL) - sealed abstract class BaseRef extends IR with TrivialIR { + trait BaseRef extends IR with TrivialIR { def name: Name def _typ: Type } - final case class Ref(name: Name, var _typ: Type) extends BaseRef { - override def typ: Type = { - assert(_typ != null) - _typ - } - } - -// Recur can't exist outside of loop -// Loops can be nested, but we can't call outer loops in terms of inner loops so there can only be one loop "active" in a given context - final case class TailLoop( - name: Name, - params: IndexedSeq[(Name, IR)], - resultType: Type, - body: IR, - ) extends IR { - lazy val paramIdx: Map[Name, Int] = params.map(_._1).zipWithIndex.toMap - } - - final case class Recur(name: Name, args: IndexedSeq[IR], var _typ: Type) extends BaseRef - - final case class RelationalLet(name: Name, value: IR, body: IR) extends IR - final case class RelationalRef(name: Name, _typ: Type) extends BaseRef - - final case class ApplyBinaryPrimOp(op: BinaryOp, l: IR, r: IR) extends IR - final case class ApplyUnaryPrimOp(op: UnaryOp, x: IR) extends IR - final case class ApplyComparisonOp(var op: ComparisonOp[_], l: IR, r: IR) extends IR - - object MakeArray { - def apply(args: IR*): MakeArray = { - assert(args.nonEmpty) - MakeArray(args.toArray, TArray(args.head.typ)) - } - - def unify(ctx: ExecuteContext, args: IndexedSeq[IR], requestedType: TArray = null) - : MakeArray = { - assert(requestedType != null || args.nonEmpty) - - if (args.nonEmpty) - if (args.forall(_.typ == args.head.typ)) - return MakeArray(args, TArray(args.head.typ)) - - MakeArray( - args.map { arg => - val upcast = PruneDeadFields.upcast(ctx, arg, requestedType.elementType) - assert(upcast.typ == requestedType.elementType) - upcast - }, - requestedType, - ) - } - } - - final case class MakeArray(args: IndexedSeq[IR], _typ: TArray) extends IR - - object MakeStream { - def unify( - ctx: ExecuteContext, - args: IndexedSeq[IR], - requiresMemoryManagementPerElement: Boolean = false, - requestedType: TStream = null, - ): MakeStream = { - assert(requestedType != null || args.nonEmpty) - - if (args.nonEmpty) - if (args.forall(_.typ == args.head.typ)) - return MakeStream(args, TStream(args.head.typ), requiresMemoryManagementPerElement) - - MakeStream( - args.map { arg => - val upcast = PruneDeadFields.upcast(ctx, arg, requestedType.elementType) - assert(upcast.typ == requestedType.elementType) - upcast - }, - requestedType, - requiresMemoryManagementPerElement, - ) - } - } - - final case class MakeStream( - args: IndexedSeq[IR], - _typ: TStream, - requiresMemoryManagementPerElement: Boolean = false, - ) extends IR - - object ArrayRef { - def apply(a: IR, i: IR): ArrayRef = ArrayRef(a, i, ErrorIDs.NO_ERROR) - } - - final case class ArrayRef(a: IR, i: IR, errorID: Int) extends IR - - final case class ArraySlice( - a: IR, - start: IR, - stop: Option[IR], - step: IR = I32(1), - errorID: Int = ErrorIDs.NO_ERROR, - ) extends IR - - final case class ArrayLen(a: IR) extends IR - - final case class ArrayZeros(length: IR) extends IR - - final case class ArrayMaximalIndependentSet(edges: IR, tieBreaker: Option[(Name, Name, IR)]) - extends IR - - /** [[StreamIota]] is an infinite stream producer, whose element is an integer starting at - * `start`, updated by `step` at each iteration. The name comes from APL: - * [[https://stackoverflow.com/questions/9244879/what-does-iota-of-stdiota-stand-for]] - */ - final case class StreamIota( - start: IR, - step: IR, - requiresMemoryManagementPerElement: Boolean = false, - ) extends IR - - final case class StreamRange( - start: IR, - stop: IR, - step: IR, - requiresMemoryManagementPerElement: Boolean = false, - errorID: Int = ErrorIDs.NO_ERROR, - ) extends IR - - object ArraySort { - def apply(a: IR, ascending: IR = True(), onKey: Boolean = false): ArraySort = { - val l = freshName() - val r = freshName() - val atyp = tcoerce[TStream](a.typ) - val compare = if (onKey) { - val elementType = atyp.elementType.asInstanceOf[TBaseStruct] - elementType match { - case _: TStruct => - val elt = tcoerce[TStruct](atyp.elementType) - ApplyComparisonOp( - Compare(elt.types(0)), - GetField(Ref(l, elt), elt.fieldNames(0)), - GetField(Ref(r, atyp.elementType), elt.fieldNames(0)), - ) - case _: TTuple => - val elt = tcoerce[TTuple](atyp.elementType) - ApplyComparisonOp( - Compare(elt.types(0)), - GetTupleElement(Ref(l, elt), elt.fields(0).index), - GetTupleElement(Ref(r, atyp.elementType), elt.fields(0).index), - ) - } - } else { - ApplyComparisonOp( - Compare(atyp.elementType), - Ref(l, atyp.elementType), - Ref(r, atyp.elementType), - ) - } - - ArraySort(a, l, r, If(ascending, compare < 0, compare > 0)) - } - } - - final case class ArraySort(a: IR, left: Name, right: Name, lessThan: IR) extends IR - - final case class ToSet(a: IR) extends IR - - final case class ToDict(a: IR) extends IR - - final case class ToArray(a: IR) extends IR - - final case class CastToArray(a: IR) extends IR - - final case class ToStream(a: IR, requiresMemoryManagementPerElement: Boolean = false) extends IR - - final case class StreamBufferedAggregate( - streamChild: IR, - initAggs: IR, - newKey: IR, - seqOps: IR, - name: Name, - aggSignatures: IndexedSeq[PhysicalAggSig], - bufferSize: Int, - ) extends IR - - final case class LowerBoundOnOrderedCollection(orderedCollection: IR, elem: IR, onKey: Boolean) - extends IR - - final case class GroupByKey(collection: IR) extends IR - - final case class RNGStateLiteral() extends IR - - final case class RNGSplit(state: IR, dynBitstring: IR) extends IR - - final case class StreamLen(a: IR) extends IR - - final case class StreamGrouped(a: IR, groupSize: IR) extends IR - - final case class StreamGroupByKey(a: IR, key: IndexedSeq[String], missingEqual: Boolean) - extends IR - - final case class StreamMap(a: IR, name: Name, body: IR) extends TypedIR[TStream] { - def elementTyp: Type = typ.elementType - } - - final case class StreamTakeWhile(a: IR, elementName: Name, body: IR) extends IR - - final case class StreamDropWhile(a: IR, elementName: Name, body: IR) extends IR - - final case class StreamTake(a: IR, num: IR) extends IR - - final case class StreamDrop(a: IR, num: IR) extends IR - - /* Generate, in ascending order, a uniform random sample, without replacement, of numToSample - * integers in the range [0, totalRange) */ - final case class SeqSample( - totalRange: IR, - numToSample: IR, - rngState: IR, - requiresMemoryManagementPerElement: Boolean, - ) extends IR - - /* Take the child stream and sort each element into buckets based on the provided pivots. The - * first and last elements of pivots are the endpoints of the first and last interval - * respectively, should not be contained in the dataset. */ - final case class StreamDistribute( - child: IR, - pivots: IR, - path: IR, - comparisonOp: ComparisonOp[_], - spec: AbstractTypedCodecSpec, - ) extends IR - - // "Whiten" a stream of vectors by regressing out from each vector all components - // in the direction of vectors in the preceding window. For efficiency, takes - // a stream of "chunks" of vectors. - // Takes a stream of structs, with two designated fields: `prevWindow` is the - // previous window (e.g. from the previous partition), if there is one, and - // `newChunk` is the new chunk to whiten. - final case class StreamWhiten( - stream: IR, - newChunk: String, - prevWindow: String, - vecSize: Int, - windowSize: Int, - chunkSize: Int, - blockSize: Int, - normalizeAfterWhiten: Boolean, - ) extends IR - object ArrayZipBehavior extends Enumeration { type ArrayZipBehavior = Value val AssumeSameLength: Value = Value(0) @@ -501,77 +139,6 @@ package defs { val ExtendNA: Value = Value(3) } - final case class StreamZip( - as: IndexedSeq[IR], - names: IndexedSeq[Name], - body: IR, - behavior: ArrayZipBehavior, - errorID: Int = ErrorIDs.NO_ERROR, - ) extends TypedIR[TStream] - - final case class StreamMultiMerge(as: IndexedSeq[IR], key: IndexedSeq[String]) - extends TypedIR[TStream] - - final case class StreamZipJoinProducers( - contexts: IR, - ctxName: Name, - makeProducer: IR, - key: IndexedSeq[String], - curKey: Name, - curVals: Name, - joinF: IR, - ) extends TypedIR[TStream] - - /** The StreamZipJoin node assumes that input streams have distinct keys. If input streams do not - * have distinct keys, the key that is included in the result is undefined, but is likely the - * last. - */ - final case class StreamZipJoin( - as: IndexedSeq[IR], - key: IndexedSeq[String], - curKey: Name, - curVals: Name, - joinF: IR, - ) extends TypedIR[TStream] - - final case class StreamFilter(a: IR, name: Name, cond: IR) extends TypedIR[TStream] - - final case class StreamFlatMap(a: IR, name: Name, body: IR) extends TypedIR[TStream] - - final case class StreamFold(a: IR, zero: IR, accumName: Name, valueName: Name, body: IR) - extends IR - - object StreamFold2 { - def apply(a: StreamFold): StreamFold2 = - StreamFold2( - a.a, - FastSeq((a.accumName, a.zero)), - a.valueName, - FastSeq(a.body), - Ref(a.accumName, a.zero.typ), - ) - } - - final case class StreamFold2( - a: IR, - accum: IndexedSeq[(Name, IR)], - valueName: Name, - seq: IndexedSeq[IR], - result: IR, - ) extends IR { - assert(accum.length == seq.length) - val nameIdx: Map[Name, Int] = accum.map(_._1).zipWithIndex.toMap - } - - final case class StreamScan(a: IR, zero: IR, accumName: Name, valueName: Name, body: IR) - extends IR - - final case class StreamFor(a: IR, valueName: Name, body: IR) extends IR - - final case class StreamAgg(a: IR, name: Name, query: IR) extends IR - - final case class StreamAggScan(a: IR, name: Name, query: IR) extends IR - object StreamJoin { def apply( left: IR, @@ -647,355 +214,10 @@ package defs { } } - final case class StreamLeftIntervalJoin( - // input streams - left: IR, - right: IR, - - // names for joiner - lKeyFieldName: String, - rIntervalFieldName: String, - - // how to combine records - lname: Name, - rname: Name, - body: IR, - ) extends IR { - override protected lazy val childrenSeq: IndexedSeq[BaseIR] = - FastSeq(left, right, body) - } - - final case class StreamJoinRightDistinct( - left: IR, - right: IR, - lKey: IndexedSeq[String], - rKey: IndexedSeq[String], - l: Name, - r: Name, - joinF: IR, - joinType: String, - ) extends IR { - def isIntervalJoin: Boolean = { - if (rKey.size != 1) return false - val lKeyTyp = tcoerce[TStruct](tcoerce[TStream](left.typ).elementType).fieldType(lKey(0)) - val rKeyTyp = tcoerce[TStruct](tcoerce[TStream](right.typ).elementType).fieldType(rKey(0)) - - rKeyTyp.isInstanceOf[TInterval] && lKeyTyp != rKeyTyp - } - } - - final case class StreamLocalLDPrune( - child: IR, - r2Threshold: IR, - windowSize: IR, - maxQueueSize: IR, - nSamples: IR, - ) extends IR - - sealed trait NDArrayIR extends TypedIR[TNDArray] { + trait NDArrayIR extends TypedIR[TNDArray] { def elementTyp: Type = typ.elementType } - object MakeNDArray { - def fill(elt: IR, shape: IndexedSeq[IR], rowMajor: IR): MakeNDArray = { - val flatSize: IR = if (shape.nonEmpty) - shape.reduce((l, r) => l * r) - else - 0L - MakeNDArray( - ToArray(mapIR(rangeIR(flatSize.toI))(_ => elt)), - MakeTuple.ordered(shape), - rowMajor, - ErrorIDs.NO_ERROR, - ) - } - } - - final case class MakeNDArray(data: IR, shape: IR, rowMajor: IR, errorId: Int) extends NDArrayIR - - final case class NDArrayShape(nd: IR) extends IR - - final case class NDArrayReshape(nd: IR, shape: IR, errorID: Int) extends NDArrayIR - - final case class NDArrayConcat(nds: IR, axis: Int) extends NDArrayIR - - final case class NDArrayRef(nd: IR, idxs: IndexedSeq[IR], errorId: Int) extends IR - - final case class NDArraySlice(nd: IR, slices: IR) extends NDArrayIR - - final case class NDArrayFilter(nd: IR, keep: IndexedSeq[IR]) extends NDArrayIR - - final case class NDArrayMap(nd: IR, valueName: Name, body: IR) extends NDArrayIR - - final case class NDArrayMap2(l: IR, r: IR, lName: Name, rName: Name, body: IR, errorID: Int) - extends NDArrayIR - - final case class NDArrayReindex(nd: IR, indexExpr: IndexedSeq[Int]) extends NDArrayIR - - final case class NDArrayAgg(nd: IR, axes: IndexedSeq[Int]) extends IR - - final case class NDArrayWrite(nd: IR, path: IR) extends IR - - final case class NDArrayMatMul(l: IR, r: IR, errorID: Int) extends NDArrayIR - - object NDArrayQR { - def pType(mode: String, req: Boolean): PType = { - mode match { - case "r" => PCanonicalNDArray(PFloat64Required, 2, req) - case "raw" => PCanonicalTuple( - req, - PCanonicalNDArray(PFloat64Required, 2, true), - PCanonicalNDArray(PFloat64Required, 1, true), - ) - case "reduced" => PCanonicalTuple( - req, - PCanonicalNDArray(PFloat64Required, 2, true), - PCanonicalNDArray(PFloat64Required, 2, true), - ) - case "complete" => PCanonicalTuple( - req, - PCanonicalNDArray(PFloat64Required, 2, true), - PCanonicalNDArray(PFloat64Required, 2, true), - ) - } - } - } - - object NDArraySVD { - def pTypes(computeUV: Boolean, req: Boolean): PType = { - if (computeUV) { - PCanonicalTuple( - req, - PCanonicalNDArray(PFloat64Required, 2, true), - PCanonicalNDArray(PFloat64Required, 1, true), - PCanonicalNDArray(PFloat64Required, 2, true), - ) - } else { - PCanonicalNDArray(PFloat64Required, 1, req) - } - } - } - - object NDArrayInv { - val pType = PCanonicalNDArray(PFloat64Required, 2) - } - - final case class NDArrayQR(nd: IR, mode: String, errorID: Int) extends IR - - final case class NDArraySVD(nd: IR, fullMatrices: Boolean, computeUV: Boolean, errorID: Int) - extends IR - - object NDArrayEigh { - def pTypes(eigvalsOnly: Boolean, req: Boolean): PType = - if (eigvalsOnly) { - PCanonicalNDArray(PFloat64Required, 1, req) - } else { - PCanonicalTuple( - req, - PCanonicalNDArray(PFloat64Required, 1, true), - PCanonicalNDArray(PFloat64Required, 2, true), - ) - } - } - - final case class NDArrayEigh(nd: IR, eigvalsOnly: Boolean, errorID: Int) extends IR - - final case class NDArrayInv(nd: IR, errorID: Int) extends IR - - final case class AggFilter(cond: IR, aggIR: IR, isScan: Boolean) extends IR - - final case class AggExplode(array: IR, name: Name, aggBody: IR, isScan: Boolean) extends IR - - final case class AggGroupBy(key: IR, aggIR: IR, isScan: Boolean) extends IR - - final case class AggArrayPerElement( - a: IR, - elementName: Name, - indexName: Name, - aggBody: IR, - knownLength: Option[IR], - isScan: Boolean, - ) extends IR - - object ApplyAggOp { - def apply(op: AggOp, initOpArgs: IR*)(seqOpArgs: IR*): ApplyAggOp = - ApplyAggOp( - initOpArgs.toIndexedSeq, - seqOpArgs.toIndexedSeq, - AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ)), - ) - } - - final case class ApplyAggOp( - initOpArgs: IndexedSeq[IR], - seqOpArgs: IndexedSeq[IR], - aggSig: AggSignature, - ) extends IR { - - def nSeqOpArgs = seqOpArgs.length - - def nInitArgs = initOpArgs.length - - def op: AggOp = aggSig.op - } - - object AggFold { - - def min(element: IR, sortFields: IndexedSeq[SortField]): IR = { - val elementType = element.typ.asInstanceOf[TStruct] - val keyType = elementType.select(sortFields.map(_.field))._1 - minAndMaxHelper(element, keyType, StructLT(keyType, sortFields)) - } - - def max(element: IR, sortFields: IndexedSeq[SortField]): IR = { - val elementType = element.typ.asInstanceOf[TStruct] - val keyType = elementType.select(sortFields.map(_.field))._1 - minAndMaxHelper(element, keyType, StructGT(keyType, sortFields)) - } - - def all(element: IR): IR = - aggFoldIR(True()) { accum => - ApplySpecial("land", Seq.empty[Type], Seq(accum, element), TBoolean, ErrorIDs.NO_ERROR) - } { (accum1, accum2) => - ApplySpecial("land", Seq.empty[Type], Seq(accum1, accum2), TBoolean, ErrorIDs.NO_ERROR) - } - - private def minAndMaxHelper(element: IR, keyType: TStruct, comp: ComparisonOp[Boolean]): IR = { - val keyFields = keyType.fields.map(_.name) - - val minAndMaxZero = NA(keyType) - val aggFoldMinAccumName1 = freshName() - val aggFoldMinAccumName2 = freshName() - val aggFoldMinAccumRef1 = Ref(aggFoldMinAccumName1, keyType) - val aggFoldMinAccumRef2 = Ref(aggFoldMinAccumName2, keyType) - val minSeq = bindIR(SelectFields(element, keyFields)) { keyOfCurElementRef => - If( - IsNA(aggFoldMinAccumRef1), - keyOfCurElementRef, - If( - ApplyComparisonOp(comp, aggFoldMinAccumRef1, keyOfCurElementRef), - aggFoldMinAccumRef1, - keyOfCurElementRef, - ), - ) - } - val minComb = - If( - IsNA(aggFoldMinAccumRef1), - aggFoldMinAccumRef2, - If( - ApplyComparisonOp(comp, aggFoldMinAccumRef1, aggFoldMinAccumRef2), - aggFoldMinAccumRef1, - aggFoldMinAccumRef2, - ), - ) - - AggFold(minAndMaxZero, minSeq, minComb, aggFoldMinAccumName1, aggFoldMinAccumName2, false) - } - } - - final case class AggFold( - zero: IR, - seqOp: IR, - combOp: IR, - accumName: Name, - otherAccumName: Name, - isScan: Boolean, - ) extends IR - - object ApplyScanOp { - def apply(op: AggOp, initOpArgs: IR*)(seqOpArgs: IR*): ApplyScanOp = - ApplyScanOp( - initOpArgs.toIndexedSeq, - seqOpArgs.toIndexedSeq, - AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ)), - ) - } - - final case class ApplyScanOp( - initOpArgs: IndexedSeq[IR], - seqOpArgs: IndexedSeq[IR], - aggSig: AggSignature, - ) extends IR { - - def nSeqOpArgs = seqOpArgs.length - - def nInitArgs = initOpArgs.length - - def op: AggOp = aggSig.op - } - - final case class InitOp(i: Int, args: IndexedSeq[IR], aggSig: PhysicalAggSig) extends IR - - final case class SeqOp(i: Int, args: IndexedSeq[IR], aggSig: PhysicalAggSig) extends IR - - final case class CombOp(i1: Int, i2: Int, aggSig: PhysicalAggSig) extends IR - - object ResultOp { - def makeTuple(aggs: IndexedSeq[PhysicalAggSig]) = - MakeTuple.ordered(aggs.zipWithIndex.map { case (aggSig, index) => - ResultOp(index, aggSig) - }) - } - - final case class ResultOp(idx: Int, aggSig: PhysicalAggSig) extends IR - - final private[ir] case class CombOpValue(i: Int, value: IR, aggSig: PhysicalAggSig) extends IR - - final case class AggStateValue(i: Int, aggSig: AggStateSig) extends IR - - final case class InitFromSerializedValue(i: Int, value: IR, aggSig: AggStateSig) extends IR - - final case class SerializeAggs( - startIdx: Int, - serializedIdx: Int, - spec: BufferSpec, - aggSigs: IndexedSeq[AggStateSig], - ) extends IR - - final case class DeserializeAggs( - startIdx: Int, - serializedIdx: Int, - spec: BufferSpec, - aggSigs: IndexedSeq[AggStateSig], - ) extends IR - - final case class RunAgg(body: IR, result: IR, signature: IndexedSeq[AggStateSig]) extends IR - - final case class RunAggScan( - array: IR, - name: Name, - init: IR, - seqs: IR, - result: IR, - signature: IndexedSeq[AggStateSig], - ) extends IR - - object Begin { - def apply(xs: IndexedSeq[IR]): IR = - if (xs.isEmpty) - Void() - else - Let(xs.init.map(x => (freshName(), x)), xs.last) - } - - final case class Begin(xs: IndexedSeq[IR]) extends IR - - final case class MakeStruct(fields: IndexedSeq[(String, IR)]) extends IR - - final case class SelectFields(old: IR, fields: IndexedSeq[String]) extends IR - - object InsertFields { - def apply(old: IR, fields: IndexedSeq[(String, IR)]): InsertFields = - InsertFields(old, fields, None) - } - - final case class InsertFields( - old: IR, - fields: IndexedSeq[(String, IR)], - fieldOrder: Option[IndexedSeq[String]], - ) extends TypedIR[TStruct] - object GetFieldByIdx { def apply(s: IR, field: Int): IR = (s.typ: @unchecked) match { @@ -1004,80 +226,7 @@ package defs { } } - final case class GetField(o: IR, name: String) extends IR - - object MakeTuple { - def ordered(types: IndexedSeq[IR]): MakeTuple = MakeTuple(types.zipWithIndex.map { - case (ir, i) => - (i, ir) - }) - } - - final case class MakeTuple(fields: IndexedSeq[(Int, IR)]) extends IR - - final case class GetTupleElement(o: IR, idx: Int) extends IR - - object In { - def apply(i: Int, typ: Type): In = In( - i, - SingleCodeEmitParamType( - false, - typ match { - case TInt32 => Int32SingleCodeType - case TInt64 => Int64SingleCodeType - case TFloat32 => Float32SingleCodeType - case TFloat64 => Float64SingleCodeType - case TBoolean => BooleanSingleCodeType - case _: TStream => throw new UnsupportedOperationException - case t => PTypeReferenceSingleCodeType(PType.canonical(t)) - }, - ), - ) - } - - // Function Input - final case class In(i: Int, _typ: EmitParamType) extends IR - - // FIXME: should be type any - object Die { - def apply(message: String, typ: Type): Die = Die(Str(message), typ, ErrorIDs.NO_ERROR) - - def apply(message: String, typ: Type, errorId: Int): Die = Die(Str(message), typ, errorId) - } - - /** the Trap node runs the `child` node with an exception handler. If the child throws a - * HailException (user exception), then we return the tuple ((msg, errorId), NA). If the child - * throws any other exception, we raise that exception. If the child does not throw, then we - * return the tuple (NA, child value). - */ - final case class Trap(child: IR) extends IR - - final case class Die(message: IR, _typ: Type, errorId: Int) extends IR - - final case class ConsoleLog(message: IR, result: IR) extends IR - - final case class ApplyIR( - function: String, - typeArgs: Seq[Type], - args: Seq[IR], - returnType: Type, - errorID: Int, - ) extends IR { - var conversion: (Seq[Type], Seq[IR], Int) => IR = _ - var inline: Boolean = _ - - private lazy val refs = args.map(a => Ref(freshName(), a.typ)).toArray - lazy val body: IR = conversion(typeArgs, refs, errorID).deepCopy() - lazy val refIdx: Map[Name, Int] = refs.map(_.name).zipWithIndex.toMap - - lazy val explicitNode: IR = { - val ir = Let(refs.map(_.name).zip(args), body) - assert(ir.typ == returnType) - ir - } - } - - sealed abstract class AbstractApplyNode[F <: JVMFunction] extends IR { + trait AbstractApplyNode[F <: JVMFunction] extends IR { def function: String def args: Seq[IR] @@ -1093,90 +242,6 @@ package defs { .asInstanceOf[F] } - final case class Apply( - function: String, - typeArgs: Seq[Type], - args: Seq[IR], - returnType: Type, - errorID: Int, - ) extends AbstractApplyNode[UnseededMissingnessObliviousJVMFunction] - - final case class ApplySeeded( - function: String, - _args: Seq[IR], - rngState: IR, - staticUID: Long, - returnType: Type, - ) extends AbstractApplyNode[UnseededMissingnessObliviousJVMFunction] { - val args = rngState +: _args - val typeArgs: Seq[Type] = Seq.empty[Type] - } - - final case class ApplySpecial( - function: String, - typeArgs: Seq[Type], - args: Seq[IR], - returnType: Type, - errorID: Int, - ) extends AbstractApplyNode[UnseededMissingnessAwareJVMFunction] - - final case class LiftMeOut(child: IR) extends IR - - final case class TableCount(child: TableIR) extends IR - - final case class MatrixCount(child: MatrixIR) extends IR - - final case class TableAggregate(child: TableIR, query: IR) extends IR - - final case class MatrixAggregate(child: MatrixIR, query: IR) extends IR - - final case class TableWrite(child: TableIR, writer: TableWriter) extends IR - - final case class TableMultiWrite( - _children: IndexedSeq[TableIR], - writer: WrappedMatrixNativeMultiWriter, - ) extends IR - - final case class TableGetGlobals(child: TableIR) extends IR - - final case class TableCollect(child: TableIR) extends IR - - final case class MatrixWrite(child: MatrixIR, writer: MatrixWriter) extends IR - - final case class MatrixMultiWrite( - _children: IndexedSeq[MatrixIR], - writer: MatrixNativeMultiWriter, - ) extends IR - - final case class TableToValueApply(child: TableIR, function: TableToValueFunction) extends IR - - final case class MatrixToValueApply(child: MatrixIR, function: MatrixToValueFunction) extends IR - - final case class BlockMatrixToValueApply( - child: BlockMatrixIR, - function: BlockMatrixToValueFunction, - ) extends IR - - final case class BlockMatrixCollect(child: BlockMatrixIR) extends NDArrayIR - - final case class BlockMatrixWrite(child: BlockMatrixIR, writer: BlockMatrixWriter) extends IR - - final case class BlockMatrixMultiWrite( - blockMatrices: IndexedSeq[BlockMatrixIR], - writer: BlockMatrixMultiWriter, - ) extends IR - - final case class CollectDistributedArray( - contexts: IR, - globals: IR, - cname: Name, - gname: Name, - body: IR, - dynamicID: IR, - staticID: String, - tsd: Option[TableStageDependency] = None, - ) extends IR - object PartitionReader { implicit val formats: Formats = new DefaultFormats() { @@ -1399,21 +464,6 @@ package defs { writeAnnotations.consume(cb, {}, _ => ()) } - final case class ReadPartition(context: IR, rowType: TStruct, reader: PartitionReader) extends IR - - final case class WritePartition(value: IR, writeCtx: IR, writer: PartitionWriter) extends IR - - final case class WriteMetadata(writeAnnotations: IR, writer: MetadataWriter) extends IR - - final case class ReadValue(path: IR, reader: ValueReader, requestedType: Type) extends IR - - final case class WriteValue( - value: IR, - path: IR, - writer: ValueWriter, - stagingFile: Option[IR] = None, - ) extends IR - class PrimitiveIR(val self: IR) extends AnyVal { def +(other: IR): IR = { assert(self.typ == other.typ) @@ -1463,4 +513,395 @@ package defs { object ErrorIDs { val NO_ERROR = -1 } + + package exts { + + abstract class UUID4CompanionExt { + def apply(): UUID4 = UUID4(genUID()) + } + + abstract class MakeArrayCompanionExt { + def apply(args: IR*): MakeArray = { + assert(args.nonEmpty) + MakeArray(args.toFastSeq, TArray(args.head.typ)) + } + + def unify(ctx: ExecuteContext, args: IndexedSeq[IR], requestedType: TArray = null) + : MakeArray = { + assert(requestedType != null || args.nonEmpty) + + if (args.nonEmpty) + if (args.forall(_.typ == args.head.typ)) + return MakeArray(args, TArray(args.head.typ)) + + MakeArray( + args.map { arg => + val upcast = PruneDeadFields.upcast(ctx, arg, requestedType.elementType) + assert(upcast.typ == requestedType.elementType) + upcast + }, + requestedType, + ) + } + } + + abstract class LiteralCompanionExt { + def coerce(t: Type, x: Any): IR = { + if (x == null) + return NA(t) + t match { + case TInt32 => I32(x.asInstanceOf[Number].intValue()) + case TInt64 => I64(x.asInstanceOf[Number].longValue()) + case TFloat32 => F32(x.asInstanceOf[Number].floatValue()) + case TFloat64 => F64(x.asInstanceOf[Number].doubleValue()) + case TBoolean => if (x.asInstanceOf[Boolean]) True() else False() + case TString => Str(x.asInstanceOf[String]) + case _ => Literal(t, x) + } + } + } + + abstract class EncodedLiteralCompanionExt { + def apply(codec: AbstractTypedCodecSpec, value: Array[Array[Byte]]): EncodedLiteral = + EncodedLiteral(codec, new WrappedByteArrays(value)) + + def fromPTypeAndAddress(pt: PType, addr: Long, ctx: ExecuteContext): IR = { + pt match { + case _: PInt32 => I32(Region.loadInt(addr)) + case _: PInt64 => I64(Region.loadLong(addr)) + case _: PFloat32 => F32(Region.loadFloat(addr)) + case _: PFloat64 => F64(Region.loadDouble(addr)) + case _: PBoolean => if (Region.loadBoolean(addr)) True() else False() + case ts: PString => Str(ts.loadString(addr)) + case _ => + val etype = EType.defaultFromPType(ctx, pt) + val codec = TypedCodecSpec(etype, pt.virtualType, BufferSpec.wireSpec) + val bytes = codec.encodeArrays(ctx, pt, addr) + EncodedLiteral(codec, bytes) + } + } + } + + abstract class BlockCompanionExt { + object Insert { + def unapply(bindings: IndexedSeq[Binding]) + : Option[(IndexedSeq[Binding], Binding, IndexedSeq[Binding])] = { + val idx = bindings.indexWhere(_.value.isInstanceOf[InsertFields]) + if (idx == -1) None else Some((bindings.take(idx), bindings(idx), bindings.drop(idx + 1))) + } + } + + object Nested { + def unapply(bindings: IndexedSeq[Binding]): Option[(Int, IndexedSeq[Binding])] = { + val idx = bindings.indexWhere(_.value.isInstanceOf[Block]) + if (idx == -1) None else Some((idx, bindings)) + } + } + } + + abstract class MakeStreamCompanionExt { + def unify( + ctx: ExecuteContext, + args: IndexedSeq[IR], + requiresMemoryManagementPerElement: Boolean = false, + requestedType: TStream = null, + ): MakeStream = { + assert(requestedType != null || args.nonEmpty) + + if (args.nonEmpty) + if (args.forall(_.typ == args.head.typ)) + return MakeStream(args, TStream(args.head.typ), requiresMemoryManagementPerElement) + + MakeStream( + args.map { arg => + val upcast = PruneDeadFields.upcast(ctx, arg, requestedType.elementType) + assert(upcast.typ == requestedType.elementType) + upcast + }, + requestedType, + requiresMemoryManagementPerElement, + ) + } + } + + abstract class ArraySortCompanionExt { + def apply(a: IR, ascending: IR = True(), onKey: Boolean = false): ArraySort = { + val l = freshName() + val r = freshName() + val atyp = tcoerce[TStream](a.typ) + val compare = if (onKey) { + val elementType = atyp.elementType.asInstanceOf[TBaseStruct] + elementType match { + case _: TStruct => + val elt = tcoerce[TStruct](atyp.elementType) + ApplyComparisonOp( + Compare(elt.types(0)), + GetField(Ref(l, elt), elt.fieldNames(0)), + GetField(Ref(r, atyp.elementType), elt.fieldNames(0)), + ) + case _: TTuple => + val elt = tcoerce[TTuple](atyp.elementType) + ApplyComparisonOp( + Compare(elt.types(0)), + GetTupleElement(Ref(l, elt), elt.fields(0).index), + GetTupleElement(Ref(r, atyp.elementType), elt.fields(0).index), + ) + } + } else { + ApplyComparisonOp( + Compare(atyp.elementType), + Ref(l, atyp.elementType), + Ref(r, atyp.elementType), + ) + } + + ArraySort(a, l, r, If(ascending, compare < 0, compare > 0)) + } + } + + abstract class StreamFold2CompanionExt { + def apply(a: StreamFold): StreamFold2 = + StreamFold2( + a.a, + FastSeq((a.accumName, a.zero)), + a.valueName, + FastSeq(a.body), + Ref(a.accumName, a.zero.typ), + ) + } + + trait StreamJoinRightDistinctExt { self: StreamJoinRightDistinct => + def isIntervalJoin: Boolean = { + if (rKey.size != 1) return false + val lKeyTyp = tcoerce[TStruct](tcoerce[TStream](left.typ).elementType).fieldType(lKey(0)) + val rKeyTyp = tcoerce[TStruct](tcoerce[TStream](right.typ).elementType).fieldType(rKey(0)) + + rKeyTyp.isInstanceOf[TInterval] && lKeyTyp != rKeyTyp + } + } + + abstract class MakeNDArrayCompanionExt { + def fill(elt: IR, shape: IndexedSeq[IR], rowMajor: IR): MakeNDArray = { + val flatSize: IR = if (shape.nonEmpty) + shape.reduce((l, r) => l * r) + else + 0L + MakeNDArray( + ToArray(mapIR(rangeIR(flatSize.toI))(_ => elt)), + MakeTuple.ordered(shape), + rowMajor, + ErrorIDs.NO_ERROR, + ) + } + } + + abstract class NDArrayQRCompanionExt { + def pType(mode: String, req: Boolean): PType = { + mode match { + case "r" => PCanonicalNDArray(PFloat64Required, 2, req) + case "raw" => PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 2, true), + PCanonicalNDArray(PFloat64Required, 1, true), + ) + case "reduced" => PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 2, true), + PCanonicalNDArray(PFloat64Required, 2, true), + ) + case "complete" => PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 2, true), + PCanonicalNDArray(PFloat64Required, 2, true), + ) + } + } + } + + abstract class NDArraySVDCompanionExt { + def pTypes(computeUV: Boolean, req: Boolean): PType = { + if (computeUV) { + PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 2, true), + PCanonicalNDArray(PFloat64Required, 1, true), + PCanonicalNDArray(PFloat64Required, 2, true), + ) + } else { + PCanonicalNDArray(PFloat64Required, 1, req) + } + } + } + + abstract class NDArrayEighCompanionExt { + def pTypes(eigvalsOnly: Boolean, req: Boolean): PType = + if (eigvalsOnly) { + PCanonicalNDArray(PFloat64Required, 1, req) + } else { + PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 1, true), + PCanonicalNDArray(PFloat64Required, 2, true), + ) + } + } + + abstract class NDArrayInvCompanionExt { + val pType = PCanonicalNDArray(PFloat64Required, 2) + } + + abstract class ApplyAggOpCompanionExt { + def apply(op: AggOp, initOpArgs: IR*)(seqOpArgs: IR*): ApplyAggOp = + ApplyAggOp( + initOpArgs.toIndexedSeq, + seqOpArgs.toIndexedSeq, + AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ)), + ) + } + + trait ApplyAggOpExt { self: ApplyAggOp => + def nSeqOpArgs = seqOpArgs.length + + def nInitArgs = initOpArgs.length + + def op: AggOp = aggSig.op + } + + abstract class AggFoldCompanionExt { + def min(element: IR, sortFields: IndexedSeq[SortField]): IR = { + val elementType = element.typ.asInstanceOf[TStruct] + val keyType = elementType.select(sortFields.map(_.field))._1 + minAndMaxHelper(element, keyType, StructLT(keyType, sortFields)) + } + + def max(element: IR, sortFields: IndexedSeq[SortField]): IR = { + val elementType = element.typ.asInstanceOf[TStruct] + val keyType = elementType.select(sortFields.map(_.field))._1 + minAndMaxHelper(element, keyType, StructGT(keyType, sortFields)) + } + + def all(element: IR): IR = + aggFoldIR(True()) { accum => + ApplySpecial( + "land", + Seq.empty[Type], + FastSeq(accum, element), + TBoolean, + ErrorIDs.NO_ERROR, + ) + } { (accum1, accum2) => + ApplySpecial( + "land", + Seq.empty[Type], + FastSeq(accum1, accum2), + TBoolean, + ErrorIDs.NO_ERROR, + ) + } + + private def minAndMaxHelper(element: IR, keyType: TStruct, comp: ComparisonOp[Boolean]) + : IR = { + val keyFields = keyType.fields.map(_.name) + + val minAndMaxZero = NA(keyType) + val aggFoldMinAccumName1 = freshName() + val aggFoldMinAccumName2 = freshName() + val aggFoldMinAccumRef1 = Ref(aggFoldMinAccumName1, keyType) + val aggFoldMinAccumRef2 = Ref(aggFoldMinAccumName2, keyType) + val minSeq = bindIR(SelectFields(element, keyFields)) { keyOfCurElementRef => + If( + IsNA(aggFoldMinAccumRef1), + keyOfCurElementRef, + If( + ApplyComparisonOp(comp, aggFoldMinAccumRef1, keyOfCurElementRef), + aggFoldMinAccumRef1, + keyOfCurElementRef, + ), + ) + } + val minComb = + If( + IsNA(aggFoldMinAccumRef1), + aggFoldMinAccumRef2, + If( + ApplyComparisonOp(comp, aggFoldMinAccumRef1, aggFoldMinAccumRef2), + aggFoldMinAccumRef1, + aggFoldMinAccumRef2, + ), + ) + + AggFold(minAndMaxZero, minSeq, minComb, aggFoldMinAccumName1, aggFoldMinAccumName2, false) + } + } + + abstract class ApplyScanOpCompanionExt { + def apply(op: AggOp, initOpArgs: IR*)(seqOpArgs: IR*): ApplyScanOp = + ApplyScanOp( + initOpArgs.toIndexedSeq, + seqOpArgs.toIndexedSeq, + AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ)), + ) + } + + trait ApplyScanOpExt { self: ApplyScanOp => + def nSeqOpArgs = seqOpArgs.length + + def nInitArgs = initOpArgs.length + + def op: AggOp = aggSig.op + } + + abstract class ResultOpCompanionExt { + def makeTuple(aggs: IndexedSeq[PhysicalAggSig]) = + MakeTuple.ordered(aggs.zipWithIndex.map { case (aggSig, index) => + ResultOp(index, aggSig) + }) + } + + abstract class MakeTupleCompanionExt { + def ordered(types: IndexedSeq[IR]): MakeTuple = MakeTuple(types.zipWithIndex.map { + case (ir, i) => + (i, ir) + }) + } + + abstract class InCompanionExt { + def apply(i: Int, typ: Type): In = In( + i, + SingleCodeEmitParamType( + false, + typ match { + case TInt32 => Int32SingleCodeType + case TInt64 => Int64SingleCodeType + case TFloat32 => Float32SingleCodeType + case TFloat64 => Float64SingleCodeType + case TBoolean => BooleanSingleCodeType + case _: TStream => throw new UnsupportedOperationException + case t => PTypeReferenceSingleCodeType(PType.canonical(t)) + }, + ), + ) + } + + abstract class DieCompanionExt { + def apply(message: String, typ: Type): Die = Die(Str(message), typ, ErrorIDs.NO_ERROR) + + def apply(message: String, typ: Type, errorId: Int): Die = Die(Str(message), typ, errorId) + } + + trait ApplyIRExt { self: ApplyIR => + var conversion: (Seq[Type], IndexedSeq[IR], Int) => IR = _ + var inline: Boolean = _ + + private lazy val refs = args.map(a => Ref(freshName(), a.typ)).toArray + lazy val body: IR = conversion(typeArgs, refs, errorID).deepCopy() + lazy val refIdx: Map[Name, Int] = refs.map(_.name).zipWithIndex.toMap + + lazy val explicitNode: IR = { + val ir = Let(refs.map(_.name).zip(args), body) + assert(ir.typ == returnType) + ir + } + } + } } diff --git a/hail/hail/src/is/hail/expr/ir/Interpret.scala b/hail/hail/src/is/hail/expr/ir/Interpret.scala index 1c0328804b2..ea964ef995e 100644 --- a/hail/hail/src/is/hail/expr/ir/Interpret.scala +++ b/hail/hail/src/is/hail/expr/ir/Interpret.scala @@ -862,14 +862,7 @@ object Interpret { val (rt, f) = functionMemo.getOrElseUpdate( ir, { val in = Ref(freshName(), argTuple.virtualType) - val wrappedArgs: IndexedSeq[BaseIR] = ir.args.zipWithIndex.map { case (_, i) => - GetTupleElement(in, i) - }.toFastSeq - val newChildren = ir match { - case _: ApplySeeded => wrappedArgs :+ NA(TRNGState) - case _ => wrappedArgs - } - val wrappedIR = Copy(ir, newChildren) + val wrappedIR = ir.mapChildrenWithIndex { case (_, i) => GetTupleElement(in, i) } val (rt, makeFunction) = Compile[AsmFunction2RegionLongLong]( ctx, @@ -1067,7 +1060,8 @@ object Interpret { } val rv = value.rvd.combine[WrappedByteArray, RegionValue]( - ctx, mkZero, itF, read, write, combOpF, isCommutative, useTreeAggregate) + ctx, mkZero, itF, read, write, combOpF, isCommutative, useTreeAggregate, + ) val (Some(PTypeReferenceSingleCodeType(rTyp: PTuple)), f) = CompileWithAggregators[AsmFunction2RegionLongLong]( diff --git a/hail/hail/src/is/hail/expr/ir/MatrixWriter.scala b/hail/hail/src/is/hail/expr/ir/MatrixWriter.scala index 812e6175216..cf2e7d23edd 100644 --- a/hail/hail/src/is/hail/expr/ir/MatrixWriter.scala +++ b/hail/hail/src/is/hail/expr/ir/MatrixWriter.scala @@ -401,7 +401,8 @@ case class MatrixNativeWriter( val components = MatrixNativeWriter.generateComponentFunctions( colsFieldName, entriesFieldName, colKey, ctx, tablestage, r, - path, overwrite, stageLocally, codecSpecJSONStr, partitions, partitionsTypeStr) + path, overwrite, stageLocally, codecSpecJSONStr, partitions, partitionsTypeStr, + ) Begin(FastSeq( components.setup, diff --git a/hail/hail/src/is/hail/expr/ir/Parser.scala b/hail/hail/src/is/hail/expr/ir/Parser.scala index 4c4a5203055..c99241beb1f 100644 --- a/hail/hail/src/is/hail/expr/ir/Parser.scala +++ b/hail/hail/src/is/hail/expr/ir/Parser.scala @@ -798,7 +798,7 @@ object IRParser { def apply_like( env: IRParserEnvironment, - cons: (String, Seq[Type], Seq[IR], Type, Int) => IR, + cons: (String, Seq[Type], IndexedSeq[IR], Type, Int) => IR, )( it: TokenIterator ): StackFrame[IR] = { @@ -1397,7 +1397,7 @@ object IRParser { args <- ir_value_children(env)(it) } yield ApplySeeded(function, args, rngState, staticUID, rt) case "ApplyIR" => - apply_like(env, ApplyIR)(it) + apply_like(env, ApplyIR.apply)(it) case "ApplySpecial" => apply_like(env, ApplySpecial)(it) case "Apply" => @@ -1820,7 +1820,8 @@ object IRParser { case "MatrixRead" => val requestedTypeRaw = it.head match { case x: IdentifierToken - if x.value == "None" || x.value == "DropColUIDs" || x.value == "DropRowUIDs" || x.value == "DropRowColUIDs" => + if x.value == "None" || x.value == "DropColUIDs" || x.value == "DropRowUIDs" || x + .value == "DropRowColUIDs" => consumeToken(it) Left(x.value) case _ => diff --git a/hail/hail/src/is/hail/expr/ir/TableIR.scala b/hail/hail/src/is/hail/expr/ir/TableIR.scala index 27404c71c01..f6b661b455b 100644 --- a/hail/hail/src/is/hail/expr/ir/TableIR.scala +++ b/hail/hail/src/is/hail/expr/ir/TableIR.scala @@ -3830,7 +3830,8 @@ case class TableKeyByAndAggregate( makeKey, seqOp, serializeAndCleanupAggs, - localBufferSize) + localBufferSize, + ) }.aggregateByKey(initAggs, nPartitions.getOrElse(prev.rvd.getNumPartitions))(combOp, combOp) val crdd = ContextRDD.weaken(rdd).cmapPartitionsWithIndex({ (i, ctx, it) => diff --git a/hail/hail/src/is/hail/expr/ir/TypeCheck.scala b/hail/hail/src/is/hail/expr/ir/TypeCheck.scala index 98bb166b171..937c8588100 100644 --- a/hail/hail/src/is/hail/expr/ir/TypeCheck.scala +++ b/hail/hail/src/is/hail/expr/ir/TypeCheck.scala @@ -345,7 +345,7 @@ object TypeCheck { assert(key.forall(structType.hasField)) case x @ StreamMap(a, _, body) => assert(a.typ.isInstanceOf[TStream]) - assert(x.elementTyp == body.typ) + assert(x.typ.elementType == body.typ) case x @ StreamZip(as, names, body, _, _) => assert(as.length == names.length) assert(x.typ.elementType == body.typ) diff --git a/hail/hail/src/is/hail/expr/ir/functions/ArrayFunctions.scala b/hail/hail/src/is/hail/expr/ir/functions/ArrayFunctions.scala index 582e238668b..7cada291e5b 100644 --- a/hail/hail/src/is/hail/expr/ir/functions/ArrayFunctions.scala +++ b/hail/hail/src/is/hail/expr/ir/functions/ArrayFunctions.scala @@ -55,14 +55,14 @@ object ArrayFunctions extends RegistryFunctions { tnum("T"), TFloat64, (ir1: IR, ir2: IR, errorID: Int) => - Apply("pow", Seq(), Seq(ir1, ir2), TFloat64, errorID), + Apply("pow", Seq(), FastSeq(ir1, ir2), TFloat64, errorID), ), ( "mod", tnum("T"), tv("T"), (ir1: IR, ir2: IR, errorID: Int) => - Apply("mod", Seq(), Seq(ir1, ir2), ir2.typ, errorID), + Apply("mod", Seq(), FastSeq(ir1, ir2), ir2.typ, errorID), ), ) diff --git a/hail/hail/src/is/hail/expr/ir/functions/Functions.scala b/hail/hail/src/is/hail/expr/ir/functions/Functions.scala index d171fd351e2..21c5f3c97d6 100644 --- a/hail/hail/src/is/hail/expr/ir/functions/Functions.scala +++ b/hail/hail/src/is/hail/expr/ir/functions/Functions.scala @@ -178,7 +178,7 @@ object IRFunctionRegistry { : Option[(Seq[IR], IR) => IR] = lookupFunction(name, returnType, Array.empty[Type], TRNGState +: arguments) .map { f => (irArguments: Seq[IR], rngState: IR) => - ApplySeeded(name, irArguments, rngState, staticUID, f.returnType.subst()) + ApplySeeded(name, irArguments.toFastSeq, rngState, staticUID, f.returnType.subst()) } def lookupUnseeded(name: String, returnType: Type, arguments: Seq[Type]) @@ -194,7 +194,7 @@ object IRFunctionRegistry { val validIR: Option[IRFunctionImplementation] = lookupIR(name, returnType, typeParameters, arguments).map { case ((_, _, _, inline), conversion) => (typeParametersPassed, args, errorID) => - val x = ApplyIR(name, typeParametersPassed, args, returnType, errorID) + val x = ApplyIR(name, typeParametersPassed, args.toFastSeq, returnType, errorID) x.conversion = conversion x.inline = inline x @@ -205,9 +205,21 @@ object IRFunctionRegistry { { (irValueParametersTypes: Seq[Type], irArguments: Seq[IR], errorID: Int) => f match { case _: UnseededMissingnessObliviousJVMFunction => - Apply(name, irValueParametersTypes, irArguments, f.returnType.subst(), errorID) + Apply( + name, + irValueParametersTypes, + irArguments.toFastSeq, + f.returnType.subst(), + errorID, + ) case _: UnseededMissingnessAwareJVMFunction => - ApplySpecial(name, irValueParametersTypes, irArguments, f.returnType.subst(), errorID) + ApplySpecial( + name, + irValueParametersTypes, + irArguments.toFastSeq, + f.returnType.subst(), + errorID, + ) } } } diff --git a/hail/hail/src/is/hail/expr/ir/lowering/LoweringPipeline.scala b/hail/hail/src/is/hail/expr/ir/lowering/LoweringPipeline.scala index fd3a329abd8..3fcd4e4d591 100644 --- a/hail/hail/src/is/hail/expr/ir/lowering/LoweringPipeline.scala +++ b/hail/hail/src/is/hail/expr/ir/lowering/LoweringPipeline.scala @@ -100,9 +100,7 @@ object LoweringPipeline { DArrayLowering.All, DArrayLowering.TableOnly, DArrayLowering.BMOnly, - ).map { lv => - (lv -> fullLoweringPipeline("darrayLowerer", LowerToDistributedArrayPass(lv))) - }.toMap + ).map(lv => (lv -> fullLoweringPipeline("darrayLowerer", LowerToDistributedArrayPass(lv)))).toMap private val _dArrayLowerersNoOpt = _dArrayLowerers.mapValues(_.noOptimization()).toMap diff --git a/hail/hail/src/is/hail/expr/ir/streams/EmitStream.scala b/hail/hail/src/is/hail/expr/ir/streams/EmitStream.scala index be3f971c52b..da7507d9a59 100644 --- a/hail/hail/src/is/hail/expr/ir/streams/EmitStream.scala +++ b/hail/hail/src/is/hail/expr/ir/streams/EmitStream.scala @@ -984,7 +984,8 @@ object EmitStream { ) val fC = cb.newLocal[Double]( "seq_sample_Fc", - (totalSizeVal.value - candidate - nRemaining).toD / (totalSizeVal.value - candidate).toD, + (totalSizeVal.value - candidate - nRemaining) + .toD / (totalSizeVal.value - candidate).toD, ) cb.while_( diff --git a/hail/hail/src/is/hail/io/index/IndexReader.scala b/hail/hail/src/is/hail/io/index/IndexReader.scala index 4ae410403cd..6f32cf7b3ea 100644 --- a/hail/hail/src/is/hail/io/index/IndexReader.scala +++ b/hail/hail/src/is/hail/io/index/IndexReader.scala @@ -42,7 +42,8 @@ object IndexReaderBuilder { (theHailClassLoader, fs, path, cacheCapacity, pool) => new IndexReader( theHailClassLoader, fs, path, cacheCapacity, leafDec, intDec, keyType, annotationType, - leafPType, intPType, pool, sm) + leafPType, intPType, pool, sm, + ) } } diff --git a/hail/hail/src/is/hail/io/index/IndexWriter.scala b/hail/hail/src/is/hail/io/index/IndexWriter.scala index 20a2e974240..9d26f1351ce 100644 --- a/hail/hail/src/is/hail/io/index/IndexWriter.scala +++ b/hail/hail/src/is/hail/io/index/IndexWriter.scala @@ -55,7 +55,8 @@ case class IndexMetadataUntypedJSON( def toMetadata(keyType: Type, annotationType: Type): IndexMetadata = IndexMetadata( fileVersion, branchingFactor, height, keyType, annotationType, - nKeys, indexPath, rootOffset, attributes) + nKeys, indexPath, rootOffset, attributes, + ) def toFileMetadata: VariableMetadata = VariableMetadata( branchingFactor, height, nKeys, rootOffset, attributes, diff --git a/hail/hail/src/is/hail/io/vcf/LoadVCF.scala b/hail/hail/src/is/hail/io/vcf/LoadVCF.scala index 8cb2f464bd3..36ac86e2551 100644 --- a/hail/hail/src/is/hail/io/vcf/LoadVCF.scala +++ b/hail/hail/src/is/hail/io/vcf/LoadVCF.scala @@ -1438,7 +1438,8 @@ object LoadVCF { filterAttrs, infoAttrs, formatAttrs, - infoFlagFields) + infoFlagFields, + ) } def getHeaderLines[T]( @@ -1772,7 +1773,8 @@ object MatrixVCFReader { files, callFields, entryFloatTypeName, headerFile, sampleIDs, nPartitions, blockSizeInMB, minPartitions, rg, contigRecoding, arrayElementsRequired, skipInvalidLoci, gzAsBGZ, forceGZ, filterAndReplace, - partitionsJSON, partitionsTypeStr), + partitionsJSON, partitionsTypeStr, + ), ) def apply(ctx: ExecuteContext, params: MatrixVCFReaderParameters): MatrixVCFReader = { diff --git a/hail/hail/src/is/hail/rvd/AbstractRVDSpec.scala b/hail/hail/src/is/hail/rvd/AbstractRVDSpec.scala index 32a2a75110f..7b036f27a0c 100644 --- a/hail/hail/src/is/hail/rvd/AbstractRVDSpec.scala +++ b/hail/hail/src/is/hail/rvd/AbstractRVDSpec.scala @@ -405,7 +405,8 @@ case class RVDSpecMaker( ais, partFiles, bounds, - attrs) + attrs, + ) case None => OrderedRVDSpec2( key, codecSpec, diff --git a/hail/hail/src/is/hail/types/virtual/BlockMatrixType.scala b/hail/hail/src/is/hail/types/virtual/BlockMatrixType.scala index 3c8fdc5ff28..0be80d599bd 100644 --- a/hail/hail/src/is/hail/types/virtual/BlockMatrixType.scala +++ b/hail/hail/src/is/hail/types/virtual/BlockMatrixType.scala @@ -308,9 +308,8 @@ case class BlockMatrixSparsity(definedBlocks: Option[IndexedSeq[(Int, Int)]]) { BlockMatrixSparsity(definedBlocks.map(_.map { case (i, j) => (j, i) })) override def toString: String = - definedBlocks.map { blocks => - blocks.map { case (i, j) => s"($i,$j)" }.mkString("[", ",", "]") - }.getOrElse("None") + definedBlocks.map(blocks => blocks.map { case (i, j) => s"($i,$j)" }.mkString("[", ",", "]")) + .getOrElse("None") } object BlockMatrixType { diff --git a/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala b/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala index abef7b459e9..2f57c423fda 100644 --- a/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/Aggregators2Suite.scala @@ -1056,7 +1056,8 @@ class Aggregators2Suite extends HailSuite { FastSeq( 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 3.0, - 0.0, 0.0, 1.0, 3.0, 6.0), + 0.0, 0.0, 1.0, 3.0, 6.0, + ), ) } diff --git a/hail/hail/test/src/is/hail/expr/ir/FoldConstantsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/FoldConstantsSuite.scala index 422985d1b55..ba9e5a89b8f 100644 --- a/hail/hail/test/src/is/hail/expr/ir/FoldConstantsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/FoldConstantsSuite.scala @@ -5,12 +5,13 @@ import is.hail.expr.ir.defs.{ AggLet, ApplyAggOp, ApplyScanOp, ApplySeeded, F64, I32, I64, RNGStateLiteral, Str, } import is.hail.types.virtual.{TFloat64, TInt32} +import is.hail.utils.FastSeq import org.testng.annotations.{DataProvider, Test} class FoldConstantsSuite extends HailSuite { @Test def testRandomBlocksFolding(): Unit = { - val x = ApplySeeded("rand_norm", Seq(F64(0d), F64(0d)), RNGStateLiteral(), 0L, TFloat64) + val x = ApplySeeded("rand_norm", FastSeq(F64(0d), F64(0d)), RNGStateLiteral(), 0L, TFloat64) assert(FoldConstants(ctx, x) == x) } diff --git a/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala b/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala index 54b085d09c7..f142e52a4cf 100644 --- a/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/IRSuite.scala @@ -1964,7 +1964,8 @@ class IRSuite extends HailSuite { val nd1 = ( FastSeq( 0, 1, 2, - 3, 4, 5), + 3, 4, 5, + ), 2L, 3L, ) @@ -1973,7 +1974,8 @@ class IRSuite extends HailSuite { FastSeq( 6, 7, 8, 9, 10, 11, - 12, 13, 14), + 12, 13, 14, + ), 3L, 3L, ) diff --git a/hail/hail/test/src/is/hail/expr/ir/EncodedLiteralSuite.scala b/hail/hail/test/src/is/hail/expr/ir/defs/EncodedLiteralSuite.scala similarity index 88% rename from hail/hail/test/src/is/hail/expr/ir/EncodedLiteralSuite.scala rename to hail/hail/test/src/is/hail/expr/ir/defs/EncodedLiteralSuite.scala index 78fb54ea776..53b5ad147db 100644 --- a/hail/hail/test/src/is/hail/expr/ir/EncodedLiteralSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/defs/EncodedLiteralSuite.scala @@ -1,7 +1,6 @@ -package is.hail.expr.ir +package is.hail.expr.ir.defs import is.hail.HailSuite -import is.hail.expr.ir.defs.WrappedByteArrays import org.testng.annotations.Test diff --git a/hail/hail/test/src/is/hail/io/IndexBTreeSuite.scala b/hail/hail/test/src/is/hail/io/IndexBTreeSuite.scala index 4682309ba0b..453d9e45f05 100644 --- a/hail/hail/test/src/is/hail/io/IndexBTreeSuite.scala +++ b/hail/hail/test/src/is/hail/io/IndexBTreeSuite.scala @@ -114,7 +114,8 @@ class IndexBTreeSuite extends HailSuite { 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 4, - 0, 0, 0, 0, 0, 0, 0, 3) + 0, 0, 0, 0, 0, 0, 0, 3, + ) assert(IndexBTree.btreeBytes(in, branchingFactor = 8) sameElements bigEndianBytes) } diff --git a/hail/hail/test/src/is/hail/io/IndexSuite.scala b/hail/hail/test/src/is/hail/io/IndexSuite.scala index bfaba257e68..384cc79a2aa 100644 --- a/hail/hail/test/src/is/hail/io/IndexSuite.scala +++ b/hail/hail/test/src/is/hail/io/IndexSuite.scala @@ -16,14 +16,16 @@ class IndexSuite extends HailSuite { "lion", "mouse", "parrot", "quail", "rabbit", "raccoon", "rat", "raven", "skunk", "snail", "squirrel", "vole", - "weasel", "whale", "yak", "zebra") + "weasel", "whale", "yak", "zebra", + ) val stringsWithDups = Array( "bear", "bear", "cat", "cat", "cat", "cat", "cat", "cat", "cat", "dog", "mouse", "mouse", "skunk", "skunk", "skunk", "whale", - "whale", "zebra", "zebra", "zebra") + "whale", "zebra", "zebra", "zebra", + ) val leafsWithDups = stringsWithDups.zipWithIndex.map { case (s, i) => LeafChild(s, i, Row()) } diff --git a/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala b/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala index e32d0c7038e..af7f09e94c7 100644 --- a/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala +++ b/hail/hail/test/src/is/hail/linalg/BlockMatrixSuite.scala @@ -128,7 +128,8 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16), + 13, 14, 15, 16, + ), ) val expected = toLM( @@ -138,7 +139,8 @@ class BlockMatrixSuite extends HailSuite { 0, -3, -6, -9, 3, 0, -3, -6, 6, 3, 0, -3, - 9, 6, 3, 0), + 9, 6, 3, 0, + ), ) val actual = (m - m.T).toBreezeMatrix() @@ -154,7 +156,8 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16), + 13, 14, 15, 16, + ), ) val l = toBM(ll) @@ -250,7 +253,8 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16), + 13, 14, 15, 16, + ), ) val v = Array[Double](1, 2, 3, 4) @@ -262,7 +266,8 @@ class BlockMatrixSuite extends HailSuite { 1, 4, 9, 16, 5, 12, 21, 32, 9, 20, 33, 48, - 13, 28, 45, 64), + 13, 28, 45, 64, + ), ) assert(l.rowVectorMul(v).toBreezeMatrix() == result) @@ -294,7 +299,8 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16), + 13, 14, 15, 16, + ), ) val v = Array[Double](1, 2, 3, 4) @@ -306,7 +312,8 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 10, 12, 14, 16, 27, 30, 33, 36, - 52, 56, 60, 64), + 52, 56, 60, 64, + ), ) assert(l.colVectorMul(v).toBreezeMatrix() == result) @@ -343,7 +350,8 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16), + 13, 14, 15, 16, + ), ) val v = Array[Double](1, 2, 3, 4) @@ -355,7 +363,8 @@ class BlockMatrixSuite extends HailSuite { 2, 3, 4, 5, 7, 8, 9, 10, 12, 13, 14, 15, - 17, 18, 19, 20), + 17, 18, 19, 20, + ), ) assert(l.colVectorAdd(v).toBreezeMatrix() == result) @@ -370,7 +379,8 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16), + 13, 14, 15, 16, + ), ) val v = Array[Double](1, 2, 3, 4) @@ -382,7 +392,8 @@ class BlockMatrixSuite extends HailSuite { 2, 4, 6, 8, 6, 8, 10, 12, 10, 12, 14, 16, - 14, 16, 18, 20), + 14, 16, 18, 20, + ), ) assert(l.rowVectorAdd(v).toBreezeMatrix() == result) @@ -396,7 +407,8 @@ class BlockMatrixSuite extends HailSuite { Array[Double]( 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12), + 9, 10, 11, 12, + ), ) val m = toBM(lm, blockSize = 2) @@ -442,7 +454,8 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16), + 13, 14, 15, 16, + ), ) val fname = ctx.createTmpPath("test") @@ -463,7 +476,8 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16), + 13, 14, 15, 16, + ), ) val fname = ctx.createTmpPath("test") @@ -565,14 +579,16 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, - 7, 8), + 7, 8, + ), ) val lmt = toLM( 2, 4, Array[Double]( 1, 3, 5, 7, - 2, 4, 6, 8), + 2, 4, 6, 8, + ), ) val m = toBM(lm) @@ -594,14 +610,16 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, - 7, 8), + 7, 8, + ), ) val lmt = toLM( 2, 4, Array[Double]( 1, 3, 5, 7, - 2, 4, 6, 8), + 2, 4, 6, 8, + ), ) val m = toBM(lm) @@ -620,14 +638,16 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, - 7, 8), + 7, 8, + ), ) val lmt = toLM( 2, 4, Array[Double]( 1, 3, 5, 7, - 2, 4, 6, 8), + 2, 4, 6, 8, + ), ) val m = toBM(lm) @@ -647,14 +667,16 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, - 7, 8), + 7, 8, + ), ) val lmt = toLM( 2, 4, Array[Double]( 1, 3, 5, 7, - 2, 4, 6, 8), + 2, 4, 6, 8, + ), ) val m = toBM(lm) @@ -683,14 +705,16 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, - 7, 8), + 7, 8, + ), ) val lmt = toLM( 2, 4, Array[Double]( 1, 3, 5, 7, - 2, 4, 6, 8), + 2, 4, 6, 8, + ), ) val m = toBM(lm) @@ -984,7 +1008,8 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16), + 13, 14, 15, 16, + ), ) val bm = toBM(lm, blockSize = 2) @@ -1031,7 +1056,8 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16), + 13, 14, 15, 16, + ), ) val bm = toBM(lm, blockSize = 2) @@ -1084,7 +1110,8 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16), + 13, 14, 15, 16, + ), ) val bm = toBM(lm, blockSize = 2) @@ -1238,7 +1265,8 @@ class BlockMatrixSuite extends HailSuite { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16), + 13, 14, 15, 16, + ), ) val bm = toBM(lm, blockSize = 2) diff --git a/hail/hail/test/src/is/hail/utils/UtilsSuite.scala b/hail/hail/test/src/is/hail/utils/UtilsSuite.scala index e1330acb38b..46840c06724 100644 --- a/hail/hail/test/src/is/hail/utils/UtilsSuite.scala +++ b/hail/hail/test/src/is/hail/utils/UtilsSuite.scala @@ -103,7 +103,8 @@ class UtilsSuite extends HailSuite { "NONE", "DISK_ONLY", "DISK_ONLY_2", "MEMORY_ONLY", "MEMORY_ONLY_2", "MEMORY_ONLY_SER", "MEMORY_ONLY_SER_2", "MEMORY_AND_DISK", "MEMORY_AND_DISK_2", "MEMORY_AND_DISK_SER", "MEMORY_AND_DISK_SER_2", - "OFF_HEAP") + "OFF_HEAP", + ) sls.foreach(sl => StorageLevel.fromString(sl)) }