Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] unify backend rpc #14693

Open
wants to merge 1 commit into
base: ehigham/move-gen-into-tests
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions hail/hail/src/is/hail/HailContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.apache.spark.executor.InputMetrics
import org.apache.spark.rdd.RDD
import org.json4s.Extraction
import org.json4s.jackson.JsonMethods
import sourcecode.Enclosing

case class FilePartition(index: Int, file: String) extends Partition

Expand All @@ -41,7 +42,7 @@ object HailContext {

def backend: Backend = get.backend

def sparkBackend(op: String): SparkBackend = get.sparkBackend(op)
def sparkBackend(implicit E: Enclosing): SparkBackend = get.backend.asSpark

def configureLogging(logFile: String, quiet: Boolean, append: Boolean): Unit = {
org.apache.log4j.helpers.LogLog.setInternalDebugging(true)
Expand Down Expand Up @@ -152,7 +153,7 @@ object HailContext {

val fsBc = fs.broadcast

new RDD[T](SparkBackend.sparkContext("readPartition"), Nil) {
new RDD[T](SparkBackend.sparkContext, Nil) {
def getPartitions: Array[Partition] =
Array.tabulate(nPartitions)(i => FilePartition(i, partFiles(i)))

Expand All @@ -175,8 +176,6 @@ class HailContext private (
) {
def stop(): Unit = HailContext.stop()

def sparkBackend(op: String): SparkBackend = backend.asSpark(op)

var checkRVDKeys: Boolean = false

def version: String = is.hail.HAIL_PRETTY_VERSION
Expand All @@ -188,7 +187,7 @@ class HailContext private (
maxLines: Int,
): Map[String, Array[WithContext[String]]] = {
val regexp = regex.r
SparkBackend.sparkContext("fileAndLineCounts").textFilesLines(fs.globAll(files).map(_.getPath))
SparkBackend.sparkContext.textFilesLines(fs.globAll(files).map(_.getPath))
.filter(line => regexp.findFirstIn(line.value).isDefined)
.take(maxLines)
.groupBy(_.source.file)
Expand Down
98 changes: 12 additions & 86 deletions hail/hail/src/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
@@ -1,28 +1,21 @@
package is.hail.backend

import is.hail.asm4s._
import is.hail.backend.Backend.jsonToBytes
import is.hail.asm4s.HailClassLoader
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{IR, IRParser, LoweringAnalyses, SortField, TableIR, TableReader}
import is.hail.expr.ir.{IR, LoweringAnalyses, SortField, TableIR, TableReader}
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
import is.hail.io.plink.LoadPlink
import is.hail.io.vcf.LoadVCF
import is.hail.types._
import is.hail.io.fs.FS
import is.hail.types.RTable
import is.hail.types.encoded.EType
import is.hail.types.physical.PTuple
import is.hail.types.virtual.TFloat64
import is.hail.utils._
import is.hail.variant.ReferenceGenome
import is.hail.utils.ExecutionTimer.Timings
import is.hail.utils.fatal

import scala.reflect.ClassTag

import java.io._
import java.nio.charset.StandardCharsets
import java.io.{Closeable, OutputStream}

import org.json4s._
import org.json4s.jackson.JsonMethods
import sourcecode.Enclosing

object Backend {
Expand All @@ -38,23 +31,19 @@ object Backend {
ctx: ExecuteContext,
t: PTuple,
off: Long,
bufferSpecString: String,
bufferSpec: BufferSpec,
os: OutputStream,
): Unit = {
val bs = BufferSpec.parseOrDefault(bufferSpecString)
assert(t.size == 1)
val elementType = t.fields(0).typ
val codec = TypedCodecSpec(
EType.fromPythonTypeEncoding(elementType.virtualType),
elementType.virtualType,
bs,
bufferSpec,
)
assert(t.isFieldDefined(off, 0))
codec.encode(ctx, elementType, t.loadField(off, 0), os)
}

def jsonToBytes(f: => JValue): Array[Byte] =
JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8)
}

abstract class BroadcastValue[T] { def value: T }
Expand Down Expand Up @@ -82,8 +71,8 @@ abstract class Backend extends Closeable {
f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)])

def asSpark(op: String): SparkBackend =
fatal(s"${getClass.getSimpleName}: $op requires SparkBackend")
def asSpark(implicit E: Enclosing): SparkBackend =
fatal(s"${getClass.getSimpleName}: ${E.value} requires SparkBackend")

def lowerDistributedSort(
ctx: ExecuteContext,
Expand Down Expand Up @@ -116,70 +105,7 @@ abstract class Backend extends Closeable {
def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses)
: TableStage

def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T

final def valueType(s: String): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
IRParser.parse_value_ir(ctx, s).typ.toJSON
}
}

final def tableType(s: String): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
IRParser.parse_table_ir(ctx, s).typ.toJSON
}
}

final def matrixTableType(s: String): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
IRParser.parse_matrix_ir(ctx, s).typ.toJSON
}
}

final def blockMatrixType(s: String): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
IRParser.parse_blockmatrix_ir(ctx, s).typ.toJSON
}
}

def loadReferencesFromDataset(path: String): Array[Byte]

def fromFASTAFile(
name: String,
fastaFile: String,
indexFile: String,
xContigs: Array[String],
yContigs: Array[String],
mtContigs: Array[String],
parInput: Array[String],
): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile,
xContigs, yContigs, mtContigs, parInput).toJSON
}
}

def parseVCFMetadata(path: String): Array[Byte] =
withExecuteContext { ctx =>
jsonToBytes {
Extraction.decompose {
LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path)
}(defaultJSONFormats)
}
}

def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String)
: Array[Byte] =
withExecuteContext { ctx =>
LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missingValue).getBytes(
StandardCharsets.UTF_8
)
}
def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings)

def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
}
Loading