Skip to content

Commit

Permalink
ir gen mvp
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Jan 3, 2025
1 parent c066895 commit 0fc810b
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 46 deletions.
17 changes: 17 additions & 0 deletions hail/build.mill
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ object hail extends HailModule { outer =>
buildInfo(),
)

override def generatedSources: T[Seq[PathRef]] = Task {
Seq(`ir-gen`.generate())
}

override def unmanagedClasspath: T[Agg[PathRef]] =
Agg(shadedazure.assembly())

Expand Down Expand Up @@ -246,6 +250,19 @@ object hail extends HailModule { outer =>
PathRef(T.dest)
}

object `ir-gen` extends HailModule {
def ivyDeps = Agg(
ivy"com.lihaoyi::mainargs:0.6.2",
ivy"com.lihaoyi::os-lib:0.10.7",
ivy"com.lihaoyi::sourcecode:0.4.2",
)

def generate: T[PathRef] = Task {
runForkedTask(finalMainClass, Task.Anon { Args("--path", T.dest) })()
PathRef(T.dest)
}
}

object memory extends JavaModule { // with CrossValue {
override def zincIncrementalCompilation: T[Boolean] = false

Expand Down
154 changes: 154 additions & 0 deletions hail/hail/ir-gen/src/Main.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import mainargs.{ParserForMethods, main}

sealed abstract class Trait(val name: String)

object Trivial extends Trait("TrivialIR")
object BaseRef extends Trait("BaseRef")

case class NChildren(static: Int = 0, dynamic: String = "") {
def +(other: NChildren): NChildren = NChildren(
static = static + other.static,
dynamic = if (dynamic.isEmpty) other.dynamic else s"$dynamic + ${other.dynamic}",
)
}

sealed abstract class AttOrChild {
val name: String
def generateDeclaration: String
def constraints: Seq[String] = Seq.empty
def nChildren: NChildren = NChildren()
}

final case class Att(name: String, typ: String, isVar: Boolean = false) extends AttOrChild {
override def generateDeclaration: String = s"${if (isVar) "var " else ""}$name: $typ"
}

final case class Child(name: String) extends AttOrChild {
override def generateDeclaration: String = s"$name: IR"
override def nChildren: NChildren = NChildren(static = 1)
}

final case class ChildPlus(name: String) extends AttOrChild {
override def generateDeclaration: String = s"$name: IndexedSeq[IR]"
override def constraints: Seq[String] = Seq(s"$name.nonEmpty")
override def nChildren: NChildren = NChildren(dynamic = "name.size")
}

final case class ChildStar(name: String) extends AttOrChild {
override def generateDeclaration: String = s"$name: IndexedSeq[IR]"
override def nChildren: NChildren = NChildren(dynamic = "name.size")
}

case class IR(
name: String,
attsAndChildren: Seq[AttOrChild],
traits: Seq[Trait] = Seq.empty,
extraMethods: Seq[String] = Seq.empty,
applyMethods: Seq[String] = Seq.empty,
docstring: String = "",
) {
def withTraits(newTraits: Trait*): IR = copy(traits = traits ++ newTraits)
def withMethod(methodDef: String): IR = copy(extraMethods = extraMethods :+ methodDef)
def withApply(methodDef: String): IR = copy(applyMethods = applyMethods :+ methodDef)
def withDocstring(docstring: String): IR = copy(docstring = docstring)

private def nChildren: NChildren = attsAndChildren.foldLeft(NChildren())(_ + _.nChildren)

private def children: String = {
val tmp = attsAndChildren.flatMap {
case _: Att => None
case c: Child => Some(s"FastSeq(${c.name})")
case cs: ChildPlus => Some(cs.name)
case cs: ChildStar => Some(cs.name)
}
if (tmp.isEmpty) "FastSeq.empty" else tmp.mkString(" ++ ")
}

private def paramList = s"$name(${attsAndChildren.map(_.generateDeclaration).mkString(", ")})"

private def classDecl =
s"final case class $paramList extends IR" + traits.map(" with " + _.name).mkString

private def classBody = {
val extraMethods =
this.extraMethods :+ s"override lazy val childrenSeq: IndexedSeq[IR] = $children"
val constraints = attsAndChildren.flatMap(_.constraints)
if (constraints.nonEmpty || extraMethods.nonEmpty) {
(
" {" +
(if (constraints.nonEmpty)
constraints.map(c => s" require($c)").mkString("\n", "\n", "\n")
else "")
+ (
if (extraMethods.nonEmpty)
extraMethods.map(" " + _).mkString("\n", "\n", "\n")
else ""
)
+ "}"
)
} else ""
}

private def classDef =
(if (docstring.nonEmpty) s"\n// $docstring\n" else "") + classDecl + classBody

private def companionBody = applyMethods.map(" " + _).mkString("\n")

private def companionDef =
if (companionBody.isEmpty) "" else s"object $name {\n$companionBody\n}\n"

def generateDef: String = companionDef + classDef + "\n"
}

object Main {
def node(name: String, attsAndChildren: AttOrChild*): IR = IR(name, attsAndChildren)

def allNodes: Seq[IR] = {
val r = Seq.newBuilder[IR]

r += node("I32", Att("x", "Int")).withTraits(Trivial)
r += node("I64", Att("x", "Long")).withTraits(Trivial)
r += node("F32", Att("x", "Float")).withTraits(Trivial)
r += node("F64", Att("x", "Double")).withTraits(Trivial)
r += node("Str", Att("x", "String")).withTraits(Trivial)
.withMethod(
"override def toString(): String = s\"\"\"Str(\"${StringEscapeUtils.escapeString(x)}\")\"\"\""
)
r += node("True").withTraits(Trivial)
r += node("False").withTraits(Trivial)
r += node("Void").withTraits(Trivial)
r += node("NA", Att("_typ", "Type")).withTraits(Trivial)
r += node("UUID4", Att("id", "String"))
.withDocstring(
"WARNING! This node can only be used when trying to append a one-off, "
+ "random string that will not be reused elsewhere in the pipeline. "
+ "Any other uses will need to write and then read again; this node is non-deterministic "
+ "and will not e.g. exhibit the correct semantics when self-joining on streams."
)
.withApply("def apply(): UUID4 = UUID4(genUID())")
r += node("Cast", Child("v"), Att("_typ", "Type"))
r += node("CastRename", Child("v"), Att("_typ", "Type"))
r += node("IsNA", Child("value"))
r += node("Coalesce", ChildPlus("values"))
r += node("Consume", Child("value"))
r += node("If", Child("cond"), Child("cnsq"), Child("altr"))
r += node("Switch", Child("x"), Child("default"), ChildStar("cases"))
.withMethod("override lazy val size: Int = 2 + cases.length")
r += IR("Ref", Seq(Att("name", "Name"), Att("_typ", "Type", isVar = true)))
.withTraits(BaseRef)

r.result()
}

@main
def main(path: String) = {
val pack = "package is.hail.expr.ir"
val imports = Seq("is.hail.types.virtual.Type", "is.hail.utils.{FastSeq, StringEscapeUtils}")
val gen = pack + "\n\n" + imports.map(i => s"import $i").mkString("\n") + "\n\n" + allNodes.map(
_.generateDef
).mkString("\n")
os.write(os.Path(path) / "IR_gen.scala", gen)
}

def main(args: Array[String]): Unit = ParserForMethods(this).runOrExit(args)
}
88 changes: 42 additions & 46 deletions hail/hail/src/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,22 @@ import java.io.OutputStream
import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints}
import org.json4s.JsonAST.{JNothing, JString}

sealed trait IR extends BaseIR {
trait IR extends BaseIR {
private var _typ: Type = null

def typ: Type = {
if (_typ == null)
override def typ: Type = {
if (_typ == null) {
try
_typ = InferType(this)
catch {
case e: Throwable => throw new RuntimeException(s"typ: inference failure:", e)
}
assert(_typ != null)
}
_typ
}

protected lazy val childrenSeq: IndexedSeq[BaseIR] =
override protected lazy val childrenSeq: IndexedSeq[BaseIR] =
Children(this)

override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): IR =
Expand Down Expand Up @@ -72,12 +74,12 @@ sealed trait IR extends BaseIR {
def unwrap: IR = _unwrap(this)
}

sealed trait TypedIR[T <: Type] extends IR {
trait TypedIR[T <: Type] extends IR {
override def typ: T = tcoerce[T](super.typ)
}

// Mark Refs and constants as IRs that are safe to duplicate
sealed trait TrivialIR extends IR
trait TrivialIR extends IR

object Literal {
def coerce(t: Type, x: Any): IR = {
Expand Down Expand Up @@ -146,48 +148,47 @@ class WrappedByteArrays(val ba: Array[Array[Byte]]) {
}
}

final case class I32(x: Int) extends IR with TrivialIR
final case class I64(x: Long) extends IR with TrivialIR
final case class F32(x: Float) extends IR with TrivialIR
final case class F64(x: Double) extends IR with TrivialIR
//final case class I32(x: Int) extends IR with TrivialIR
//final case class I64(x: Long) extends IR with TrivialIR
//final case class F32(x: Float) extends IR with TrivialIR
//final case class F64(x: Double) extends IR with TrivialIR

final case class Str(x: String) extends IR with TrivialIR {
override def toString(): String = s"""Str("${StringEscapeUtils.escapeString(x)}")"""
}
//final case class Str(x: String) extends IR with TrivialIR {
// override def toString(): String = s"""Str("${StringEscapeUtils.escapeString(x)}")"""
//}

final case class True() extends IR with TrivialIR
final case class False() extends IR with TrivialIR
final case class Void() extends IR with TrivialIR
//final case class True() extends IR with TrivialIR
//final case class False() extends IR with TrivialIR
//final case class Void() extends IR with TrivialIR

object UUID4 {
def apply(): UUID4 = UUID4(genUID())
}
//object UUID4 {
// def apply(): UUID4 = UUID4(genUID())
//}
//
//// WARNING! This node can only be used when trying to append a one-off,
//// random string that will not be reused elsewhere in the pipeline.
//// Any other uses will need to write and then read again; this node is
//// non-deterministic and will not e.g. exhibit the correct semantics when
//// self-joining on streams.
//final case class UUID4(id: String) extends IR

// WARNING! This node can only be used when trying to append a one-off,
// random string that will not be reused elsewhere in the pipeline.
// Any other uses will need to write and then read again; this node is
// non-deterministic and will not e.g. exhibit the correct semantics when
// self-joining on streams.
final case class UUID4(id: String) extends IR
//final case class Cast(v: IR, _typ: Type) extends IR
//final case class CastRename(v: IR, _typ: Type) extends IR

final case class Cast(v: IR, _typ: Type) extends IR
final case class CastRename(v: IR, _typ: Type) extends IR
//final case class NA(_typ: Type) extends IR with TrivialIR
//final case class IsNA(value: IR) extends IR

final case class NA(_typ: Type) extends IR with TrivialIR
final case class IsNA(value: IR) extends IR
//final case class Coalesce(values: Seq[IR]) extends IR {
// require(values.nonEmpty)
//}

final case class Coalesce(values: Seq[IR]) extends IR {
require(values.nonEmpty)
}
//final case class Consume(value: IR) extends IR

final case class Consume(value: IR) extends IR
//final case class If(cond: IR, cnsq: IR, altr: IR) extends IR

final case class If(cond: IR, cnsq: IR, altr: IR) extends IR

final case class Switch(x: IR, default: IR, cases: IndexedSeq[IR]) extends IR {
override lazy val size: Int =
2 + cases.length
}
//final case class Switch(x: IR, default: IR, cases: IndexedSeq[IR]) extends IR {
// override lazy val size: Int = 2 + cases.length
//}

object AggLet {
def apply(name: Name, value: IR, body: IR, isScan: Boolean): IR = {
Expand Down Expand Up @@ -238,17 +239,12 @@ object Block {
}
}

sealed abstract class BaseRef extends IR with TrivialIR {
trait BaseRef extends IR with TrivialIR {
def name: Name
def _typ: Type
}

final case class Ref(name: Name, var _typ: Type) extends BaseRef {
override def typ: Type = {
assert(_typ != null)
_typ
}
}
//final case class Ref(name: Name, var _typ: Type) extends BaseRef

// Recur can't exist outside of loop
// Loops can be nested, but we can't call outer loops in terms of inner loops so there can only be one loop "active" in a given context
Expand Down

0 comments on commit 0fc810b

Please sign in to comment.