diff --git a/hail/build.mill b/hail/build.mill index a0c8cb1750f8..f047defbf22e 100644 --- a/hail/build.mill +++ b/hail/build.mill @@ -107,7 +107,7 @@ trait HailScalaModule extends ScalafmtModule with ScalafixModule { outer => override def bspCompileClasspath: T[Agg[UnresolvedPath]] = super.bspCompileClasspath() ++ resources().map(p => UnresolvedPath.ResolvedPath(p.path)) - trait HailTests extends ScalafmtModule with ScalafixModule with TestNg { + trait HailTests extends ScalaTests with TestNg with ScalafmtModule with ScalafixModule { override def forkArgs: T[Seq[String]] = Seq("-Xss4m", "-Xmx4096M") override def ivyDeps: T[Agg[Dep]] = @@ -179,6 +179,10 @@ object `package` extends RootModule with SbtModule with HailScalaModule { outer buildInfo(), ) + override def generatedSources: T[Seq[PathRef]] = Task { + Seq(`ir-gen`.generate()) + } + override def unmanagedClasspath: T[Agg[PathRef]] = Agg(shadedazure.assembly()) @@ -250,6 +254,19 @@ object `package` extends RootModule with SbtModule with HailScalaModule { outer PathRef(T.dest) } + object `ir-gen` extends HailScalaModule { + def ivyDeps = Agg( + ivy"com.lihaoyi::mainargs:0.6.2", + ivy"com.lihaoyi::os-lib:0.10.7", + ivy"com.lihaoyi::sourcecode:0.4.2", + ) + + def generate: T[PathRef] = Task { + runForkedTask(finalMainClass, Task.Anon { Args("--path", T.dest) })() + PathRef(T.dest) + } + } + object memory extends JavaModule { // with CrossValue { override def zincIncrementalCompilation: T[Boolean] = false diff --git a/hail/modules/ir-gen/src/Main.scala b/hail/modules/ir-gen/src/Main.scala new file mode 100644 index 000000000000..a0ec0f4e581c --- /dev/null +++ b/hail/modules/ir-gen/src/Main.scala @@ -0,0 +1,65 @@ +import mainargs.{ParserForMethods, main} + +sealed abstract class Trait(val name: String) + +object Trivial extends Trait("TrivialIR") + +sealed abstract class AttOrChild { + val name: String + def generateDeclaration: String +} + +final case class Att(name: String, typ: String) extends AttOrChild { + override def generateDeclaration: String = s"$name: $typ" +} + +final case class Child(name: String) extends AttOrChild { + override def generateDeclaration: String = s"$name: IR" +} + +case class IR( + name: String, + attsAndChildren: Seq[AttOrChild], + traits: Seq[Trait] = Seq.empty, + extraMethods: Seq[String] = Seq.empty, +) { + def withTraits(newTraits: Trait*): IR = copy(traits = traits ++ newTraits) + def withMethod(methodDef: String): IR = copy(extraMethods = extraMethods :+ methodDef) + + def generateDef: String = + (s"final case class $name(${attsAndChildren.map(_.generateDeclaration).mkString(",")}) extends IR" + + traits.map(" with " + _.name).mkString + + (if (extraMethods.nonEmpty) + extraMethods.map(" " + _).mkString(" {\n", "\n", "\n}") + else "")) +} + +object Main { + def node(name: String, attsAndChildren: AttOrChild*): IR = IR(name, attsAndChildren) + + def allNodes: Seq[IR] = { + val r = Seq.newBuilder[IR] + + r += node("I32", Att("x", "Int")).withTraits(Trivial) + r += node("I64", Att("x", "Long")).withTraits(Trivial) + r += node("F32", Att("x", "Float")).withTraits(Trivial) + r += node("F64", Att("x", "Double")).withTraits(Trivial) + r += node("Str", Att("x", "String")).withTraits(Trivial) + .withMethod( + "override def toString(): String = s\"\"\"Str(\"${StringEscapeUtils.escapeString(x)}\")\"\"\"" + ) + r += node("True").withTraits(Trivial) + r += node("False").withTraits(Trivial) + r += node("Void").withTraits(Trivial) + + r.result() + } + + @main + def main(path: String) = { + val gen = "package is.hail.expr.ir\n\n" + allNodes.map(_.generateDef).mkString("\n") + os.write(os.Path(path) / "IR_gen.scala", gen) + } + + def main(args: Array[String]): Unit = ParserForMethods(this).runOrExit(args) +} diff --git a/hail/modules/src/main/scala/is/hail/expr/ir/IR.scala b/hail/modules/src/main/scala/is/hail/expr/ir/IR.scala index 5f432ab2dd5c..fbec49f7dade 100644 --- a/hail/modules/src/main/scala/is/hail/expr/ir/IR.scala +++ b/hail/modules/src/main/scala/is/hail/expr/ir/IR.scala @@ -27,7 +27,7 @@ import java.io.OutputStream import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} import org.json4s.JsonAST.{JNothing, JString} -sealed trait IR extends BaseIR { +trait IR extends BaseIR { private var _typ: Type = null def typ: Type = { @@ -72,12 +72,12 @@ sealed trait IR extends BaseIR { def unwrap: IR = _unwrap(this) } -sealed trait TypedIR[T <: Type] extends IR { +trait TypedIR[T <: Type] extends IR { override def typ: T = tcoerce[T](super.typ) } // Mark Refs and constants as IRs that are safe to duplicate -sealed trait TrivialIR extends IR +trait TrivialIR extends IR object Literal { def coerce(t: Type, x: Any): IR = { @@ -146,18 +146,18 @@ class WrappedByteArrays(val ba: Array[Array[Byte]]) { } } -final case class I32(x: Int) extends IR with TrivialIR -final case class I64(x: Long) extends IR with TrivialIR -final case class F32(x: Float) extends IR with TrivialIR -final case class F64(x: Double) extends IR with TrivialIR +//final case class I32(x: Int) extends IR with TrivialIR +//final case class I64(x: Long) extends IR with TrivialIR +//final case class F32(x: Float) extends IR with TrivialIR +//final case class F64(x: Double) extends IR with TrivialIR -final case class Str(x: String) extends IR with TrivialIR { - override def toString(): String = s"""Str("${StringEscapeUtils.escapeString(x)}")""" -} +//final case class Str(x: String) extends IR with TrivialIR { +// override def toString(): String = s"""Str("${StringEscapeUtils.escapeString(x)}")""" +//} -final case class True() extends IR with TrivialIR -final case class False() extends IR with TrivialIR -final case class Void() extends IR with TrivialIR +//final case class True() extends IR with TrivialIR +//final case class False() extends IR with TrivialIR +//final case class Void() extends IR with TrivialIR object UUID4 { def apply(): UUID4 = UUID4(genUID())