Skip to content

Commit

Permalink
[query] unify backend rpc
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Sep 18, 2024
1 parent d86b932 commit 99e06b5
Show file tree
Hide file tree
Showing 29 changed files with 711 additions and 658 deletions.
3 changes: 0 additions & 3 deletions hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,6 @@ def add_liftover(self, name, chain_file, dest_reference_genome):
def remove_liftover(self, name, dest_reference_genome):
self._jbackend.pyRemoveLiftover(name, dest_reference_genome)

def index_bgen(self, files, index_file_map, referenceGenomeName, contig_recoding, skip_invalid_loci):
self._jbackend.pyIndexBgen(files, index_file_map, referenceGenomeName, contig_recoding, skip_invalid_loci)

def _parse_value_ir(self, code, ref_map={}):
return self._jbackend.parse_value_ir(
code,
Expand Down
9 changes: 4 additions & 5 deletions hail/src/main/scala/is/hail/HailContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.apache.spark.executor.InputMetrics
import org.apache.spark.rdd.RDD
import org.json4s.Extraction
import org.json4s.jackson.JsonMethods
import sourcecode.Enclosing

case class FilePartition(index: Int, file: String) extends Partition

Expand All @@ -41,7 +42,7 @@ object HailContext {

def backend: Backend = get.backend

def sparkBackend(op: String): SparkBackend = get.sparkBackend(op)
def sparkBackend(implicit E: Enclosing): SparkBackend = get.backend.asSpark

def configureLogging(logFile: String, quiet: Boolean, append: Boolean): Unit = {
org.apache.log4j.helpers.LogLog.setInternalDebugging(true)
Expand Down Expand Up @@ -152,7 +153,7 @@ object HailContext {

val fsBc = fs.broadcast

new RDD[T](SparkBackend.sparkContext("readPartition"), Nil) {
new RDD[T](SparkBackend.sparkContext, Nil) {
def getPartitions: Array[Partition] =
Array.tabulate(nPartitions)(i => FilePartition(i, partFiles(i)))

Expand All @@ -175,8 +176,6 @@ class HailContext private (
) {
def stop(): Unit = HailContext.stop()

def sparkBackend(op: String): SparkBackend = backend.asSpark(op)

var checkRVDKeys: Boolean = false

def version: String = is.hail.HAIL_PRETTY_VERSION
Expand All @@ -188,7 +187,7 @@ class HailContext private (
maxLines: Int,
): Map[String, Array[WithContext[String]]] = {
val regexp = regex.r
SparkBackend.sparkContext("fileAndLineCounts").textFilesLines(fs.globAll(files).map(_.getPath))
SparkBackend.sparkContext.textFilesLines(fs.globAll(files).map(_.getPath))
.filter(line => regexp.findFirstIn(line.value).isDefined)
.take(maxLines)
.groupBy(_.source.file)
Expand Down
94 changes: 13 additions & 81 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
package is.hail.backend

import is.hail.asm4s._
import is.hail.backend.Backend.jsonToBytes
import is.hail.asm4s.HailClassLoader
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{IR, IRParser, LoweringAnalyses, SortField, TableIR, TableReader}
import is.hail.expr.ir.{IR, LoweringAnalyses, SortField, TableIR, TableReader}
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
import is.hail.io.plink.LoadPlink
import is.hail.io.vcf.LoadVCF
import is.hail.types._
import is.hail.io.fs.FS
import is.hail.types.RTable
import is.hail.types.encoded.EType
import is.hail.types.physical.PTuple
import is.hail.types.virtual.TFloat64
import is.hail.utils._
import is.hail.variant.ReferenceGenome
import is.hail.utils.ExecutionTimer.Timings
import is.hail.utils.fatal

import scala.reflect.ClassTag

import java.io._
import java.io.{Closeable, OutputStream}
import java.nio.charset.StandardCharsets

import com.fasterxml.jackson.core.StreamReadConstraints
import org.json4s._
import org.json4s.JValue
import org.json4s.jackson.JsonMethods
import sourcecode.Enclosing

Expand All @@ -39,16 +35,15 @@ object Backend {
ctx: ExecuteContext,
t: PTuple,
off: Long,
bufferSpecString: String,
bufferSpec: BufferSpec,
os: OutputStream,
): Unit = {
val bs = BufferSpec.parseOrDefault(bufferSpecString)
assert(t.size == 1)
val elementType = t.fields(0).typ
val codec = TypedCodecSpec(
EType.fromPythonTypeEncoding(elementType.virtualType),
elementType.virtualType,
bs,
bufferSpec,
)
assert(t.isFieldDefined(off, 0))
codec.encode(ctx, elementType, t.loadField(off, 0), os)
Expand Down Expand Up @@ -96,8 +91,8 @@ abstract class Backend extends Closeable {

def close(): Unit

def asSpark(op: String): SparkBackend =
fatal(s"${getClass.getSimpleName}: $op requires SparkBackend")
def asSpark(implicit E: Enclosing): SparkBackend =
fatal(s"${getClass.getSimpleName}: ${E.value} requires SparkBackend")

def shouldCacheQueryInfo: Boolean = true

Expand Down Expand Up @@ -132,70 +127,7 @@ abstract class Backend extends Closeable {
def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses)
: TableStage

def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T

final def valueType(s: String): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
IRParser.parse_value_ir(ctx, s).typ.toJSON
}
}

final def tableType(s: String): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
IRParser.parse_table_ir(ctx, s).typ.toJSON
}
}

final def matrixTableType(s: String): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
IRParser.parse_matrix_ir(ctx, s).typ.toJSON
}
}

final def blockMatrixType(s: String): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
IRParser.parse_blockmatrix_ir(ctx, s).typ.toJSON
}
}

def loadReferencesFromDataset(path: String): Array[Byte]

def fromFASTAFile(
name: String,
fastaFile: String,
indexFile: String,
xContigs: Array[String],
yContigs: Array[String],
mtContigs: Array[String],
parInput: Array[String],
): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput).toJSON
}
}

def parseVCFMetadata(path: String): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
implicit val formats = defaultJSONFormats
Extraction.decompose(metadata)
}
}

def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String)
: Array[Byte] =
withExecuteContext { ctx =>
LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missingValue).getBytes(
StandardCharsets.UTF_8
)
}
def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings)

def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
}
Loading

0 comments on commit 99e06b5

Please sign in to comment.