Skip to content

Commit

Permalink
ir gen mvp
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Dec 17, 2024
1 parent 2e1deb3 commit 2558f9a
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 14 deletions.
19 changes: 18 additions & 1 deletion hail/build.mill
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ trait HailScalaModule extends ScalafmtModule with ScalafixModule { outer =>
override def bspCompileClasspath: T[Agg[UnresolvedPath]] =
super.bspCompileClasspath() ++ resources().map(p => UnresolvedPath.ResolvedPath(p.path))

trait HailTests extends ScalafmtModule with ScalafixModule with TestNg {
trait HailTests extends ScalaTests with TestNg with ScalafmtModule with ScalafixModule {
override def forkArgs: T[Seq[String]] = Seq("-Xss4m", "-Xmx4096M")

override def ivyDeps: T[Agg[Dep]] =
Expand Down Expand Up @@ -179,6 +179,10 @@ object `package` extends RootModule with SbtModule with HailScalaModule { outer
buildInfo(),
)

override def generatedSources: T[Seq[PathRef]] = Task {
Seq(`ir-gen`.generate())
}

override def unmanagedClasspath: T[Agg[PathRef]] =
Agg(shadedazure.assembly())

Expand Down Expand Up @@ -250,6 +254,19 @@ object `package` extends RootModule with SbtModule with HailScalaModule { outer
PathRef(T.dest)
}

object `ir-gen` extends HailScalaModule {
def ivyDeps = Agg(
ivy"com.lihaoyi::mainargs:0.6.2",
ivy"com.lihaoyi::os-lib:0.10.7",
ivy"com.lihaoyi::sourcecode:0.4.2",
)

def generate: T[PathRef] = Task {
runForkedTask(finalMainClass, Task.Anon { Args("--path", T.dest) })()
PathRef(T.dest)
}
}

object memory extends JavaModule { // with CrossValue {
override def zincIncrementalCompilation: T[Boolean] = false

Expand Down
65 changes: 65 additions & 0 deletions hail/modules/ir-gen/src/Main.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import mainargs.{ParserForMethods, main}

sealed abstract class Trait(val name: String)

object Trivial extends Trait("TrivialIR")

sealed abstract class AttOrChild {
val name: String
def generateDeclaration: String
}

final case class Att(name: String, typ: String) extends AttOrChild {
override def generateDeclaration: String = s"$name: $typ"
}

final case class Child(name: String) extends AttOrChild {
override def generateDeclaration: String = s"$name: IR"
}

case class IR(
name: String,
attsAndChildren: Seq[AttOrChild],
traits: Seq[Trait] = Seq.empty,
extraMethods: Seq[String] = Seq.empty,
) {
def withTraits(newTraits: Trait*): IR = copy(traits = traits ++ newTraits)
def withMethod(methodDef: String): IR = copy(extraMethods = extraMethods :+ methodDef)

def generateDef: String =
(s"final case class $name(${attsAndChildren.map(_.generateDeclaration).mkString(",")}) extends IR"
+ traits.map(" with " + _.name).mkString
+ (if (extraMethods.nonEmpty)
extraMethods.map(" " + _).mkString(" {\n", "\n", "\n}")
else ""))
}

object Main {
def node(name: String, attsAndChildren: AttOrChild*): IR = IR(name, attsAndChildren)

def allNodes: Seq[IR] = {
val r = Seq.newBuilder[IR]

r += node("I32", Att("x", "Int")).withTraits(Trivial)
r += node("I64", Att("x", "Long")).withTraits(Trivial)
r += node("F32", Att("x", "Float")).withTraits(Trivial)
r += node("F64", Att("x", "Double")).withTraits(Trivial)
r += node("Str", Att("x", "String")).withTraits(Trivial)
.withMethod(
"override def toString(): String = s\"\"\"Str(\"${StringEscapeUtils.escapeString(x)}\")\"\"\""
)
r += node("True").withTraits(Trivial)
r += node("False").withTraits(Trivial)
r += node("Void").withTraits(Trivial)

r.result()
}

@main
def main(path: String) = {
val gen = "package is.hail.expr.ir\n\n" + allNodes.map(_.generateDef).mkString("\n")
os.write(os.Path(path) / "IR_gen.scala", gen)
}

def main(args: Array[String]): Unit = ParserForMethods(this).runOrExit(args)
}
26 changes: 13 additions & 13 deletions hail/modules/src/main/scala/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import java.io.OutputStream
import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints}
import org.json4s.JsonAST.{JNothing, JString}

sealed trait IR extends BaseIR {
trait IR extends BaseIR {
private var _typ: Type = null

def typ: Type = {
Expand Down Expand Up @@ -72,12 +72,12 @@ sealed trait IR extends BaseIR {
def unwrap: IR = _unwrap(this)
}

sealed trait TypedIR[T <: Type] extends IR {
trait TypedIR[T <: Type] extends IR {
override def typ: T = tcoerce[T](super.typ)
}

// Mark Refs and constants as IRs that are safe to duplicate
sealed trait TrivialIR extends IR
trait TrivialIR extends IR

object Literal {
def coerce(t: Type, x: Any): IR = {
Expand Down Expand Up @@ -146,18 +146,18 @@ class WrappedByteArrays(val ba: Array[Array[Byte]]) {
}
}

final case class I32(x: Int) extends IR with TrivialIR
final case class I64(x: Long) extends IR with TrivialIR
final case class F32(x: Float) extends IR with TrivialIR
final case class F64(x: Double) extends IR with TrivialIR
//final case class I32(x: Int) extends IR with TrivialIR
//final case class I64(x: Long) extends IR with TrivialIR
//final case class F32(x: Float) extends IR with TrivialIR
//final case class F64(x: Double) extends IR with TrivialIR

final case class Str(x: String) extends IR with TrivialIR {
override def toString(): String = s"""Str("${StringEscapeUtils.escapeString(x)}")"""
}
//final case class Str(x: String) extends IR with TrivialIR {
// override def toString(): String = s"""Str("${StringEscapeUtils.escapeString(x)}")"""
//}

final case class True() extends IR with TrivialIR
final case class False() extends IR with TrivialIR
final case class Void() extends IR with TrivialIR
//final case class True() extends IR with TrivialIR
//final case class False() extends IR with TrivialIR
//final case class Void() extends IR with TrivialIR

object UUID4 {
def apply(): UUID4 = UUID4(genUID())
Expand Down

0 comments on commit 2558f9a

Please sign in to comment.