diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index 10aa71bad18d..b9f986e5834a 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -276,9 +276,6 @@ def add_liftover(self, name, chain_file, dest_reference_genome): def remove_liftover(self, name, dest_reference_genome): self._jbackend.pyRemoveLiftover(name, dest_reference_genome) - def index_bgen(self, files, index_file_map, referenceGenomeName, contig_recoding, skip_invalid_loci): - self._jbackend.pyIndexBgen(files, index_file_map, referenceGenomeName, contig_recoding, skip_invalid_loci) - def _parse_value_ir(self, code, ref_map={}): return self._jbackend.parse_value_ir( code, diff --git a/hail/src/main/scala/is/hail/HailContext.scala b/hail/src/main/scala/is/hail/HailContext.scala index 3de89ce13cd2..f52efe60ee71 100644 --- a/hail/src/main/scala/is/hail/HailContext.scala +++ b/hail/src/main/scala/is/hail/HailContext.scala @@ -19,6 +19,7 @@ import org.apache.spark.executor.InputMetrics import org.apache.spark.rdd.RDD import org.json4s.Extraction import org.json4s.jackson.JsonMethods +import sourcecode.Enclosing case class FilePartition(index: Int, file: String) extends Partition @@ -41,7 +42,7 @@ object HailContext { def backend: Backend = get.backend - def sparkBackend(op: String): SparkBackend = get.sparkBackend(op) + def sparkBackend(implicit E: Enclosing): SparkBackend = get.backend.asSpark def configureLogging(logFile: String, quiet: Boolean, append: Boolean): Unit = { org.apache.log4j.helpers.LogLog.setInternalDebugging(true) @@ -152,7 +153,7 @@ object HailContext { val fsBc = fs.broadcast - new RDD[T](SparkBackend.sparkContext("readPartition"), Nil) { + new RDD[T](SparkBackend.sparkContext, Nil) { def getPartitions: Array[Partition] = Array.tabulate(nPartitions)(i => FilePartition(i, partFiles(i))) @@ -175,8 +176,6 @@ class HailContext private ( ) { def stop(): Unit = HailContext.stop() - def sparkBackend(op: String): SparkBackend = backend.asSpark(op) - var checkRVDKeys: Boolean = false def version: String = is.hail.HAIL_PRETTY_VERSION @@ -188,7 +187,7 @@ class HailContext private ( maxLines: Int, ): Map[String, Array[WithContext[String]]] = { val regexp = regex.r - SparkBackend.sparkContext("fileAndLineCounts").textFilesLines(fs.globAll(files).map(_.getPath)) + SparkBackend.sparkContext.textFilesLines(fs.globAll(files).map(_.getPath)) .filter(line => regexp.findFirstIn(line.value).isDefined) .take(maxLines) .groupBy(_.source.file) diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index 5d2df1301a4f..bf0f9425c0ca 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -1,28 +1,24 @@ package is.hail.backend -import is.hail.asm4s._ -import is.hail.backend.Backend.jsonToBytes +import is.hail.asm4s.HailClassLoader import is.hail.backend.spark.SparkBackend -import is.hail.expr.ir.{IR, IRParser, LoweringAnalyses, SortField, TableIR, TableReader} +import is.hail.expr.ir.{IR, LoweringAnalyses, SortField, TableIR, TableReader} import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.io.{BufferSpec, TypedCodecSpec} -import is.hail.io.fs._ -import is.hail.io.plink.LoadPlink -import is.hail.io.vcf.LoadVCF -import is.hail.types._ +import is.hail.io.fs.FS +import is.hail.types.RTable import is.hail.types.encoded.EType import is.hail.types.physical.PTuple -import is.hail.types.virtual.TFloat64 -import is.hail.utils._ -import is.hail.variant.ReferenceGenome +import is.hail.utils.ExecutionTimer.Timings +import is.hail.utils.fatal import scala.reflect.ClassTag -import java.io._ +import java.io.{Closeable, OutputStream} import java.nio.charset.StandardCharsets import com.fasterxml.jackson.core.StreamReadConstraints -import org.json4s._ +import org.json4s.JValue import org.json4s.jackson.JsonMethods import sourcecode.Enclosing @@ -39,16 +35,15 @@ object Backend { ctx: ExecuteContext, t: PTuple, off: Long, - bufferSpecString: String, + bufferSpec: BufferSpec, os: OutputStream, ): Unit = { - val bs = BufferSpec.parseOrDefault(bufferSpecString) assert(t.size == 1) val elementType = t.fields(0).typ val codec = TypedCodecSpec( EType.fromPythonTypeEncoding(elementType.virtualType), elementType.virtualType, - bs, + bufferSpec, ) assert(t.isFieldDefined(off, 0)) codec.encode(ctx, elementType, t.loadField(off, 0), os) @@ -96,8 +91,8 @@ abstract class Backend extends Closeable { def close(): Unit - def asSpark(op: String): SparkBackend = - fatal(s"${getClass.getSimpleName}: $op requires SparkBackend") + def asSpark(implicit E: Enclosing): SparkBackend = + fatal(s"${getClass.getSimpleName}: ${E.value} requires SparkBackend") def shouldCacheQueryInfo: Boolean = true @@ -132,70 +127,7 @@ abstract class Backend extends Closeable { def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) : TableStage - def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T - - final def valueType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_value_ir(ctx, s).typ.toJSON - } - } - - final def tableType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_table_ir(ctx, s).typ.toJSON - } - } - - final def matrixTableType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_matrix_ir(ctx, s).typ.toJSON - } - } - - final def blockMatrixType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_blockmatrix_ir(ctx, s).typ.toJSON - } - } - - def loadReferencesFromDataset(path: String): Array[Byte] - - def fromFASTAFile( - name: String, - fastaFile: String, - indexFile: String, - xContigs: Array[String], - yContigs: Array[String], - mtContigs: Array[String], - parInput: Array[String], - ): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, - xContigs, yContigs, mtContigs, parInput).toJSON - } - } - - def parseVCFMetadata(path: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) - implicit val formats = defaultJSONFormats - Extraction.decompose(metadata) - } - } - - def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String) - : Array[Byte] = - withExecuteContext { ctx => - LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missingValue).getBytes( - StandardCharsets.UTF_8 - ) - } + def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] } diff --git a/hail/src/main/scala/is/hail/backend/BackendRpc.scala b/hail/src/main/scala/is/hail/backend/BackendRpc.scala new file mode 100644 index 000000000000..64f15af7be9d --- /dev/null +++ b/hail/src/main/scala/is/hail/backend/BackendRpc.scala @@ -0,0 +1,306 @@ +package is.hail.backend + +import is.hail.expr.ir.IRParser +import is.hail.expr.ir.functions.IRFunctionRegistry +import is.hail.io.BufferSpec +import is.hail.io.plink.LoadPlink +import is.hail.io.vcf.LoadVCF +import is.hail.linalg.RowMatrix +import is.hail.services.retryTransientErrors +import is.hail.types.virtual.{Kind, TFloat64, VType} +import is.hail.types.virtual.Kinds._ +import is.hail.utils.{toRichIterable, using, ExecutionTimer} +import is.hail.utils.ExecutionTimer.Timings +import is.hail.variant.ReferenceGenome + +import scala.util.control.NonFatal + +import java.io.ByteArrayOutputStream +import java.nio.charset.StandardCharsets + +import org.json4s.{DefaultFormats, Extraction, Formats, JValue} +import org.json4s.jackson.{JsonMethods, Serialization} + +case class IRTypePayload(ir: String) +case class LoadReferencesFromDatasetPayload(path: String) + +case class FromFASTAFilePayload( + name: String, + fasta_file: String, + index_file: String, + x_contigs: Array[String], + y_contigs: Array[String], + mt_contigs: Array[String], + par: Array[String], +) + +case class ParseVCFMetadataPayload(path: String) +case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: String, missing: String) + +case class ExecutePayload( + ir: String, + fs: Array[SerializedIRFunction], + stream_codec: String, +) + +case class SerializedIRFunction( + name: String, + type_parameters: Array[String], + value_parameter_names: Array[String], + value_parameter_types: Array[String], + return_type: String, + rendered_body: String, +) + +trait BackendRpc { + + sealed trait Command extends Product with Serializable + + object Commands { + case class TypeOf(k: Kind[_ <: VType], ir: String) extends Command + case class Execute(ir: String, fs: Array[SerializedIRFunction], codec: String) extends Command + case class ParseVcfMetadata(path: String) extends Command + + case class ImportFam(path: String, isQuantPheno: Boolean, delimiter: String, missing: String) + extends Command + + case class LoadReferencesFromDataset(path: String) extends Command + + case class LoadReferencesFromFASTA( + name: String, + fasta_file: String, + index_file: String, + x_contigs: Array[String], + y_contigs: Array[String], + mt_contigs: Array[String], + par: Array[String], + ) extends Command + + case class ExportBlockMatrix( + pathIn: String, + pathOut: String, + delimiter: String, + header: String, + addIndex: Boolean, + exportType: String, + partitionSize: Int, + entries: String, + ) extends Command + } + + trait Ask[Env] { + def command(env: Env): Command + } + + trait Context[Env] { + def scoped[A](env: Env)(f: ExecuteContext => A): (A, Timings) + } + + trait Write[Env] { + def timings(env: Env)(t: ExecutionTimer.Timings): Unit + def result(env: Env)(r: Array[Byte]): Unit + def error(env: Env)(t: Throwable): Unit + } + + implicit val fmts: Formats = DefaultFormats + + import Commands._ + + final def runRpc[A](env: A)(implicit Ask: Ask[A], Context: Context[A], Write: Write[A]): Unit = + try { + val command = Ask.command(env) + val (result, timings) = retryTransientErrors { + Context.scoped(env) { ctx => + command match { + case TypeOf(kind, s) => + jsonToBytes { + (kind match { + case Value => IRParser.parse_value_ir(ctx, s) + case Table => IRParser.parse_table_ir(ctx, s) + case Matrix => IRParser.parse_matrix_ir(ctx, s) + case BlockMatrix => IRParser.parse_blockmatrix_ir(ctx, s) + }).typ.toJSON + } + + case Execute(s, fns, codec) => + val bufferSpec = BufferSpec.parseOrDefault(codec) + withIRFunctionsReadFromInput(ctx, fns) { + val ir = IRParser.parse_value_ir(ctx, s) + val res = ctx.backend.execute(ctx, ir) + res match { + case Left(_) => Array() + case Right((pt, off)) => + using(new ByteArrayOutputStream()) { os => + Backend.encodeToOutputStream(ctx, pt, off, bufferSpec, os) + os.toByteArray + } + } + } + + case ParseVcfMetadata(path) => + jsonToBytes { + val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) + Extraction.decompose(metadata) + } + + case ImportFam(path, isQuantPheno, delimiter, missing) => + LoadPlink + .importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missing) + .getBytes(StandardCharsets.UTF_8) + + case LoadReferencesFromDataset(path) => + val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) + ctx.References ++= rgs.map(rg => rg.name -> rg) + Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) + + case LoadReferencesFromFASTA(name, fasta, index, xContigs, yContigs, mtContigs, par) => + jsonToBytes { + val rg = ReferenceGenome.fromFASTAFile( + ctx, + name, + fasta, + index, + xContigs, + yContigs, + mtContigs, + par, + ) + ctx.References += rg.name -> rg + rg.toJSON + } + + case ExportBlockMatrix(pathIn, pathOut, delimiter, header, addIndex, exportType, + partitionSize, entries) => + val rm = RowMatrix.readBlockMatrix(ctx.fs, pathIn, partitionSize) + entries match { + case "full" => + rm.export(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + case "lower" => + rm.exportLowerTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + case "strict_lower" => + rm.exportStrictLowerTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + case "upper" => + rm.exportUpperTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + case "strict_upper" => + rm.exportStrictUpperTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) + } + Array() + } + } + } + + Write.result(env)(result.asInstanceOf[Array[Byte]]) + Write.timings(env)(timings) + } catch { + case NonFatal(error) => Write.error(env)(error) + } + + def jsonToBytes(v: JValue): Array[Byte] = + JsonMethods.compact(v).getBytes(StandardCharsets.UTF_8) + + private[this] def withIRFunctionsReadFromInput[A]( + ctx: ExecuteContext, + serializedFunctions: Array[SerializedIRFunction], + )( + body: => A + ): A = { + try { + serializedFunctions.foreach { func => + IRFunctionRegistry.registerIR( + ctx, + func.name, + func.type_parameters, + func.value_parameter_names, + func.value_parameter_types, + func.return_type, + func.rendered_body, + ) + } + + body + } finally + IRFunctionRegistry.clearUserFunctions() + } +} + +trait HttpLikeBackendRpc[A] extends BackendRpc { + + import Commands._ + + trait Routing extends Ask[A] { + + sealed trait Route extends Product with Serializable + + object Routes { + case class TypeOf(kind: Kind[_ <: VType]) extends Route + case object Execute extends Route + case object ParseVcfMetadata extends Route + case object ImportFam extends Route + case object LoadReferencesFromDataset extends Route + case object LoadReferencesFromFASTA extends Route + } + + def route(a: A): Route + def payload(a: A): JValue + + final override def command(a: A): Command = + route(a) match { + case Routes.TypeOf(k) => + TypeOf(k, payload(a).extract[IRTypePayload].ir) + case Routes.Execute => + val ExecutePayload(ir, fs, codec) = payload(a).extract[ExecutePayload] + Execute(ir, fs, codec) + case Routes.ParseVcfMetadata => + ParseVcfMetadata(payload(a).extract[ParseVCFMetadataPayload].path) + case Routes.ImportFam => + val config = payload(a).extract[ImportFamPayload] + ImportFam(config.path, config.quant_pheno, config.delimiter, config.missing) + case Routes.LoadReferencesFromDataset => + val path = payload(a).extract[LoadReferencesFromDatasetPayload].path + LoadReferencesFromDataset(path) + case Routes.LoadReferencesFromFASTA => + val config = payload(a).extract[FromFASTAFilePayload] + LoadReferencesFromFASTA( + config.name, + config.fasta_file, + config.index_file, + config.x_contigs, + config.y_contigs, + config.mt_contigs, + config.par, + ) + } + } + + implicit protected def Ask: Routing + implicit protected def Write: Write[A] + implicit protected def Context: Context[A] +} diff --git a/hail/src/main/scala/is/hail/backend/BackendServer.scala b/hail/src/main/scala/is/hail/backend/BackendServer.scala index 24e7f7b98755..db23bcb310fe 100644 --- a/hail/src/main/scala/is/hail/backend/BackendServer.scala +++ b/hail/src/main/scala/is/hail/backend/BackendServer.scala @@ -1,41 +1,21 @@ package is.hail.backend -import is.hail.expr.ir.IRParser +import is.hail.types.virtual.Kinds.{BlockMatrix, Matrix, Table, Value} import is.hail.utils._ - -import scala.util.control.NonFatal +import is.hail.utils.ExecutionTimer.Timings import java.io.Closeable import java.net.InetSocketAddress -import java.nio.charset.StandardCharsets import java.util.concurrent._ +import com.google.api.client.http.HttpStatusCodes import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer} import org.json4s._ -import org.json4s.jackson.JsonMethods -import org.json4s.jackson.JsonMethods.compact - -case class IRTypePayload(ir: String) -case class LoadReferencesFromDatasetPayload(path: String) - -case class FromFASTAFilePayload( - name: String, - fasta_file: String, - index_file: String, - x_contigs: Array[String], - y_contigs: Array[String], - mt_contigs: Array[String], - par: Array[String], -) - -case class ParseVCFMetadataPayload(path: String) -case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: String, missing: String) -case class ExecutePayload(ir: String, stream_codec: String, timed: Boolean) +import org.json4s.jackson.{JsonMethods, Serialization} class BackendServer(backend: Backend) extends Closeable { // 0 => let the OS pick an available port private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) - private[this] val handler = new BackendHttpHandler(backend) private[this] val thread = { // This HTTP server *must not* start non-daemon threads because such threads keep the JVM @@ -59,7 +39,7 @@ class BackendServer(backend: Backend) extends Closeable { /* Source: * https://docs.oracle.com/en/java/javase/11/docs/api/jdk.httpserver/com/sun/net/httpserver/HttpServer.html#setExecutor(java.util.concurrent.Executor) */ // - httpServer.createContext("/", handler) + httpServer.createContext("/", Handler) httpServer.setExecutor(null) val t = Executors.defaultThreadFactory().newThread(new Runnable() { def run(): Unit = @@ -69,85 +49,75 @@ class BackendServer(backend: Backend) extends Closeable { t } - def port = httpServer.getAddress.getPort + def port: Int = httpServer.getAddress.getPort def start(): Unit = thread.start() override def close(): Unit = httpServer.stop(10) -} -class BackendHttpHandler(backend: Backend) extends HttpHandler { - def handle(exchange: HttpExchange): Unit = { - implicit val formats: Formats = DefaultFormats - - try { - val body = using(exchange.getRequestBody)(JsonMethods.parse(_)) - if (exchange.getRequestURI.getPath == "/execute") { - val ExecutePayload(irStr, streamCodec, timed) = body.extract[ExecutePayload] - backend.withExecuteContext { ctx => - val (res, timings) = ExecutionTimer.time { timer => - ctx.local(timer = timer) { ctx => - val irData = IRParser.parse_value_ir(ctx, irStr) - backend.execute(ctx, irData) - } - } - - if (timed) { - exchange.getResponseHeaders.add("X-Hail-Timings", compact(timings.toJSON)) - } - - res match { - case Left(_) => exchange.sendResponseHeaders(200, -1L) - case Right((t, off)) => - exchange.sendResponseHeaders(200, 0L) // 0 => an arbitrarily long response body - using(exchange.getResponseBody) { os => - Backend.encodeToOutputStream(ctx, t, off, streamCodec, os) - } - } + private case class Request(exchange: HttpExchange, payload: JValue) + + private[this] object Handler extends HttpHandler with HttpLikeBackendRpc[Request] { + + override def handle(exchange: HttpExchange): Unit = { + val payload = using(exchange.getRequestBody)(JsonMethods.parse(_)) + runRpc(Request(exchange, payload)) + } + + implicit override protected object Ask extends Routing { + + import Routes._ + + override def route(a: Request): Route = + a.exchange.getRequestURI.getPath match { + case "/value/type" => TypeOf(Value) + case "/table/type" => TypeOf(Table) + case "/matrixtable/type" => TypeOf(Matrix) + case "/blockmatrix/type" => TypeOf(BlockMatrix) + case "/execute" => Execute + case "/vcf/metadata/parse" => ParseVcfMetadata + case "/fam/import" => ImportFam + case "/references/load" => LoadReferencesFromDataset + case "/references/from_fasta" => LoadReferencesFromFASTA } - return - } - val response: Array[Byte] = exchange.getRequestURI.getPath match { - case "/value/type" => backend.valueType(body.extract[IRTypePayload].ir) - case "/table/type" => backend.tableType(body.extract[IRTypePayload].ir) - case "/matrixtable/type" => backend.matrixTableType(body.extract[IRTypePayload].ir) - case "/blockmatrix/type" => backend.blockMatrixType(body.extract[IRTypePayload].ir) - case "/references/load" => - backend.loadReferencesFromDataset(body.extract[LoadReferencesFromDatasetPayload].path) - case "/references/from_fasta" => - val config = body.extract[FromFASTAFilePayload] - backend.fromFASTAFile( - config.name, - config.fasta_file, - config.index_file, - config.x_contigs, - config.y_contigs, - config.mt_contigs, - config.par, - ) - case "/vcf/metadata/parse" => - backend.parseVCFMetadata(body.extract[ParseVCFMetadataPayload].path) - case "/fam/import" => - val config = body.extract[ImportFamPayload] - backend.importFam(config.path, config.quant_pheno, config.delimiter, config.missing) + override def payload(a: Request): JValue = a.payload + } + + implicit override protected object Write extends Write[Request] with ErrorHandling { + + override def timings(req: Request)(t: Timings): Unit = { + val ts = Serialization.write(Map("timings" -> t)) + req.exchange.getResponseHeaders.add("X-Hail-Timings", ts) } - exchange.sendResponseHeaders(200, response.length) - using(exchange.getResponseBody())(_.write(response)) - } catch { - case NonFatal(t) => - val (shortMessage, expandedMessage, errorId) = handleForPython(t) - val errorJson = JObject( - "short" -> JString(shortMessage), - "expanded" -> JString(expandedMessage), - "error_id" -> JInt(errorId), + override def result(req: Request)(result: Array[Byte]): Unit = + respond(req)(HttpStatusCodes.STATUS_CODE_OK, result) + + override def error(req: Request)(t: Throwable): Unit = + respond(req)( + HttpStatusCodes.STATUS_CODE_SERVER_ERROR, + jsonToBytes { + val (shortMessage, expandedMessage, errorId) = handleForPython(t) + JObject( + "short" -> JString(shortMessage), + "expanded" -> JString(expandedMessage), + "error_id" -> JInt(errorId), + ) + }, ) - val errorBytes = JsonMethods.compact(errorJson).getBytes(StandardCharsets.UTF_8) - exchange.sendResponseHeaders(500, errorBytes.length) - using(exchange.getResponseBody())(_.write(errorBytes)) + + private[this] def respond(req: Request)(code: Int, payload: Array[Byte]): Unit = { + req.exchange.sendResponseHeaders(code, payload.length) + using(req.exchange.getResponseBody)(_.write(payload)) + } + } + + implicit override protected object Context extends Context[Request] { + override def scoped[A](req: Request)(f: ExecuteContext => A): (A, Timings) = + backend.withExecuteContext(f) } } } diff --git a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala index 299c597c2373..94219a72166d 100644 --- a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala +++ b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala @@ -54,18 +54,15 @@ object NonOwningTempFileManager { } object ExecuteContext { - def scoped[T](f: ExecuteContext => T)(implicit E: Enclosing): T = { - val result = HailContext.sparkBackend("ExecuteContext.scoped").withExecuteContext( + def scoped[T](f: ExecuteContext => T)(implicit E: Enclosing): T = + HailContext.sparkBackend.withExecuteContext( selfContainedExecution = false )(f) - result - } def scoped[T]( tmpdir: String, localTmpdir: String, backend: Backend, - references: Map[String, ReferenceGenome], fs: FS, timer: ExecutionTimer, tempFileManager: TempFileManager, @@ -73,6 +70,7 @@ object ExecuteContext { flags: HailFeatureFlags, backendContext: BackendContext, irMetadata: IrMetadata, + references: mutable.Map[String, ReferenceGenome], blockMatrixCache: mutable.Map[String, BlockMatrix], codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], irCache: mutable.Map[Int, BaseIR], @@ -85,7 +83,6 @@ object ExecuteContext { tmpdir, localTmpdir, backend, - references, fs, region, timer, @@ -94,6 +91,7 @@ object ExecuteContext { flags, backendContext, irMetadata, + references, blockMatrixCache, codeCache, irCache, @@ -117,7 +115,6 @@ class ExecuteContext( val tmpdir: String, val localTmpdir: String, val backend: Backend, - val references: Map[String, ReferenceGenome], val fs: FS, val r: Region, val timer: ExecutionTimer, @@ -126,6 +123,7 @@ class ExecuteContext( val flags: HailFeatureFlags, val backendContext: BackendContext, var irMetadata: IrMetadata, + val References: mutable.Map[String, ReferenceGenome], val BlockMatrixCache: mutable.Map[String, BlockMatrix], val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], val IrCache: mutable.Map[Int, BaseIR], @@ -142,7 +140,8 @@ class ExecuteContext( ) } - val stateManager = HailStateManager(references) + def stateManager: HailStateManager = + HailStateManager(References.toMap) val tempFileManager: TempFileManager = if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs) @@ -168,7 +167,7 @@ class ExecuteContext( def getFlag(name: String): String = flags.get(name) - def getReference(name: String): ReferenceGenome = references(name) + def getReference(name: String): ReferenceGenome = References(name) def shouldWriteIRFiles(): Boolean = getFlag("write_ir_files") != null @@ -188,7 +187,6 @@ class ExecuteContext( tmpdir: String = this.tmpdir, localTmpdir: String = this.localTmpdir, backend: Backend = this.backend, - references: Map[String, ReferenceGenome] = this.references, fs: FS = this.fs, r: Region = this.r, timer: ExecutionTimer = this.timer, @@ -196,6 +194,7 @@ class ExecuteContext( theHailClassLoader: HailClassLoader = this.theHailClassLoader, flags: HailFeatureFlags = this.flags, backendContext: BackendContext = this.backendContext, + references: mutable.Map[String, ReferenceGenome] = this.References, irMetadata: IrMetadata = this.irMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache, codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache, @@ -207,7 +206,6 @@ class ExecuteContext( tmpdir, localTmpdir, backend, - references, fs, r, timer, @@ -216,6 +214,7 @@ class ExecuteContext( flags, backendContext, irMetadata, + references, blockMatrixCache, codeCache, irCache, diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 1a1e4af00d9c..892c71ef08be 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -16,6 +16,7 @@ import is.hail.types.physical.PTuple import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual.TVoid import is.hail.utils._ +import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome import scala.collection.mutable @@ -84,14 +85,13 @@ class LocalBackend( // flags can be set after construction from python def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - ExecutionTimer.logTime { timer => + override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = + ExecutionTimer.time { timer => val fs = this.fs ExecuteContext.scoped( tmpdir, tmpdir, this, - references.toMap, fs, timer, null, @@ -102,6 +102,7 @@ class LocalBackend( ExecutionCache.fromFlags(flags, fs, tmpdir) }, IrMetadata(None), + references, ImmutableMap.empty, codeCache, persistedIR, diff --git a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala index e2d017c1a380..65bf65f1adbb 100644 --- a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala +++ b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala @@ -14,7 +14,7 @@ import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} import is.hail.linalg.RowMatrix import is.hail.types.physical.PStruct import is.hail.types.virtual.{TArray, TInterval} -import is.hail.utils.{defaultJSONFormats, fatal, log, toRichIterable, HailException, Interval} +import is.hail.utils.{fatal, log, toRichIterable, HailException, Interval} import is.hail.variant.ReferenceGenome import scala.collection.mutable @@ -22,13 +22,11 @@ import scala.jdk.CollectionConverters.{ asScalaBufferConverter, mapAsScalaMapConverter, seqAsJavaListConverter, } -import java.nio.charset.StandardCharsets import java.util import org.apache.spark.sql.DataFrame import org.json4s -import org.json4s.Formats -import org.json4s.jackson.{JsonMethods, Serialization} +import org.json4s.jackson.JsonMethods import sourcecode.Enclosing trait Py4JBackendExtensions { @@ -139,7 +137,7 @@ trait Py4JBackendExtensions { val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) addJavaIR(ctx, field) } - } + }._1 def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = { val key = jKey.asScala.toArray.toFastSeq @@ -165,7 +163,7 @@ trait Py4JBackendExtensions { backend.withExecuteContext { ctx => val tir = IRParser.parse_table_ir(ctx, s) Interpret(tir, ctx).toDF() - } + }._1 def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] = backend.withExecuteContext { ctx => @@ -191,7 +189,7 @@ trait Py4JBackendExtensions { } log.info("pyReadMultipleMatrixTables: returning N matrix tables") matrixReaders.asJava - } + }._1 def pyAddReference(jsonConfig: String): Unit = addReference(ReferenceGenome.fromJSON(jsonConfig)) @@ -234,7 +232,7 @@ trait Py4JBackendExtensions { Name(n) -> IRParser.parseType(t) }.toSeq: _*), ) - } + }._1 def parse_table_ir(s: String): TableIR = withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_table_ir(ctx, s)) @@ -247,15 +245,6 @@ trait Py4JBackendExtensions { IRParser.parse_blockmatrix_ir(ctx, s) } - def loadReferencesFromDataset(path: String): Array[Byte] = - backend.withExecuteContext { ctx => - val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) - rgs.foreach(addReference) - - implicit val formats: Formats = defaultJSONFormats - Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) - } - def withExecuteContext[T]( selfContainedExecution: Boolean = true )( @@ -266,5 +255,5 @@ trait Py4JBackendExtensions { val tempFileManager = longLifeTempFileManager if (selfContainedExecution && tempFileManager != null) f(ctx) else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f) - } + }._1 } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 0f02d16e5800..c7da5b3c579b 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -4,26 +4,26 @@ import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ +import is.hail.backend.service.ServiceBackend.MaxAvailableGcsConnections import is.hail.expr.Validate import is.hail.expr.ir.{ - IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck, + IR, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck, } import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.compile.Compile -import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} -import is.hail.services.{BatchClient, _} +import is.hail.services._ import is.hail.services.JobGroupStates.Failure import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType -import is.hail.types.virtual._ +import is.hail.types.virtual.{Kinds, TVoid} import is.hail.utils._ +import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome -import scala.annotation.switch import scala.collection.mutable import scala.reflect.ClassTag @@ -32,110 +32,35 @@ import java.nio.charset.StandardCharsets import java.nio.file.Path import java.util.concurrent._ -import org.apache.log4j.Logger -import org.json4s.{DefaultFormats, Formats} import org.json4s.JsonAST._ -import org.json4s.jackson.{JsonMethods, Serialization} +import org.json4s.jackson.JsonMethods import sourcecode.Enclosing -class ServiceBackendContext( - val billingProject: String, - val remoteTmpDir: String, - val workerCores: String, - val workerMemory: String, - val storageRequirement: String, - val regions: Array[String], - val cloudfuseConfig: Array[CloudfuseConfig], - val profile: Boolean, - val executionCache: ExecutionCache, -) extends BackendContext with Serializable {} +case class ServiceBackendContext( + remoteTmpDir: String, + jobConfig: BatchJobConfig, + override val executionCache: ExecutionCache, +) extends BackendContext with Serializable object ServiceBackend { - - def apply( - jarLocation: String, - name: String, - theHailClassLoader: HailClassLoader, - batchClient: BatchClient, - batchId: Option[Int], - jobGroupId: Option[Int], - scratchDir: String = sys.env.getOrElse("HAIL_WORKER_SCRATCH_DIR", ""), - rpcConfig: ServiceBackendRPCPayload, - env: Map[String, String], - ): ServiceBackend = { - - val flags = HailFeatureFlags.fromEnv(rpcConfig.flags) - val shouldProfile = flags.get("profile") != null - val fs = RouterFS.buildRoutes( - CloudStorageFSConfig.fromFlagsAndEnv( - Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), - flags, - env, - ) - ) - - val backendContext = new ServiceBackendContext( - rpcConfig.billing_project, - rpcConfig.remote_tmpdir, - rpcConfig.worker_cores, - rpcConfig.worker_memory, - rpcConfig.storage, - rpcConfig.regions, - rpcConfig.cloudfuse_configs, - shouldProfile, - ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), - ) - - val references = mutable.Map.empty[String, ReferenceGenome] - references ++= ReferenceGenome.builtinReferences() - rpcConfig.custom_references.map(ReferenceGenome.fromJSON).foreach { r => - references += (r.name -> r) - } - - rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => - liftoversForSource.foreach { case (destGenome, chainFile) => - references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) - } - } - rpcConfig.sequences.foreach { case (rg, seq) => - references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) - } - - new ServiceBackend( - jarLocation, - name, - theHailClassLoader, - references.toMap, - batchClient, - batchId, - jobGroupId, - flags, - rpcConfig.tmp_dir, - fs, - backendContext, - scratchDir, - ) - } + val MaxAvailableGcsConnections = 1000 } class ServiceBackend( - val jarLocation: String, - var name: String, - val theHailClassLoader: HailClassLoader, - val references: Map[String, ReferenceGenome], - val batchClient: BatchClient, - val curBatchId: Option[Int], - val curJobGroupId: Option[Int], - val flags: HailFeatureFlags, - val tmpdir: String, + val name: String, + batchClient: BatchClient, + jarLocation: String, + theHailClassLoader: HailClassLoader, + batchConfig: Option[BatchConfig], + rpcConfig: ServiceBackendRPCPayload, + jobConfig: BatchJobConfig, + flags: HailFeatureFlags, val fs: FS, - val serviceBackendContext: ServiceBackendContext, - val scratchDir: String, + references: mutable.Map[String, ReferenceGenome], ) extends Backend with Logging { private[this] var stageCount = 0 - private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000 - private[this] val executor = Executors.newFixedThreadPool(MAX_AVAILABLE_GCS_CONNECTIONS) + private[this] val executor = Executors.newFixedThreadPool(MaxAvailableGcsConnections) override def shouldCacheQueryInfo: Boolean = false @@ -169,10 +94,11 @@ class ServiceBackend( stageIdentifier: String, f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte], ): (String, String, Int) = { - val backendContext = _backendContext.asInstanceOf[ServiceBackendContext] + val ServiceBackendContext(remoteTmp, jobConfig, _) = + _backendContext.asInstanceOf[ServiceBackendContext] val n = collection.length val token = tokenUrlSafe - val root = s"${backendContext.remoteTmpDir}/parallelizeAndComputeWithIndex/$token" + val root = s"$remoteTmp/parallelizeAndComputeWithIndex/$token" log.info(s"parallelizeAndComputeWithIndex: $token: nPartitions $n") log.info(s"parallelizeAndComputeWithIndex: $token: writing f and contexts") @@ -202,7 +128,7 @@ class ServiceBackend( val jobGroup = JobGroupRequest( job_group_id = 1, // QoB creates an update for every new stage - absolute_parent_id = curJobGroupId.getOrElse(0), + absolute_parent_id = batchConfig.map(_.jobGroupId).getOrElse(0), attributes = Map("name" -> stageIdentifier), ) @@ -222,13 +148,13 @@ class ServiceBackend( resources = Some( JobResources( preemptible = true, - cpu = Some(backendContext.workerCores).filter(_ != "None"), - memory = Some(backendContext.workerMemory).filter(_ != "None"), - storage = Some(backendContext.storageRequirement).filter(_ != "0Gi"), + cpu = Some(jobConfig.worker_cores).filter(_ != "None"), + memory = Some(jobConfig.worker_memory).filter(_ != "None"), + storage = Some(jobConfig.storage).filter(_ != "0Gi"), ) ), - regions = Some(backendContext.regions).filter(_.nonEmpty), - cloudfuse = Some(backendContext.cloudfuseConfig).filter(_.nonEmpty), + regions = Some(jobConfig.regions).filter(_.nonEmpty), + cloudfuse = Some(jobConfig.cloudfuse_configs).filter(_.nonEmpty), attributes = Map("name" -> s"${name}_stage${stageCount}_${stageIdentifier}_job$i"), ) } @@ -238,10 +164,10 @@ class ServiceBackend( log.info(s"parallelizeAndComputeWithIndex: $token: running job") - val batchId = curBatchId.getOrElse { + val batchId = batchConfig.map(_.batchId).getOrElse { batchClient.newBatch( BatchRequest( - billing_project = backendContext.billingProject, + billing_project = jobConfig.billing_project, n_jobs = 0, token = token, attributes = Map("name" -> name), @@ -371,39 +297,43 @@ class ServiceBackend( : TableStage = LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses) - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - ExecutionTimer.logTime { timer => + override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = + ExecutionTimer.time { timer => ExecuteContext.scoped( - tmpdir, - "file:///tmp", + rpcConfig.tmp_dir, + rpcConfig.remote_tmpdir, this, - references, fs, timer, null, theHailClassLoader, flags, - serviceBackendContext, + ServiceBackendContext( + rpcConfig.remote_tmpdir, + jobConfig, + ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), + ), IrMetadata(None), + references, ImmutableMap.empty, mutable.Map.empty, ImmutableMap.empty, )(f) } - - override def loadReferencesFromDataset(path: String): Array[Byte] = - withExecuteContext { ctx => - val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) - implicit val formats: Formats = defaultJSONFormats - Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) - } } class EndOfInputException extends RuntimeException class HailBatchFailure(message: String) extends RuntimeException(message) -object ServiceBackendAPI { - private[this] val log = Logger.getLogger(getClass.getName()) +case class Request( + backend: ServiceBackend, + fs: FS, + outputUrl: String, + action: Int, + payload: JValue, +) + +object ServiceBackendAPI extends HttpLikeBackendRpc[Request] with Logging { def main(argv: Array[String]): Unit = { assert(argv.length == 7, argv.toFastSeq) @@ -417,42 +347,67 @@ object ServiceBackendAPI { val inputURL = argv(5) val outputURL = argv(6) - val fs = RouterFS.buildRoutes( + val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") + DeployConfig.set(deployConfig) + sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) + + var fs = RouterFS.buildRoutes( CloudStorageFSConfig.fromFlagsAndEnv( Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), HailFeatureFlags.fromEnv(), ) ) - val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") - DeployConfig.set(deployConfig) - sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) - val batchClient = BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")) - log.info("BatchClient allocated.") + val (rpcConfig, jobConfig, action, payload) = + using(fs.openNoCompression(inputURL)) { is => + val input = JsonMethods.parse(is) + ( + (input \ "config").extract[ServiceBackendRPCPayload], + (input \ "job_config").extract[BatchJobConfig], + (input \ "action").extract[Int], + input \ "payload", + ) + } + + // requester pays config is conveyed in feature flags currently + val featureFlags = HailFeatureFlags.fromEnv(rpcConfig.flags) + fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + featureFlags, + ) + ) - val batchConfig = - BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")) - val batchId = batchConfig.map(_.batchId) - val jobGroupId = batchConfig.map(_.jobGroupId) - log.info("BatchConfig parsed.") + val references = mutable.Map[String, ReferenceGenome]() + references ++= ReferenceGenome.builtinReferences() + rpcConfig.custom_references.toFastSeq.view.map(ReferenceGenome.fromJSON).foreach { rg => + references += rg.name -> rg + } - implicit val formats: Formats = DefaultFormats + rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => + liftoversForSource.foreach { case (destGenome, chainFile) => + references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) + } + } - val input = using(fs.openNoCompression(inputURL))(JsonMethods.parse(_)) - val rpcConfig = (input \ "config").extract[ServiceBackendRPCPayload] + rpcConfig.sequences.foreach { case (rg, seq) => + references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) + } // FIXME: when can the classloader be shared? (optimizer benefits!) - val backend = ServiceBackend( - jarLocation, + val backend = new ServiceBackend( name, - new HailClassLoader(getClass().getClassLoader()), - batchClient, - batchId, - jobGroupId, - scratchDir, + BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")), + jarLocation, + new HailClassLoader(getClass.getClassLoader), + BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")), rpcConfig, - sys.env, + jobConfig, + featureFlags, + fs, + references, ) + log.info("ServiceBackend allocated.") if (HailContext.isInitialized) { HailContext.get.backend = backend @@ -462,9 +417,76 @@ object ServiceBackendAPI { log.info("HailContexet initialized.") } - val action = (input \ "action").extract[Int] - val payload = (input \ "payload") - new ServiceBackendAPI(backend, fs, outputURL).executeOneCommand(action, payload) + runRpc(Request(backend, fs, outputURL, action, payload)) + } + + implicit override protected object Ask extends Routing { + import Routes._ + + override def route(a: Request): Route = + a.action match { + case 2 => TypeOf(Kinds.Value) + case 3 => TypeOf(Kinds.Table) + case 4 => TypeOf(Kinds.Matrix) + case 5 => TypeOf(Kinds.BlockMatrix) + case 6 => Execute + case 7 => ParseVcfMetadata + case 8 => ImportFam + case 1 => LoadReferencesFromDataset + case 9 => LoadReferencesFromFASTA + } + + override def payload(a: Request): JValue = a.payload + } + + implicit override protected object Write extends Write[Request] { + + // service backend doesn't support sending timings back to the python client + override def timings(env: Request)(t: Timings): Unit = + () + + override def result(env: Request)(result: Array[Byte]): Unit = + retryTransientErrors { + using(env.fs.createNoCompression(env.outputUrl)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(true) + output.writeBytes(result) + } + } + + override def error(env: Request)(t: Throwable): Unit = + retryTransientErrors { + val (shortMessage, expandedMessage, errorId) = + t match { + case t: HailWorkerException => + log.error( + "A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job.", + t, + ) + (t.shortMessage, t.expandedMessage, t.errorId) + case _ => + log.error( + "An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job.", + t, + ) + handleForPython(t) + } + + using(env.fs.createNoCompression(env.outputUrl)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(false) + output.writeString(shortMessage) + output.writeString(expandedMessage) + output.writeInt(errorId) + } + + throw t + } + } + + implicit override protected object Context extends Context[Request] { + override def scoped[A](env: Request)(f: ExecuteContext => A): (A, Timings) = + env.backend.withExecuteContext(f) } } @@ -506,173 +528,17 @@ case class SequenceConfig(fasta: String, index: String) case class ServiceBackendRPCPayload( tmp_dir: String, remote_tmpdir: String, - billing_project: String, - worker_cores: String, - worker_memory: String, - storage: String, - cloudfuse_configs: Array[CloudfuseConfig], - regions: Array[String], flags: Map[String, String], custom_references: Array[String], liftovers: Map[String, Map[String, String]], sequences: Map[String, SequenceConfig], ) -case class ServiceBackendExecutePayload( - functions: Array[SerializedIRFunction], - idempotency_token: String, - payload: ExecutePayload, -) - -case class SerializedIRFunction( - name: String, - type_parameters: Array[String], - value_parameter_names: Array[String], - value_parameter_types: Array[String], - return_type: String, - rendered_body: String, +case class BatchJobConfig( + billing_project: String, + worker_cores: String, + worker_memory: String, + storage: String, + cloudfuse_configs: Array[CloudfuseConfig], + regions: Array[String], ) - -class ServiceBackendAPI( - private[this] val backend: ServiceBackend, - private[this] val fs: FS, - private[this] val outputURL: String, -) extends Thread { - private[this] val LOAD_REFERENCES_FROM_DATASET = 1 - private[this] val VALUE_TYPE = 2 - private[this] val TABLE_TYPE = 3 - private[this] val MATRIX_TABLE_TYPE = 4 - private[this] val BLOCK_MATRIX_TYPE = 5 - private[this] val EXECUTE = 6 - private[this] val PARSE_VCF_METADATA = 7 - private[this] val IMPORT_FAM = 8 - private[this] val FROM_FASTA_FILE = 9 - - private[this] val log = Logger.getLogger(getClass.getName()) - - private[this] def doAction(action: Int, payload: JValue): Array[Byte] = retryTransientErrors { - implicit val formats: Formats = DefaultFormats - (action: @switch) match { - case LOAD_REFERENCES_FROM_DATASET => - val path = payload.extract[LoadReferencesFromDatasetPayload].path - backend.loadReferencesFromDataset(path) - case VALUE_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.valueType(ir) - case TABLE_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.tableType(ir) - case MATRIX_TABLE_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.matrixTableType(ir) - case BLOCK_MATRIX_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.blockMatrixType(ir) - case EXECUTE => - val qobExecutePayload = payload.extract[ServiceBackendExecutePayload] - val bufferSpecString = qobExecutePayload.payload.stream_codec - val code = qobExecutePayload.payload.ir - backend.withExecuteContext { ctx => - withIRFunctionsReadFromInput(qobExecutePayload.functions, ctx) { () => - val ir = IRParser.parse_value_ir(ctx, code) - backend.execute(ctx, ir) match { - case Left(()) => - Array() - case Right((pt, off)) => - using(new ByteArrayOutputStream()) { os => - Backend.encodeToOutputStream(ctx, pt, off, bufferSpecString, os) - os.toByteArray - } - } - } - } - case PARSE_VCF_METADATA => - val path = payload.extract[ParseVCFMetadataPayload].path - backend.parseVCFMetadata(path) - case IMPORT_FAM => - val famPayload = payload.extract[ImportFamPayload] - val path = famPayload.path - val quantPheno = famPayload.quant_pheno - val delimiter = famPayload.delimiter - val missing = famPayload.missing - backend.importFam(path, quantPheno, delimiter, missing) - case FROM_FASTA_FILE => - val fastaPayload = payload.extract[FromFASTAFilePayload] - backend.fromFASTAFile( - fastaPayload.name, - fastaPayload.fasta_file, - fastaPayload.index_file, - fastaPayload.x_contigs, - fastaPayload.y_contigs, - fastaPayload.mt_contigs, - fastaPayload.par, - ) - } - } - - private[this] def withIRFunctionsReadFromInput( - serializedFunctions: Array[SerializedIRFunction], - ctx: ExecuteContext, - )( - body: () => Array[Byte] - ): Array[Byte] = { - try { - serializedFunctions.foreach { func => - IRFunctionRegistry.registerIR( - ctx, - func.name, - func.type_parameters, - func.value_parameter_names, - func.value_parameter_types, - func.return_type, - func.rendered_body, - ) - } - body() - } finally - IRFunctionRegistry.clearUserFunctions() - } - - def executeOneCommand(action: Int, payload: JValue): Unit = { - try { - val result = doAction(action, payload) - retryTransientErrors { - using(fs.createNoCompression(outputURL)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(true) - output.writeBytes(result) - } - } - } catch { - case exc: HailWorkerException => - retryTransientErrors { - using(fs.createNoCompression(outputURL)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(false) - output.writeString(exc.shortMessage) - output.writeString(exc.expandedMessage) - output.writeInt(exc.errorId) - } - } - log.error( - "A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job." - ) - throw exc - case t: Throwable => - val (shortMessage, expandedMessage, errorId) = handleForPython(t) - retryTransientErrors { - using(fs.createNoCompression(outputURL)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(false) - output.writeString(shortMessage) - output.writeString(expandedMessage) - output.writeInt(errorId) - } - } - log.error( - "An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job." - ) - throw t - } - } -} diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index e40364c8fdaa..cc461dcc5cad 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -166,40 +166,23 @@ object Worker { timer.end("readInputs") timer.start("executeFunction") - if (HailContext.isInitialized) { - HailContext.get.backend = new ServiceBackend( - null, - null, - new HailClassLoader(getClass().getClassLoader()), - null, - null, - None, - None, - null, - null, - null, - null, - scratchDir, - ) - } else { - HailContext( - // FIXME: workers should not have backends, but some things do need hail contexts - new ServiceBackend( - null, - null, - new HailClassLoader(getClass().getClassLoader()), - null, - null, - None, - None, - null, - null, - null, - null, - scratchDir, - ) - ) - } + + // FIXME: workers should not have backends, but some things do need hail contexts + val backend = new ServiceBackend( + null, + null, + null, + new HailClassLoader(getClass().getClassLoader()), + None, + null, + null, + null, + null, + null, + ) + + if (HailContext.isInitialized) HailContext.get.backend = backend + else HailContext(backend) val result = using(new ServiceTaskContext(i)) { htc => try diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index 1b8f3117305f..7e752855e69d 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -19,6 +19,7 @@ import is.hail.types.physical.{PStruct, PTuple} import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual._ import is.hail.utils._ +import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome import scala.collection.mutable @@ -76,7 +77,7 @@ object SparkBackend { private var theSparkBackend: SparkBackend = _ - def sparkContext(op: String): SparkContext = HailContext.sparkBackend(op).sc + def sparkContext(implicit E: Enclosing): SparkContext = HailContext.sparkBackend.sc def checkSparkCompatibility(jarVersion: String, sparkVersion: String): Unit = { def majorMinor(version: String): String = version.split("\\.", 3).take(2).mkString(".") @@ -346,17 +347,15 @@ class SparkBackend( def createExecuteContextForTests( timer: ExecutionTimer, region: Region, - selfContainedExecution: Boolean = true, ): ExecuteContext = new ExecuteContext( tmpdir, localTmpdir, this, - references.toMap, fs, region, timer, - if (selfContainedExecution) null else NonOwningTempFileManager(longLifeTempFileManager), + null, theHailClassLoader, flags, new BackendContext { @@ -364,18 +363,18 @@ class SparkBackend( ExecutionCache.forTesting }, IrMetadata(None), + references, ImmutableMap.empty, mutable.Map.empty, ImmutableMap.empty, ) - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - ExecutionTimer.logTime { timer => + override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = + ExecutionTimer.time { timer => ExecuteContext.scoped( tmpdir, localTmpdir, this, - references.toMap, fs, timer, null, @@ -386,6 +385,7 @@ class SparkBackend( ExecutionCache.fromFlags(flags, fs, tmpdir) }, IrMetadata(None), + references, bmCache, codeCache, persistedIr, @@ -450,7 +450,7 @@ class SparkBackend( def defaultParallelism: Int = sc.defaultParallelism - override def asSpark(op: String): SparkBackend = this + override def asSpark(implicit E: Enclosing): SparkBackend = this def close(): Unit = { bmCache.close() diff --git a/hail/src/main/scala/is/hail/expr/ir/GenericLines.scala b/hail/src/main/scala/is/hail/expr/ir/GenericLines.scala index 8adb4fe75eb5..000348f272fa 100644 --- a/hail/src/main/scala/is/hail/expr/ir/GenericLines.scala +++ b/hail/src/main/scala/is/hail/expr/ir/GenericLines.scala @@ -449,7 +449,7 @@ class GenericLinesRDDPartition(val index: Int, val context: Any) extends Partiti class GenericLinesRDD( @(transient @param) contexts: IndexedSeq[Any], body: (Any) => CloseableIterator[GenericLine], -) extends RDD[GenericLine](SparkBackend.sparkContext("GenericLinesRDD"), Seq()) { +) extends RDD[GenericLine](SparkBackend.sparkContext, Seq()) { protected def getPartitions: Array[Partition] = contexts.iterator.zipWithIndex.map { case (c, i) => diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index 4c528d79a989..f50b07d597a6 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -728,7 +728,7 @@ case class PartitionRVDReader(rvd: RVD, uidFieldName: String) extends PartitionR val upcastF = mb.genFieldThisRef[AsmFunction2RegionLongLong]("rvdreader_upcast") val broadcastRVD = - mb.getObject[BroadcastRVD](new BroadcastRVD(ctx.backend.asSpark("RVDReader"), rvd)) + mb.getObject[BroadcastRVD](new BroadcastRVD(ctx.backend.asSpark, rvd)) val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb @@ -3213,7 +3213,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { fileStack += filesToMerge filesToMerge = - ContextRDD.weaken(SparkBackend.sparkContext("TableMapRows.execute").parallelize( + ContextRDD.weaken(SparkBackend.sparkContext.parallelize( 0 until nToMerge, nToMerge, )) diff --git a/hail/src/main/scala/is/hail/expr/ir/TableValue.scala b/hail/src/main/scala/is/hail/expr/ir/TableValue.scala index 8b5441db48a0..a1dd38c72d13 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableValue.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableValue.scala @@ -188,7 +188,7 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow } def toDF(): DataFrame = - HailContext.sparkBackend("toDF").sparkSession.createDataFrame( + HailContext.sparkBackend.sparkSession.createDataFrame( rvd.toRows, typ.rowType.schema.asInstanceOf[StructType], ) diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala index e0af240efc56..3230023b7da3 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -123,9 +123,7 @@ object TableStageToRVD { partition = { ctx: Ref => _ts.partition(GetField(ctx, "context")) }, ) - val sparkContext = ctx.backend - .asSpark("TableStageToRVD") - .sc + val sparkContext = ctx.backend.asSpark.sc val globalsAndBroadcastVals = Let( diff --git a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala index d8193581821d..876bc831fc93 100644 --- a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala +++ b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala @@ -1685,7 +1685,7 @@ class PartitionedVCFRDD( file: String, @(transient @param) reverseContigMapping: Map[String, String], @(transient @param) _partitions: Array[Partition], -) extends RDD[WithContext[String]](SparkBackend.sparkContext("PartitionedVCFRDD"), Seq()) { +) extends RDD[WithContext[String]](SparkBackend.sparkContext, Seq()) { val contigRemappingBc = if (reverseContigMapping.size != 0) sparkContext.broadcast(reverseContigMapping) else null diff --git a/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala b/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala index 25328be4c110..d9e0e40f6299 100644 --- a/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala +++ b/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala @@ -44,7 +44,7 @@ case class CollectMatricesRDDPartition( } class CollectMatricesRDD(@transient var bms: IndexedSeq[BlockMatrix]) - extends RDD[BDM[Double]](SparkBackend.sparkContext("CollectMatricesRDD"), Nil) { + extends RDD[BDM[Double]](SparkBackend.sparkContext, Nil) { private val nBlocks = bms.map(_.blocks.getNumPartitions) private val firstPartition = nBlocks.scan(0)(_ + _).init @@ -114,7 +114,7 @@ object BlockMatrix { def apply(gp: GridPartitioner, piBlock: (GridPartitioner, Int) => ((Int, Int), BDM[Double])) : BlockMatrix = new BlockMatrix( - new RDD[((Int, Int), BDM[Double])](SparkBackend.sparkContext("BlockMatrix.apply"), Nil) { + new RDD[((Int, Int), BDM[Double])](SparkBackend.sparkContext, Nil) { override val partitioner = Some(gp) protected def getPartitions: Array[Partition] = Array.tabulate(gp.numPartitions)(pi => @@ -2199,7 +2199,7 @@ class WriteBlocksRDD( parentPartStarts: Array[Long], entryField: String, gp: GridPartitioner, -) extends RDD[(Int, String)](SparkBackend.sparkContext("WriteBlocksRDD"), Nil) { +) extends RDD[(Int, String)](SparkBackend.sparkContext, Nil) { require(gp.nRows == parentPartStarts.last) @@ -2369,7 +2369,7 @@ class BlockMatrixReadRowBlockedRDD( metadata: BlockMatrixMetadata, maybeMaximumCacheMemoryInBytes: Option[Int], ) extends RDD[RVDContext => Iterator[Long]]( - SparkBackend.sparkContext("BlockMatrixReadRowBlockedRDD"), + SparkBackend.sparkContext, Nil, ) { import BlockMatrixReadRowBlockedRDD._ diff --git a/hail/src/main/scala/is/hail/linalg/RowMatrix.scala b/hail/src/main/scala/is/hail/linalg/RowMatrix.scala index 6adbaca171ab..969ef2d4916b 100644 --- a/hail/src/main/scala/is/hail/linalg/RowMatrix.scala +++ b/hail/src/main/scala/is/hail/linalg/RowMatrix.scala @@ -286,7 +286,7 @@ class ReadBlocksAsRowsRDD( partFiles: IndexedSeq[String], partitionCounts: Array[Long], gp: GridPartitioner, -) extends RDD[(Long, Array[Double])](SparkBackend.sparkContext("ReadBlocksAsRowsRDD"), Nil) { +) extends RDD[(Long, Array[Double])](SparkBackend.sparkContext, Nil) { private val partitionStarts = partitionCounts.scanLeft(0L)(_ + _) diff --git a/hail/src/main/scala/is/hail/methods/MatrixExportEntriesByCol.scala b/hail/src/main/scala/is/hail/methods/MatrixExportEntriesByCol.scala index 43ade7d14184..5db999bbbba4 100644 --- a/hail/src/main/scala/is/hail/methods/MatrixExportEntriesByCol.scala +++ b/hail/src/main/scala/is/hail/methods/MatrixExportEntriesByCol.scala @@ -197,7 +197,7 @@ case class MatrixExportEntriesByCol( // clean up temporary files val temps = tempFolders.result() val fsBc = fs.broadcast - SparkBackend.sparkContext("MatrixExportEntriesByCol.execute").parallelize( + SparkBackend.sparkContext.parallelize( temps, (temps.length / 32).max(1), ).foreach(path => fsBc.value.delete(path, recursive = true)) diff --git a/hail/src/main/scala/is/hail/rvd/RVD.scala b/hail/src/main/scala/is/hail/rvd/RVD.scala index 9c463915722e..2cd97c3df035 100644 --- a/hail/src/main/scala/is/hail/rvd/RVD.scala +++ b/hail/src/main/scala/is/hail/rvd/RVD.scala @@ -1377,7 +1377,7 @@ object RVD { ) } - val sc = SparkBackend.sparkContext("writeRowsSplitFiles") + val sc = SparkBackend.sparkContext val localTmpdir = execCtx.localTmpdir val fs = execCtx.fs val fsBc = fs.broadcast diff --git a/hail/src/main/scala/is/hail/sparkextras/ContextRDD.scala b/hail/src/main/scala/is/hail/sparkextras/ContextRDD.scala index fe9c4d4e4aca..9ecefc16e133 100644 --- a/hail/src/main/scala/is/hail/sparkextras/ContextRDD.scala +++ b/hail/src/main/scala/is/hail/sparkextras/ContextRDD.scala @@ -86,7 +86,7 @@ object ContextRDD { def empty[T: ClassTag](): ContextRDD[T] = new ContextRDD( - SparkBackend.sparkContext("ContextRDD.empty").emptyRDD[RVDContext => Iterator[T]] + SparkBackend.sparkContext.emptyRDD[RVDContext => Iterator[T]] ) def union[T: ClassTag]( @@ -117,7 +117,7 @@ object ContextRDD { filterAndReplace: TextInputFilterAndReplace, ): ContextRDD[WithContext[String]] = ContextRDD.weaken( - SparkBackend.sparkContext("ContxtRDD.textFilesLines").textFilesLines( + SparkBackend.sparkContext.textFilesLines( files, nPartitions, ) @@ -129,12 +129,12 @@ object ContextRDD { weaken(sc.parallelize(data, nPartitions.getOrElse(sc.defaultMinPartitions))).map(x => x) def parallelize[T: ClassTag](data: Seq[T], numSlices: Int): ContextRDD[T] = - weaken(SparkBackend.sparkContext("ContextRDD.parallelize").parallelize(data, numSlices)).map { + weaken(SparkBackend.sparkContext.parallelize(data, numSlices)).map { x => x } def parallelize[T: ClassTag](data: Seq[T]): ContextRDD[T] = - weaken(SparkBackend.sparkContext("ContextRDD.parallelize").parallelize(data)).map(x => x) + weaken(SparkBackend.sparkContext.parallelize(data)).map(x => x) type ElementType[T] = RVDContext => Iterator[T] diff --git a/hail/src/main/scala/is/hail/sparkextras/IndexReadRDD.scala b/hail/src/main/scala/is/hail/sparkextras/IndexReadRDD.scala index d7316f3d7369..94e00942a2f5 100644 --- a/hail/src/main/scala/is/hail/sparkextras/IndexReadRDD.scala +++ b/hail/src/main/scala/is/hail/sparkextras/IndexReadRDD.scala @@ -15,7 +15,7 @@ class IndexReadRDD[T: ClassTag]( @transient val partFiles: Array[String], @transient val intervalBounds: Option[Array[Interval]], f: (IndexedFilePartition, TaskContext) => T, -) extends RDD[T](SparkBackend.sparkContext("IndexReadRDD"), Nil) { +) extends RDD[T](SparkBackend.sparkContext, Nil) { def getPartitions: Array[Partition] = Array.tabulate(partFiles.length) { i => IndexedFilePartition(i, partFiles(i), intervalBounds.map(_(i))) diff --git a/hail/src/main/scala/is/hail/types/virtual/package.scala b/hail/src/main/scala/is/hail/types/virtual/package.scala new file mode 100644 index 000000000000..032c862bafb4 --- /dev/null +++ b/hail/src/main/scala/is/hail/types/virtual/package.scala @@ -0,0 +1,12 @@ +package is.hail.types + +package virtual { + sealed abstract class Kind[T <: VType] extends Product with Serializable + + object Kinds { + case object Value extends Kind[Type] + case object Table extends Kind[TableType] + case object Matrix extends Kind[MatrixType] + case object BlockMatrix extends Kind[BlockMatrixType] + } +} diff --git a/hail/src/main/scala/is/hail/utils/SpillingCollectIterator.scala b/hail/src/main/scala/is/hail/utils/SpillingCollectIterator.scala index a8bb6d8babbd..cbf9833ed151 100644 --- a/hail/src/main/scala/is/hail/utils/SpillingCollectIterator.scala +++ b/hail/src/main/scala/is/hail/utils/SpillingCollectIterator.scala @@ -16,7 +16,7 @@ object SpillingCollectIterator { val nPartitions = rdd.partitions.length val x = new SpillingCollectIterator(localTmpdir, fs, nPartitions, sizeLimit) val ctc = classTag[T] - SparkBackend.sparkContext("SpillingCollectIterator.apply").runJob( + SparkBackend.sparkContext.runJob( rdd, (_, it: Iterator[T]) => it.toArray(ctc), 0 until nPartitions, diff --git a/hail/src/test/scala/is/hail/HailSuite.scala b/hail/src/test/scala/is/hail/HailSuite.scala index c4994aa4be84..f00b588f8a84 100644 --- a/hail/src/test/scala/is/hail/HailSuite.scala +++ b/hail/src/test/scala/is/hail/HailSuite.scala @@ -43,7 +43,7 @@ object HailSuite { lazy val hc: HailContext = { val hc = withSparkBackend() - hc.sparkBackend("HailSuite.hc").flags.set("lower", "1") + hc.backend.asSpark.flags.set("lower", "1") hc.checkRVDKeys = true hc } @@ -56,7 +56,7 @@ class HailSuite extends TestNGSuite { @BeforeClass def ensureHailContextInitialized(): Unit = hc - def backend: SparkBackend = hc.sparkBackend("HailSuite.backend") + def backend: SparkBackend = hc.backend.asSpark def sc: SparkContext = backend.sc diff --git a/hail/src/test/scala/is/hail/TestUtils.scala b/hail/src/test/scala/is/hail/TestUtils.scala index ff7a4ed48682..71ab9445a4d4 100644 --- a/hail/src/test/scala/is/hail/TestUtils.scala +++ b/hail/src/test/scala/is/hail/TestUtils.scala @@ -119,9 +119,9 @@ object TestUtils { if (agg.isDefined || !env.isEmpty || !args.isEmpty) throw new LowererUnsupportedOperation("can't test with aggs or user defined args/env") - HailContext.sparkBackend("TestUtils.loweredExecute") - .jvmLowerAndExecute(ctx, x, optimize = false, lowerTable = true, lowerBM = true, - print = bytecodePrinter) + HailContext.sparkBackend.jvmLowerAndExecute(ctx, x, optimize = false, lowerTable = true, + lowerBM = true, + print = bytecodePrinter) } def eval(x: IR): Any = ExecuteContext.scoped { ctx => diff --git a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala index 83a5fa9d5937..631a95c07159 100644 --- a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala @@ -2,12 +2,15 @@ package is.hail.backend import is.hail.HailFeatureFlags import is.hail.asm4s.HailClassLoader -import is.hail.backend.service.{ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload} +import is.hail.backend.service.{ + BatchJobConfig, ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload, +} import is.hail.io.fs.{CloudStorageFSConfig, RouterFS} import is.hail.services._ import is.hail.services.JobGroupStates.Success import is.hail.utils.{tokenUrlSafe, using} +import scala.collection.mutable import scala.reflect.io.{Directory, Path} import scala.util.Random @@ -24,9 +27,9 @@ import org.testng.annotations.Test class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionValues { @Test def testCreateJobPayload(): Unit = - withMockDriverContext { rpcConfig => + withMockDriverContext { case (rpcConfig, jobConfig) => val batchClient = mock[BatchClient] - using(ServiceBackend(batchClient, rpcConfig)) { backend => + using(ServiceBackend(batchClient, rpcConfig, jobConfig)) { backend => val contexts = Array.tabulate(1)(_.toString.getBytes) // verify that the service backend @@ -38,7 +41,7 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV when(batchClient.newBatch(any[BatchRequest])) thenAnswer { (batchRequest: BatchRequest) => - batchRequest.billing_project shouldEqual rpcConfig.billing_project + batchRequest.billing_project shouldEqual jobConfig.billing_project batchRequest.n_jobs shouldBe 0 batchRequest.attributes.get("name").value shouldBe backend.name batchId @@ -56,12 +59,12 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV jobGroup.absolute_parent_id shouldBe 0 jobs.length shouldEqual contexts.length jobs.foreach { payload => - payload.regions.value shouldBe rpcConfig.regions + payload.regions.value shouldBe jobConfig.regions payload.resources.value shouldBe JobResources( preemptible = true, - cpu = Some(rpcConfig.worker_cores), - memory = Some(rpcConfig.worker_memory), - storage = Some(rpcConfig.storage), + cpu = Some(jobConfig.worker_cores), + memory = Some(jobConfig.worker_memory), + storage = Some(jobConfig.storage), ) } @@ -76,7 +79,7 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV jobGroupId shouldEqual 1 val resultsDir = - Path(backend.serviceBackendContext.remoteTmpDir) / + Path(rpcConfig.remote_tmpdir) / "parallelizeAndComputeWithIndex" / tokenUrlSafe @@ -97,7 +100,11 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV val (failure, _) = backend.parallelizeAndComputeWithIndex( - backend.serviceBackendContext, + ServiceBackendContext( + remoteTmpDir = rpcConfig.remote_tmpdir, + jobConfig = jobConfig, + executionCache = ExecutionCache.noCache, + ), backend.fs, contexts, "stage1", @@ -110,59 +117,52 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV } } - def ServiceBackend(client: BatchClient, rpcConfig: ServiceBackendRPCPayload): ServiceBackend = { + def ServiceBackend( + client: BatchClient, + rpcConfig: ServiceBackendRPCPayload, + jobConfig: BatchJobConfig, + ): ServiceBackend = { val flags = HailFeatureFlags.fromEnv() val fs = RouterFS.buildRoutes(CloudStorageFSConfig()) new ServiceBackend( - jarLocation = "us-docker.pkg.dev/hail-vdc/hail/hailgenetics/hail@sha256:fake", name = "name", - theHailClassLoader = new HailClassLoader(getClass.getClassLoader), - references = Map.empty, batchClient = client, - curBatchId = None, - curJobGroupId = None, + jarLocation = "us-docker.pkg.dev/hail-vdc/hail/hailgenetics/hail@sha256:fake", + theHailClassLoader = new HailClassLoader(getClass.getClassLoader), + batchConfig = None, + rpcConfig = rpcConfig, + jobConfig = jobConfig, flags = flags, - tmpdir = rpcConfig.tmp_dir, fs = fs, - serviceBackendContext = - new ServiceBackendContext( - rpcConfig.billing_project, - rpcConfig.remote_tmpdir, - rpcConfig.worker_cores, - rpcConfig.worker_memory, - rpcConfig.storage, - rpcConfig.regions, - rpcConfig.cloudfuse_configs, - profile = false, - ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), - ), - scratchDir = rpcConfig.remote_tmpdir, + references = mutable.Map.empty, ) } - def withMockDriverContext(test: ServiceBackendRPCPayload => Any): Any = + def withMockDriverContext(test: (ServiceBackendRPCPayload, BatchJobConfig) => Any): Any = using(LocalTmpFolder) { tmp => withObjectSpied[is.hail.utils.UtilsType] { // not obvious how to pull out `tokenUrlSafe` and inject this directory // using a spy is a hack and i don't particularly like it. when(is.hail.utils.tokenUrlSafe) thenAnswer "TOKEN" - test { + test( ServiceBackendRPCPayload( tmp_dir = tmp.path, remote_tmpdir = tmp.path, + flags = Map(), + custom_references = Array(), + liftovers = Map(), + sequences = Map(), + ), + BatchJobConfig( billing_project = "fancy", worker_cores = "128", worker_memory = "a lot.", storage = "a big ssd?", cloudfuse_configs = Array(), regions = Array("lunar1"), - flags = Map(), - custom_references = Array(), - liftovers = Map(), - sequences = Map(), - ) - } + ), + ) } } diff --git a/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala b/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala index 81145a20bcaf..52a901176cd1 100644 --- a/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala @@ -15,6 +15,8 @@ import is.hail.types.physical.stypes.primitives.{SInt32, SInt32Value} import is.hail.utils.{using, FastSeq} import is.hail.variant.{Locus, ReferenceGenome} +import scala.collection.mutable + import org.scalatest.matchers.should.Matchers.{be, convertToAnyShouldWrapper} import org.testng.annotations.Test @@ -60,7 +62,7 @@ class StagedMinHeapSuite extends HailSuite { @Test def testLocus(): Unit = forAll(loci) { case (rg: ReferenceGenome, loci: IndexedSeq[Locus]) => - ctx.local(references = Map(rg.name -> rg)) { ctx => + ctx.local(references = mutable.Map(rg.name -> rg)) { ctx => implicit val coercions: StagedCoercions[Locus] = stagedLocusCoercions(rg) diff --git a/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala b/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala index 41ac9aa65e32..f97bce884364 100644 --- a/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala +++ b/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala @@ -1,7 +1,7 @@ package is.hail.variant import is.hail.{HailSuite, TestUtils} -import is.hail.backend.{ExecuteContext, HailStateManager} +import is.hail.backend.HailStateManager import is.hail.check.Prop._ import is.hail.check.Properties import is.hail.expr.ir.EmitFunctionBuilder @@ -9,18 +9,18 @@ import is.hail.io.reference.{FASTAReader, FASTAReaderConfig, LiftOver} import is.hail.types.virtual.TLocus import is.hail.utils._ +import scala.collection.mutable + import htsjdk.samtools.reference.ReferenceSequenceFileFactory import org.testng.annotations.Test class ReferenceGenomeSuite extends HailSuite { - def hasReference(name: String) = ctx.stateManager.referenceGenomes.contains(name) - - def getReference(name: String) = ctx.getReference(name) + def hasReference(name: String) = ctx.References.contains(name) @Test def testGRCh37(): Unit = { assert(hasReference(ReferenceGenome.GRCh37)) - val grch37 = getReference(ReferenceGenome.GRCh37) + val grch37 = ctx.References(ReferenceGenome.GRCh37) assert(grch37.inX("X") && grch37.inY("Y") && grch37.isMitochondrial("MT")) assert(grch37.contigLength("1") == 249250621) @@ -36,7 +36,7 @@ class ReferenceGenomeSuite extends HailSuite { @Test def testGRCh38(): Unit = { assert(hasReference(ReferenceGenome.GRCh38)) - val grch38 = getReference(ReferenceGenome.GRCh38) + val grch38 = ctx.References(ReferenceGenome.GRCh38) assert(grch38.inX("chrX") && grch38.inY("chrY") && grch38.isMitochondrial("chrM")) assert(grch38.contigLength("chr1") == 248956422) @@ -104,12 +104,12 @@ class ReferenceGenomeSuite extends HailSuite { @Test def testContigRemap(): Unit = { val mapping = Map("23" -> "foo") TestUtils.interceptFatal("have remapped contigs in reference genome")( - getReference(ReferenceGenome.GRCh37).validateContigRemap(mapping) + ctx.References(ReferenceGenome.GRCh37).validateContigRemap(mapping) ) } @Test def testComparisonOps(): Unit = { - val rg = getReference(ReferenceGenome.GRCh37) + val rg = ctx.References(ReferenceGenome.GRCh37) // Test contigs assert(rg.compare("3", "18") < 0) @@ -128,7 +128,7 @@ class ReferenceGenomeSuite extends HailSuite { @Test def testWriteToFile(): Unit = { val tmpFile = ctx.createTmpPath("grWrite", "json") - val rg = getReference(ReferenceGenome.GRCh37) + val rg = ctx.References(ReferenceGenome.GRCh37) rg.copy(name = "GRCh37_2").write(fs, tmpFile) val gr2 = ReferenceGenome.fromFile(fs, tmpFile) @@ -222,23 +222,22 @@ class ReferenceGenomeSuite extends HailSuite { } @Test def testSerializeOnFB(): Unit = { - ExecuteContext.scoped { ctx => - val grch38 = ctx.getReference(ReferenceGenome.GRCh38) - val fb = EmitFunctionBuilder[String, Boolean](ctx, "serialize_rg") - val rgfield = fb.getReferenceGenome(grch38.name) - fb.emit(rgfield.invoke[String, Boolean]("isValidContig", fb.getCodeParam[String](1))) - - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r) + val grch38 = ctx.References(ReferenceGenome.GRCh38) + val fb = EmitFunctionBuilder[String, Boolean](ctx, "serialize_rg") + val rgfield = fb.getReferenceGenome(grch38.name) + fb.emit(rgfield.invoke[String, Boolean]("isValidContig", fb.getCodeParam[String](1))) + ctx.scopedExecution { (cl, fs, tc, r) => + val f = fb.resultWithIndex()(cl, fs, tc, r) assert(f("X") == grch38.isValidContig("X")) } } - @Test def testSerializeWithLiftoverOnFB(): Unit = { - ExecuteContext.scoped { ctx => - val grch37 = ctx.getReference(ReferenceGenome.GRCh37) + @Test def testSerializeWithLiftoverOnFB(): Unit = + ctx.local(references = mutable.Map(ctx.References.mapValues(_.copy()).toSeq: _*)) { ctx => + val grch37 = ctx.References(ReferenceGenome.GRCh37) val liftoverFile = "src/test/resources/grch37_to_grch38_chr20.over.chain.gz" - grch37.addLiftover(ctx.references("GRCh38"), LiftOver(ctx.fs, liftoverFile)) + grch37.addLiftover(ctx.References("GRCh38"), LiftOver(ctx.fs, liftoverFile)) val fb = EmitFunctionBuilder[String, Locus, Double, (Locus, Boolean)](ctx, "serialize_with_liftover") @@ -250,13 +249,13 @@ class ReferenceGenomeSuite extends HailSuite { fb.getCodeParam[Double](3), )) - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r) - assert(f("GRCh38", Locus("20", 60001), 0.95) == grch37.liftoverLocus( - "GRCh38", - Locus("20", 60001), - 0.95, - )) - grch37.removeLiftover("GRCh38") + ctx.scopedExecution { (cl, fs, tc, r) => + val f = fb.resultWithIndex()(cl, fs, tc, r) + assert(f("GRCh38", Locus("20", 60001), 0.95) == grch37.liftoverLocus( + "GRCh38", + Locus("20", 60001), + 0.95, + )) + } } - } }