From 5300b18b020d7e959299b504c320fe70763b1cf9 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Wed, 18 Sep 2024 16:46:37 -0400 Subject: [PATCH] [query] unify backend rpc --- hail/python/hail/backend/backend.py | 80 ++- hail/python/hail/backend/local_backend.py | 2 +- hail/python/hail/backend/py4j_backend.py | 21 +- hail/python/hail/backend/service_backend.py | 169 ++----- hail/python/test/hail/test_ir.py | 18 +- hail/src/main/scala/is/hail/HailContext.scala | 9 +- .../main/scala/is/hail/backend/Backend.scala | 98 +--- .../scala/is/hail/backend/BackendRpc.scala | 254 ++++++++++ .../scala/is/hail/backend/BackendServer.scala | 154 +++--- .../is/hail/backend/ExecuteContext.scala | 21 +- .../is/hail/backend/local/LocalBackend.scala | 7 +- .../backend/py4j/Py4JBackendExtensions.scala | 49 +- .../hail/backend/service/ServiceBackend.scala | 473 +++++++----------- .../is/hail/backend/service/Worker.scala | 49 +- .../is/hail/backend/spark/SparkBackend.scala | 16 +- .../scala/is/hail/expr/ir/GenericLines.scala | 2 +- .../main/scala/is/hail/expr/ir/TableIR.scala | 4 +- .../scala/is/hail/expr/ir/TableValue.scala | 2 +- .../is/hail/expr/ir/functions/Functions.scala | 58 ++- .../expr/ir/lowering/RVDToTableStage.scala | 4 +- .../scala/is/hail/io/plink/LoadPlink.scala | 6 +- .../main/scala/is/hail/io/vcf/LoadVCF.scala | 2 +- .../scala/is/hail/linalg/BlockMatrix.scala | 8 +- .../main/scala/is/hail/linalg/RowMatrix.scala | 2 +- .../methods/MatrixExportEntriesByCol.scala | 2 +- hail/src/main/scala/is/hail/rvd/RVD.scala | 2 +- .../is/hail/sparkextras/ContextRDD.scala | 8 +- .../is/hail/sparkextras/IndexReadRDD.scala | 2 +- .../scala/is/hail/types/virtual/package.scala | 12 + .../hail/utils/SpillingCollectIterator.scala | 2 +- hail/src/test/scala/is/hail/HailSuite.scala | 4 +- hail/src/test/scala/is/hail/TestUtils.scala | 6 +- .../is/hail/backend/ServiceBackendSuite.scala | 72 +-- .../is/hail/expr/ir/StagedMinHeapSuite.scala | 4 +- .../hail/variant/ReferenceGenomeSuite.scala | 55 +- 35 files changed, 822 insertions(+), 855 deletions(-) create mode 100644 hail/src/main/scala/is/hail/backend/BackendRpc.scala create mode 100644 hail/src/main/scala/is/hail/types/virtual/package.scala diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 995d0b668c4..e7e14bbe90a 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -3,7 +3,7 @@ import zipfile from dataclasses import dataclass from enum import Enum -from typing import AbstractSet, Any, ClassVar, Dict, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import AbstractSet, Any, ClassVar, Dict, List, Mapping, Optional, Set, Tuple, TypeVar, Union import orjson @@ -68,15 +68,45 @@ def local_jar_information() -> LocalJarInformation: raise ValueError(f'Hail requires either {hail_jar} or {hail_all_spark_jar}.') +class IRFunction: + def __init__( + self, + name: str, + type_parameters: Union[Tuple[HailType, ...], List[HailType]], + value_parameter_names: Union[Tuple[str, ...], List[str]], + value_parameter_types: Union[Tuple[HailType, ...], List[HailType]], + return_type: HailType, + body: Expression, + ): + assert len(value_parameter_names) == len(value_parameter_types) + render = CSERenderer() + self._name = name + self._type_parameters = type_parameters + self._value_parameter_names = value_parameter_names + self._value_parameter_types = value_parameter_types + self._return_type = return_type + self._rendered_body = render(finalize_randomness(body._ir)) + + def to_dataclass(self): + return SerializedIRFunction( + name=self._name, + type_parameters=[tp._parsable_string() for tp in self._type_parameters], + value_parameter_names=list(self._value_parameter_names), + value_parameter_types=[vpt._parsable_string() for vpt in self._value_parameter_types], + return_type=self._return_type._parsable_string(), + rendered_body=self._rendered_body, + ) + + class ActionTag(Enum): - LOAD_REFERENCES_FROM_DATASET = 1 - VALUE_TYPE = 2 - TABLE_TYPE = 3 - MATRIX_TABLE_TYPE = 4 - BLOCK_MATRIX_TYPE = 5 - EXECUTE = 6 - PARSE_VCF_METADATA = 7 - IMPORT_FAM = 8 + VALUE_TYPE = 1 + TABLE_TYPE = 2 + MATRIX_TABLE_TYPE = 3 + BLOCK_MATRIX_TYPE = 4 + EXECUTE = 5 + PARSE_VCF_METADATA = 6 + IMPORT_FAM = 7 + LOAD_REFERENCES_FROM_DATASET = 8 FROM_FASTA_FILE = 9 @@ -90,11 +120,21 @@ class IRTypePayload(ActionPayload): ir: str +@dataclass +class SerializedIRFunction: + name: str + type_parameters: List[str] + value_parameter_names: List[str] + value_parameter_types: List[str] + return_type: str + rendered_body: str + + @dataclass class ExecutePayload(ActionPayload): ir: str + fns: List[SerializedIRFunction] stream_codec: str - timed: bool @dataclass @@ -164,6 +204,8 @@ def _valid_flags(self) -> AbstractSet[str]: def __init__(self): self._persisted_locations = dict() self._references = {} + self.functions: List[IRFunction] = [] + self._registered_ir_function_names: Set[str] = set() @abc.abstractmethod def validate_file(self, uri: str): @@ -171,10 +213,15 @@ def validate_file(self, uri: str): @abc.abstractmethod def stop(self): - pass + self.functions = [] + self._registered_ir_function_names = set() def execute(self, ir: BaseIR, timed: bool = False) -> Any: - payload = ExecutePayload(self._render_ir(ir), '{"name":"StreamBufferSpec"}', timed) + payload = ExecutePayload( + self._render_ir(ir), + fns=[fn.to_dataclass() for fn in self.functions], + stream_codec='{"name":"StreamBufferSpec"}', + ) try: result, timings = self._rpc(ActionTag.EXECUTE, payload) except FatalError as e: @@ -300,7 +347,6 @@ def unpersist(self, dataset: Dataset) -> Dataset: tempfile.__exit__(None, None, None) return unpersisted - @abc.abstractmethod def register_ir_function( self, name: str, @@ -310,11 +356,13 @@ def register_ir_function( return_type: HailType, body: Expression, ): - pass + self._registered_ir_function_names.add(name) + self.functions.append( + IRFunction(name, type_parameters, value_parameter_names, value_parameter_types, return_type, body) + ) - @abc.abstractmethod def _is_registered_ir_function_name(self, name: str) -> bool: - pass + return name in self._registered_ir_function_names @abc.abstractmethod def persist_expression(self, expr: Expression) -> Expression: diff --git a/hail/python/hail/backend/local_backend.py b/hail/python/hail/backend/local_backend.py index 708eacaef2d..7bcb1145259 100644 --- a/hail/python/hail/backend/local_backend.py +++ b/hail/python/hail/backend/local_backend.py @@ -119,7 +119,7 @@ def register_ir_function( ) def stop(self): - super().stop() + super(Py4JBackend, self).stop() self._exit_stack.close() uninstall_exception_handler() diff --git a/hail/python/hail/backend/py4j_backend.py b/hail/python/hail/backend/py4j_backend.py index b9f986e5834..697269b152b 100644 --- a/hail/python/hail/backend/py4j_backend.py +++ b/hail/python/hail/backend/py4j_backend.py @@ -4,7 +4,7 @@ import socketserver import sys from threading import Thread -from typing import Mapping, Optional, Set, Tuple +from typing import Mapping, Optional, Tuple import orjson import py4j @@ -192,8 +192,6 @@ def decode_bytearray(encoded): self._backend_server.start() self._requests_session = requests.Session() - self._registered_ir_function_names: Set[str] = set() - # This has to go after creating the SparkSession. Unclear why. # Maybe it does its own patch? install_exception_handler() @@ -239,9 +237,6 @@ def persist_expression(self, expr): t = expr.dtype return construct_expr(JavaIR(t, self._jbackend.pyExecuteLiteral(self._render_ir(expr._ir))), t) - def _is_registered_ir_function_name(self, name: str) -> bool: - return name in self._registered_ir_function_names - def set_flags(self, **flags: Mapping[str, str]): available = self._jbackend.pyAvailableFlags() invalid = [] @@ -276,12 +271,6 @@ def add_liftover(self, name, chain_file, dest_reference_genome): def remove_liftover(self, name, dest_reference_genome): self._jbackend.pyRemoveLiftover(name, dest_reference_genome) - def _parse_value_ir(self, code, ref_map={}): - return self._jbackend.parse_value_ir( - code, - {k: t._parsable_string() for k, t in ref_map.items()}, - ) - def _register_ir_function(self, name, type_parameters, argument_names, argument_types, return_type, code): self._registered_ir_function_names.add(name) self._jbackend.pyRegisterIR( @@ -293,12 +282,6 @@ def _register_ir_function(self, name, type_parameters, argument_names, argument_ code, ) - def _parse_table_ir(self, code): - return self._jbackend.parse_table_ir(code) - - def _parse_matrix_ir(self, code): - return self._jbackend.parse_matrix_ir(code) - def _parse_blockmatrix_ir(self, code): return self._jbackend.parse_blockmatrix_ir(code) @@ -310,5 +293,5 @@ def stop(self): self._jbackend.close() self._jhc.stop() self._jhc = None - self._registered_ir_function_names = set() uninstall_exception_handler() + super().stop() diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index 5a5a2457173..b56bde18497 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -12,10 +12,6 @@ import hailtop.aiotools.fs as afs from hail.context import TemporaryDirectory, TemporaryFilename, revision, tmp_dir, version from hail.experimental import read_expression, write_expression -from hail.expr.expressions.base_expression import Expression -from hail.expr.types import HailType -from hail.ir import finalize_randomness -from hail.ir.renderer import CSERenderer from hail.utils import FatalError from hailtop import yamlx from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration, get_gcs_requester_pays_configuration @@ -32,7 +28,7 @@ from ..builtin_references import BUILTIN_REFERENCES from ..utils import ANY_REGION -from .backend import ActionPayload, ActionTag, Backend, ExecutePayload, fatal_error_from_java_error_triplet +from .backend import ActionPayload, ActionTag, Backend, fatal_error_from_java_error_triplet ReferenceGenomeConfig = Dict[str, Any] @@ -66,53 +62,6 @@ async def read_str(strm: afs.ReadableStream) -> str: return b.decode('utf-8') -@dataclass -class SerializedIRFunction: - name: str - type_parameters: List[str] - value_parameter_names: List[str] - value_parameter_types: List[str] - return_type: str - rendered_body: str - - -class IRFunction: - def __init__( - self, - name: str, - type_parameters: Union[Tuple[HailType, ...], List[HailType]], - value_parameter_names: Union[Tuple[str, ...], List[str]], - value_parameter_types: Union[Tuple[HailType, ...], List[HailType]], - return_type: HailType, - body: Expression, - ): - assert len(value_parameter_names) == len(value_parameter_types) - render = CSERenderer() - self._name = name - self._type_parameters = type_parameters - self._value_parameter_names = value_parameter_names - self._value_parameter_types = value_parameter_types - self._return_type = return_type - self._rendered_body = render(finalize_randomness(body._ir)) - - def to_dataclass(self): - return SerializedIRFunction( - name=self._name, - type_parameters=[tp._parsable_string() for tp in self._type_parameters], - value_parameter_names=list(self._value_parameter_names), - value_parameter_types=[vpt._parsable_string() for vpt in self._value_parameter_types], - return_type=self._return_type._parsable_string(), - rendered_body=self._rendered_body, - ) - - -@dataclass -class ServiceBackendExecutePayload(ActionPayload): - functions: List[SerializedIRFunction] - idempotency_token: str - payload: ExecutePayload - - @dataclass class CloudfuseConfig: bucket: str @@ -129,17 +78,21 @@ class SequenceConfig: @dataclass class ServiceBackendRPCConfig: tmp_dir: str - remote_tmpdir: str + flags: Dict[str, str] + custom_references: List[str] + liftovers: Dict[str, Dict[str, str]] + sequences: Dict[str, SequenceConfig] + + +@dataclass +class BatchJobConfig: + token: str billing_project: str worker_cores: str worker_memory: str storage: str cloudfuse_configs: List[CloudfuseConfig] regions: List[str] - flags: Dict[str, str] - custom_references: List[str] - liftovers: Dict[str, Dict[str, str]] - sequences: Dict[str, SequenceConfig] class ServiceBackend(Backend): @@ -148,14 +101,14 @@ class ServiceBackend(Backend): DRIVER = "driver" # is.hail.backend.service.ServiceBackendSocketAPI2 protocol - LOAD_REFERENCES_FROM_DATASET = 1 - VALUE_TYPE = 2 - TABLE_TYPE = 3 - MATRIX_TABLE_TYPE = 4 - BLOCK_MATRIX_TYPE = 5 - EXECUTE = 6 - PARSE_VCF_METADATA = 7 - IMPORT_FAM = 8 + VALUE_TYPE = 1 + TABLE_TYPE = 2 + MATRIX_TABLE_TYPE = 3 + BLOCK_MATRIX_TYPE = 4 + EXECUTE = 5 + PARSE_VCF_METADATA = 6 + IMPORT_FAM = 7 + LOAD_REFERENCES_FROM_DATASET = 8 FROM_FASTA_FILE = 9 @staticmethod @@ -288,7 +241,6 @@ def __init__( self.batch_attributes = batch_attributes self.remote_tmpdir = remote_tmpdir self.flags: Dict[str, str] = {} - self.functions: List[IRFunction] = [] self._registered_ir_function_names: Set[str] = set() self.driver_cores = driver_cores self.driver_memory = driver_memory @@ -333,16 +285,16 @@ def logger(self): def stop(self): hail_event_loop().run_until_complete(self._stop()) + super().stop() async def _stop(self): await self._async_exit_stack.aclose() - self.functions = [] - self._registered_ir_function_names = set() async def _run_on_batch( self, name: str, service_backend_config: ServiceBackendRPCConfig, + job_config: BatchJobConfig, action: ActionTag, payload: ActionPayload, *, @@ -356,7 +308,8 @@ async def _run_on_batch( async with await self._async_fs.create(iodir + '/in') as infile: await infile.write( orjson.dumps({ - 'config': service_backend_config, + 'rpc_config': service_backend_config, + 'job_config': job_config, 'action': action.value, 'payload': payload, }) @@ -374,8 +327,8 @@ async def _run_on_batch( elif self.driver_memory is not None: resources['memory'] = str(self.driver_memory) - if service_backend_config.storage != '0Gi': - resources['storage'] = service_backend_config.storage + if job_config.storage != '0Gi': + resources['storage'] = job_config.storage j = self._batch.create_jvm_job( jar_spec=self.jar_spec, @@ -389,7 +342,7 @@ async def _run_on_batch( resources=resources, attributes={'name': name + '_driver'}, regions=self.regions, - cloudfuse=[(c.bucket, c.mount_path, c.read_only) for c in service_backend_config.cloudfuse_configs], + cloudfuse=[(c.bucket, c.mount_path, c.read_only) for c in job_config.cloudfuse_configs], profile=self.flags['profile'] is not None, ) await self._batch.submit(disable_progress_bar=True) @@ -466,11 +419,6 @@ def _rpc(self, action: ActionTag, payload: ActionPayload) -> Tuple[bytes, Option return self._cancel_on_ctrl_c(self._async_rpc(action, payload)) async def _async_rpc(self, action: ActionTag, payload: ActionPayload): - if isinstance(payload, ExecutePayload): - payload = ServiceBackendExecutePayload( - [f.to_dataclass() for f in self.functions], self._batch.token, payload - ) - storage_requirement_bytes = 0 readonly_fuse_buckets: Set[str] = set() @@ -485,31 +433,37 @@ async def _async_rpc(self, action: ActionTag, payload: ActionPayload): readonly_fuse_buckets.add(bucket) storage_requirement_bytes += await (await self._async_fs.statfile(blob)).size() sequence_file_mounts[rg_name] = SequenceConfig( - f'/cloudfuse/{fasta_bucket}/{fasta_path}', f'/cloudfuse/{index_bucket}/{index_path}' + f'/cloudfuse/{fasta_bucket}/{fasta_path}', + f'/cloudfuse/{index_bucket}/{index_path}', ) - storage_gib_str = f'{math.ceil(storage_requirement_bytes / 1024 / 1024 / 1024)}Gi' - qob_config = ServiceBackendRPCConfig( - tmp_dir=tmp_dir(), - remote_tmpdir=self.remote_tmpdir, - billing_project=self.billing_project, - worker_cores=str(self.worker_cores), - worker_memory=str(self.worker_memory), - storage=storage_gib_str, - cloudfuse_configs=[ - CloudfuseConfig(bucket, f'/cloudfuse/{bucket}', True) for bucket in readonly_fuse_buckets - ], - regions=self.regions, - flags=self.flags, - custom_references=[ - orjson.dumps(rg._config).decode('utf-8') - for rg in self._references.values() - if rg.name not in BUILTIN_REFERENCES - ], - liftovers={rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0}, - sequences=sequence_file_mounts, + return await self._run_on_batch( + name=f'{action.name.lower()}(...)', + service_backend_config=ServiceBackendRPCConfig( + tmp_dir=self.remote_tmpdir, + flags=self.flags, + custom_references=[ + orjson.dumps(rg._config).decode('utf-8') + for rg in self._references.values() + if rg.name not in BUILTIN_REFERENCES + ], + liftovers={rg.name: rg._liftovers for rg in self._references.values() if len(rg._liftovers) > 0}, + sequences=sequence_file_mounts, + ), + job_config=BatchJobConfig( + token=self._batch.token, + billing_project=self.billing_project, + worker_cores=str(self.worker_cores), + worker_memory=str(self.worker_memory), + storage=f'{math.ceil(storage_requirement_bytes / 1024 / 1024 / 1024)}Gi', + cloudfuse_configs=[ + CloudfuseConfig(bucket, f'/cloudfuse/{bucket}', True) for bucket in readonly_fuse_buckets + ], + regions=self.regions, + ), + action=action, + payload=payload, ) - return await self._run_on_batch(f'{action.name.lower()}(...)', qob_config, action, payload) # Sequence and liftover information is stored on the ReferenceGenome # and there is no persistent backend to keep in sync. @@ -532,23 +486,6 @@ def add_liftover(self, name: str, chain_file: str, dest_reference_genome: str): def remove_liftover(self, name, dest_reference_genome): # pylint: disable=unused-argument pass - def register_ir_function( - self, - name: str, - type_parameters: Union[Tuple[HailType, ...], List[HailType]], - value_parameter_names: Union[Tuple[str, ...], List[str]], - value_parameter_types: Union[Tuple[HailType, ...], List[HailType]], - return_type: HailType, - body: Expression, - ): - self._registered_ir_function_names.add(name) - self.functions.append( - IRFunction(name, type_parameters, value_parameter_names, value_parameter_types, return_type, body) - ) - - def _is_registered_ir_function_name(self, name: str) -> bool: - return name in self._registered_ir_function_names - def persist_expression(self, expr): # FIXME: should use context manager to clean up persisted resources fname = TemporaryFilename(prefix='persist_expression').name diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index 81f5fb3b855..cdb1c874c36 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -2,7 +2,7 @@ import random import re import unittest -from test.hail.helpers import resource, skip_unless_spark_backend, skip_when_service_backend +from test.hail.helpers import resource, skip_unless_spark_backend import numpy as np import pytest @@ -189,12 +189,6 @@ def value_ir(value_irs, request): return value_irs[request.param] -@skip_when_service_backend() -def test_ir_parses(value_ir): - env = value_irs_env() - Env.backend()._parse_value_ir(str(value_ir), env) - - def test_ir_value_type(value_ir): env = value_irs_env() typ = Env.backend().value_type( @@ -313,11 +307,6 @@ def table_ir(table_irs, request): return table_irs[request.param] -@skip_when_service_backend() -def test_table_ir_parses(table_ir): - Env.backend()._parse_table_ir(str(table_ir)) - - def test_table_ir_table_type(table_ir): typ = Env.backend().table_type(table_ir) assert table_ir.typ == typ @@ -420,11 +409,6 @@ def matrix_ir(matrix_irs, request): return matrix_irs[request.param] -@skip_when_service_backend() -def test_matrix_ir_parses(matrix_ir): - Env.backend()._parse_matrix_ir(str(matrix_ir)) - - def test_matrix_ir_matrix_type(matrix_ir): typ = Env.backend().matrix_type(matrix_ir) assert typ == matrix_ir.typ diff --git a/hail/src/main/scala/is/hail/HailContext.scala b/hail/src/main/scala/is/hail/HailContext.scala index 3de89ce13cd..f52efe60ee7 100644 --- a/hail/src/main/scala/is/hail/HailContext.scala +++ b/hail/src/main/scala/is/hail/HailContext.scala @@ -19,6 +19,7 @@ import org.apache.spark.executor.InputMetrics import org.apache.spark.rdd.RDD import org.json4s.Extraction import org.json4s.jackson.JsonMethods +import sourcecode.Enclosing case class FilePartition(index: Int, file: String) extends Partition @@ -41,7 +42,7 @@ object HailContext { def backend: Backend = get.backend - def sparkBackend(op: String): SparkBackend = get.sparkBackend(op) + def sparkBackend(implicit E: Enclosing): SparkBackend = get.backend.asSpark def configureLogging(logFile: String, quiet: Boolean, append: Boolean): Unit = { org.apache.log4j.helpers.LogLog.setInternalDebugging(true) @@ -152,7 +153,7 @@ object HailContext { val fsBc = fs.broadcast - new RDD[T](SparkBackend.sparkContext("readPartition"), Nil) { + new RDD[T](SparkBackend.sparkContext, Nil) { def getPartitions: Array[Partition] = Array.tabulate(nPartitions)(i => FilePartition(i, partFiles(i))) @@ -175,8 +176,6 @@ class HailContext private ( ) { def stop(): Unit = HailContext.stop() - def sparkBackend(op: String): SparkBackend = backend.asSpark(op) - var checkRVDKeys: Boolean = false def version: String = is.hail.HAIL_PRETTY_VERSION @@ -188,7 +187,7 @@ class HailContext private ( maxLines: Int, ): Map[String, Array[WithContext[String]]] = { val regexp = regex.r - SparkBackend.sparkContext("fileAndLineCounts").textFilesLines(fs.globAll(files).map(_.getPath)) + SparkBackend.sparkContext.textFilesLines(fs.globAll(files).map(_.getPath)) .filter(line => regexp.findFirstIn(line.value).isDefined) .take(maxLines) .groupBy(_.source.file) diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index b593138426d..ce62b34ae97 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -1,28 +1,21 @@ package is.hail.backend -import is.hail.asm4s._ -import is.hail.backend.Backend.jsonToBytes +import is.hail.asm4s.HailClassLoader import is.hail.backend.spark.SparkBackend -import is.hail.expr.ir.{IR, IRParser, LoweringAnalyses, SortField, TableIR, TableReader} +import is.hail.expr.ir.{IR, LoweringAnalyses, SortField, TableIR, TableReader} import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.io.{BufferSpec, TypedCodecSpec} -import is.hail.io.fs._ -import is.hail.io.plink.LoadPlink -import is.hail.io.vcf.LoadVCF -import is.hail.types._ +import is.hail.io.fs.FS +import is.hail.types.RTable import is.hail.types.encoded.EType import is.hail.types.physical.PTuple -import is.hail.types.virtual.TFloat64 -import is.hail.utils._ -import is.hail.variant.ReferenceGenome +import is.hail.utils.ExecutionTimer.Timings +import is.hail.utils.fatal import scala.reflect.ClassTag -import java.io._ -import java.nio.charset.StandardCharsets +import java.io.{Closeable, OutputStream} -import org.json4s._ -import org.json4s.jackson.JsonMethods import sourcecode.Enclosing object Backend { @@ -38,23 +31,19 @@ object Backend { ctx: ExecuteContext, t: PTuple, off: Long, - bufferSpecString: String, + bufferSpec: BufferSpec, os: OutputStream, ): Unit = { - val bs = BufferSpec.parseOrDefault(bufferSpecString) assert(t.size == 1) val elementType = t.fields(0).typ val codec = TypedCodecSpec( EType.fromPythonTypeEncoding(elementType.virtualType), elementType.virtualType, - bs, + bufferSpec, ) assert(t.isFieldDefined(off, 0)) codec.encode(ctx, elementType, t.loadField(off, 0), os) } - - def jsonToBytes(f: => JValue): Array[Byte] = - JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8) } abstract class BroadcastValue[T] { def value: T } @@ -82,8 +71,8 @@ abstract class Backend extends Closeable { f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) - def asSpark(op: String): SparkBackend = - fatal(s"${getClass.getSimpleName}: $op requires SparkBackend") + def asSpark(implicit E: Enclosing): SparkBackend = + fatal(s"${getClass.getSimpleName}: ${E.value} requires SparkBackend") def shouldCacheQueryInfo: Boolean = true @@ -118,70 +107,7 @@ abstract class Backend extends Closeable { def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) : TableStage - def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T - - final def valueType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_value_ir(ctx, s).typ.toJSON - } - } - - final def tableType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_table_ir(ctx, s).typ.toJSON - } - } - - final def matrixTableType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_matrix_ir(ctx, s).typ.toJSON - } - } - - final def blockMatrixType(s: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - IRParser.parse_blockmatrix_ir(ctx, s).typ.toJSON - } - } - - def loadReferencesFromDataset(path: String): Array[Byte] - - def fromFASTAFile( - name: String, - fastaFile: String, - indexFile: String, - xContigs: Array[String], - yContigs: Array[String], - mtContigs: Array[String], - parInput: Array[String], - ): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, - xContigs, yContigs, mtContigs, parInput).toJSON - } - } - - def parseVCFMetadata(path: String): Array[Byte] = - withExecuteContext { ctx => - jsonToBytes { - Extraction.decompose { - LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) - }(defaultJSONFormats) - } - } - - def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String) - : Array[Byte] = - withExecuteContext { ctx => - LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missingValue).getBytes( - StandardCharsets.UTF_8 - ) - } + def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)] } diff --git a/hail/src/main/scala/is/hail/backend/BackendRpc.scala b/hail/src/main/scala/is/hail/backend/BackendRpc.scala new file mode 100644 index 00000000000..bcd749c624b --- /dev/null +++ b/hail/src/main/scala/is/hail/backend/BackendRpc.scala @@ -0,0 +1,254 @@ +package is.hail.backend + +import is.hail.expr.ir.IRParser +import is.hail.expr.ir.functions.IRFunctionRegistry +import is.hail.expr.ir.functions.IRFunctionRegistry.UserDefinedFnKey +import is.hail.io.BufferSpec +import is.hail.io.plink.LoadPlink +import is.hail.io.vcf.LoadVCF +import is.hail.services.retryTransientErrors +import is.hail.types.virtual.{Kind, TFloat64, VType} +import is.hail.types.virtual.Kinds._ +import is.hail.utils.{using, BoxedArrayBuilder, ExecutionTimer, FastSeq} +import is.hail.utils.ExecutionTimer.Timings +import is.hail.variant.ReferenceGenome + +import scala.util.control.NonFatal + +import java.io.ByteArrayOutputStream +import java.nio.charset.StandardCharsets + +import org.json4s.{DefaultFormats, Extraction, Formats, JArray, JValue} +import org.json4s.jackson.JsonMethods + +case class IRTypePayload(ir: String) +case class LoadReferencesFromDatasetPayload(path: String) + +case class FromFASTAFilePayload( + name: String, + fasta_file: String, + index_file: String, + x_contigs: Array[String], + y_contigs: Array[String], + mt_contigs: Array[String], + par: Array[String], +) + +case class ParseVCFMetadataPayload(path: String) +case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: String, missing: String) + +case class ExecutePayload( + ir: String, + fns: Array[SerializedIRFunction], + stream_codec: String, +) + +case class SerializedIRFunction( + name: String, + type_parameters: Array[String], + value_parameter_names: Array[String], + value_parameter_types: Array[String], + return_type: String, + rendered_body: String, +) + +trait BackendRpc { + + sealed trait Command extends Product with Serializable + + object Commands { + case class TypeOf(k: Kind[_ <: VType], ir: String) extends Command + case class Execute(ir: String, fs: Array[SerializedIRFunction], codec: String) extends Command + case class ParseVcfMetadata(path: String) extends Command + + case class ImportFam(path: String, isQuantPheno: Boolean, delimiter: String, missing: String) + extends Command + + case class LoadReferencesFromDataset(path: String) extends Command + + case class LoadReferencesFromFASTA( + name: String, + fasta_file: String, + index_file: String, + x_contigs: Array[String], + y_contigs: Array[String], + mt_contigs: Array[String], + par: Array[String], + ) extends Command + } + + trait Ask[Env] { + def command(env: Env): Command + } + + trait Context[Env] { + def scoped[A](env: Env)(f: ExecuteContext => A): (A, Timings) + } + + trait Write[Env] { + def timings(env: Env)(t: ExecutionTimer.Timings): Unit + def result(env: Env)(r: Array[Byte]): Unit + def error(env: Env)(t: Throwable): Unit + } + + implicit val fmts: Formats = DefaultFormats + + import Commands._ + + final def runRpc[A](env: A)(implicit Ask: Ask[A], Context: Context[A], Write: Write[A]): Unit = + try { + val command = Ask.command(env) + val (result, timings) = retryTransientErrors { + Context.scoped(env) { ctx => + command match { + case TypeOf(kind, s) => + jsonToBytes { + (kind match { + case Value => IRParser.parse_value_ir(ctx, s) + case Table => IRParser.parse_table_ir(ctx, s) + case Matrix => IRParser.parse_matrix_ir(ctx, s) + case BlockMatrix => IRParser.parse_blockmatrix_ir(ctx, s) + }).typ.toJSON + } + + case Execute(s, fns, codec) => + val bufferSpec = BufferSpec.parseOrDefault(codec) + withIRFunctionsReadFromInput(ctx, fns) { + val ir = IRParser.parse_value_ir(ctx, s) + val res = ctx.backend.execute(ctx, ir) + res match { + case Left(_) => Array.empty[Byte] + case Right((pt, off)) => + using(new ByteArrayOutputStream()) { os => + Backend.encodeToOutputStream(ctx, pt, off, bufferSpec, os) + os.toByteArray + } + } + } + + case ParseVcfMetadata(path) => + jsonToBytes { + val metadata = LoadVCF.parseHeaderMetadata(ctx.fs, Set.empty, TFloat64, path) + Extraction.decompose(metadata) + } + + case ImportFam(path, isQuantPheno, delimiter, missing) => + jsonToBytes { + LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missing) + } + + case LoadReferencesFromDataset(path) => + jsonToBytes { + val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) + ReferenceGenome.addFatalOnCollision(ctx.References, rgs) + JArray(rgs.map(_.toJSON).toList) + } + + case LoadReferencesFromFASTA(name, fasta, index, xContigs, yContigs, mtContigs, par) => + jsonToBytes { + val rg = ReferenceGenome.fromFASTAFile( + ctx, + name, + fasta, + index, + xContigs, + yContigs, + mtContigs, + par, + ) + ReferenceGenome.addFatalOnCollision(ctx.References, FastSeq(rg)) + rg.toJSON + } + } + } + } + + Write.result(env)(result) + Write.timings(env)(timings) + } catch { + case NonFatal(error) => Write.error(env)(error) + } + + def jsonToBytes(v: JValue): Array[Byte] = + JsonMethods.compact(v).getBytes(StandardCharsets.UTF_8) + + private[this] def withIRFunctionsReadFromInput[A]( + ctx: ExecuteContext, + serializedFunctions: Array[SerializedIRFunction], + )( + body: => A + ): A = { + val fns = new BoxedArrayBuilder[UserDefinedFnKey](serializedFunctions.length) + try { + for (func <- serializedFunctions) { + fns += IRFunctionRegistry.registerIR( + ctx, + func.name, + func.type_parameters, + func.value_parameter_names, + func.value_parameter_types, + func.return_type, + func.rendered_body, + ) + } + + body + } finally + for (i <- 0 until fns.length) + IRFunctionRegistry.unregisterIr(fns(i)) + } +} + +trait HttpLikeBackendRpc[A] extends BackendRpc { + + import Commands._ + + trait Routing extends Ask[A] { + + sealed trait Route extends Product with Serializable + + object Routes { + case class TypeOf(kind: Kind[_ <: VType]) extends Route + case object Execute extends Route + case object ParseVcfMetadata extends Route + case object ImportFam extends Route + case object LoadReferencesFromDataset extends Route + case object LoadReferencesFromFASTA extends Route + } + + def route(a: A): Route + def payload(a: A): JValue + + final override def command(a: A): Command = + route(a) match { + case Routes.TypeOf(k) => + TypeOf(k, payload(a).extract[IRTypePayload].ir) + case Routes.Execute => + val ExecutePayload(ir, fns, codec) = payload(a).extract[ExecutePayload] + Execute(ir, fns, codec) + case Routes.ParseVcfMetadata => + ParseVcfMetadata(payload(a).extract[ParseVCFMetadataPayload].path) + case Routes.ImportFam => + val config = payload(a).extract[ImportFamPayload] + ImportFam(config.path, config.quant_pheno, config.delimiter, config.missing) + case Routes.LoadReferencesFromDataset => + val path = payload(a).extract[LoadReferencesFromDatasetPayload].path + LoadReferencesFromDataset(path) + case Routes.LoadReferencesFromFASTA => + val config = payload(a).extract[FromFASTAFilePayload] + LoadReferencesFromFASTA( + config.name, + config.fasta_file, + config.index_file, + config.x_contigs, + config.y_contigs, + config.mt_contigs, + config.par, + ) + } + } + + implicit protected def Ask: Routing + implicit protected def Write: Write[A] + implicit protected def Context: Context[A] +} diff --git a/hail/src/main/scala/is/hail/backend/BackendServer.scala b/hail/src/main/scala/is/hail/backend/BackendServer.scala index 24e7f7b9875..db23bcb310f 100644 --- a/hail/src/main/scala/is/hail/backend/BackendServer.scala +++ b/hail/src/main/scala/is/hail/backend/BackendServer.scala @@ -1,41 +1,21 @@ package is.hail.backend -import is.hail.expr.ir.IRParser +import is.hail.types.virtual.Kinds.{BlockMatrix, Matrix, Table, Value} import is.hail.utils._ - -import scala.util.control.NonFatal +import is.hail.utils.ExecutionTimer.Timings import java.io.Closeable import java.net.InetSocketAddress -import java.nio.charset.StandardCharsets import java.util.concurrent._ +import com.google.api.client.http.HttpStatusCodes import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer} import org.json4s._ -import org.json4s.jackson.JsonMethods -import org.json4s.jackson.JsonMethods.compact - -case class IRTypePayload(ir: String) -case class LoadReferencesFromDatasetPayload(path: String) - -case class FromFASTAFilePayload( - name: String, - fasta_file: String, - index_file: String, - x_contigs: Array[String], - y_contigs: Array[String], - mt_contigs: Array[String], - par: Array[String], -) - -case class ParseVCFMetadataPayload(path: String) -case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: String, missing: String) -case class ExecutePayload(ir: String, stream_codec: String, timed: Boolean) +import org.json4s.jackson.{JsonMethods, Serialization} class BackendServer(backend: Backend) extends Closeable { // 0 => let the OS pick an available port private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) - private[this] val handler = new BackendHttpHandler(backend) private[this] val thread = { // This HTTP server *must not* start non-daemon threads because such threads keep the JVM @@ -59,7 +39,7 @@ class BackendServer(backend: Backend) extends Closeable { /* Source: * https://docs.oracle.com/en/java/javase/11/docs/api/jdk.httpserver/com/sun/net/httpserver/HttpServer.html#setExecutor(java.util.concurrent.Executor) */ // - httpServer.createContext("/", handler) + httpServer.createContext("/", Handler) httpServer.setExecutor(null) val t = Executors.defaultThreadFactory().newThread(new Runnable() { def run(): Unit = @@ -69,85 +49,75 @@ class BackendServer(backend: Backend) extends Closeable { t } - def port = httpServer.getAddress.getPort + def port: Int = httpServer.getAddress.getPort def start(): Unit = thread.start() override def close(): Unit = httpServer.stop(10) -} -class BackendHttpHandler(backend: Backend) extends HttpHandler { - def handle(exchange: HttpExchange): Unit = { - implicit val formats: Formats = DefaultFormats - - try { - val body = using(exchange.getRequestBody)(JsonMethods.parse(_)) - if (exchange.getRequestURI.getPath == "/execute") { - val ExecutePayload(irStr, streamCodec, timed) = body.extract[ExecutePayload] - backend.withExecuteContext { ctx => - val (res, timings) = ExecutionTimer.time { timer => - ctx.local(timer = timer) { ctx => - val irData = IRParser.parse_value_ir(ctx, irStr) - backend.execute(ctx, irData) - } - } - - if (timed) { - exchange.getResponseHeaders.add("X-Hail-Timings", compact(timings.toJSON)) - } - - res match { - case Left(_) => exchange.sendResponseHeaders(200, -1L) - case Right((t, off)) => - exchange.sendResponseHeaders(200, 0L) // 0 => an arbitrarily long response body - using(exchange.getResponseBody) { os => - Backend.encodeToOutputStream(ctx, t, off, streamCodec, os) - } - } + private case class Request(exchange: HttpExchange, payload: JValue) + + private[this] object Handler extends HttpHandler with HttpLikeBackendRpc[Request] { + + override def handle(exchange: HttpExchange): Unit = { + val payload = using(exchange.getRequestBody)(JsonMethods.parse(_)) + runRpc(Request(exchange, payload)) + } + + implicit override protected object Ask extends Routing { + + import Routes._ + + override def route(a: Request): Route = + a.exchange.getRequestURI.getPath match { + case "/value/type" => TypeOf(Value) + case "/table/type" => TypeOf(Table) + case "/matrixtable/type" => TypeOf(Matrix) + case "/blockmatrix/type" => TypeOf(BlockMatrix) + case "/execute" => Execute + case "/vcf/metadata/parse" => ParseVcfMetadata + case "/fam/import" => ImportFam + case "/references/load" => LoadReferencesFromDataset + case "/references/from_fasta" => LoadReferencesFromFASTA } - return - } - val response: Array[Byte] = exchange.getRequestURI.getPath match { - case "/value/type" => backend.valueType(body.extract[IRTypePayload].ir) - case "/table/type" => backend.tableType(body.extract[IRTypePayload].ir) - case "/matrixtable/type" => backend.matrixTableType(body.extract[IRTypePayload].ir) - case "/blockmatrix/type" => backend.blockMatrixType(body.extract[IRTypePayload].ir) - case "/references/load" => - backend.loadReferencesFromDataset(body.extract[LoadReferencesFromDatasetPayload].path) - case "/references/from_fasta" => - val config = body.extract[FromFASTAFilePayload] - backend.fromFASTAFile( - config.name, - config.fasta_file, - config.index_file, - config.x_contigs, - config.y_contigs, - config.mt_contigs, - config.par, - ) - case "/vcf/metadata/parse" => - backend.parseVCFMetadata(body.extract[ParseVCFMetadataPayload].path) - case "/fam/import" => - val config = body.extract[ImportFamPayload] - backend.importFam(config.path, config.quant_pheno, config.delimiter, config.missing) + override def payload(a: Request): JValue = a.payload + } + + implicit override protected object Write extends Write[Request] with ErrorHandling { + + override def timings(req: Request)(t: Timings): Unit = { + val ts = Serialization.write(Map("timings" -> t)) + req.exchange.getResponseHeaders.add("X-Hail-Timings", ts) } - exchange.sendResponseHeaders(200, response.length) - using(exchange.getResponseBody())(_.write(response)) - } catch { - case NonFatal(t) => - val (shortMessage, expandedMessage, errorId) = handleForPython(t) - val errorJson = JObject( - "short" -> JString(shortMessage), - "expanded" -> JString(expandedMessage), - "error_id" -> JInt(errorId), + override def result(req: Request)(result: Array[Byte]): Unit = + respond(req)(HttpStatusCodes.STATUS_CODE_OK, result) + + override def error(req: Request)(t: Throwable): Unit = + respond(req)( + HttpStatusCodes.STATUS_CODE_SERVER_ERROR, + jsonToBytes { + val (shortMessage, expandedMessage, errorId) = handleForPython(t) + JObject( + "short" -> JString(shortMessage), + "expanded" -> JString(expandedMessage), + "error_id" -> JInt(errorId), + ) + }, ) - val errorBytes = JsonMethods.compact(errorJson).getBytes(StandardCharsets.UTF_8) - exchange.sendResponseHeaders(500, errorBytes.length) - using(exchange.getResponseBody())(_.write(errorBytes)) + + private[this] def respond(req: Request)(code: Int, payload: Array[Byte]): Unit = { + req.exchange.sendResponseHeaders(code, payload.length) + using(req.exchange.getResponseBody)(_.write(payload)) + } + } + + implicit override protected object Context extends Context[Request] { + override def scoped[A](req: Request)(f: ExecuteContext => A): (A, Timings) = + backend.withExecuteContext(f) } } } diff --git a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala index a572bed0190..383b6c296b3 100644 --- a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala +++ b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala @@ -54,18 +54,15 @@ object NonOwningTempFileManager { } object ExecuteContext { - def scoped[T](f: ExecuteContext => T)(implicit E: Enclosing): T = { - val result = HailContext.sparkBackend("ExecuteContext.scoped").withExecuteContext( + def scoped[T](f: ExecuteContext => T)(implicit E: Enclosing): T = + HailContext.sparkBackend.withExecuteContext( selfContainedExecution = false )(f) - result - } def scoped[T]( tmpdir: String, localTmpdir: String, backend: Backend, - references: Map[String, ReferenceGenome], fs: FS, timer: ExecutionTimer, tempFileManager: TempFileManager, @@ -73,6 +70,7 @@ object ExecuteContext { flags: HailFeatureFlags, backendContext: BackendContext, irMetadata: IrMetadata, + references: mutable.Map[String, ReferenceGenome], blockMatrixCache: mutable.Map[String, BlockMatrix], codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], irCache: mutable.Map[Int, BaseIR], @@ -85,7 +83,6 @@ object ExecuteContext { tmpdir, localTmpdir, backend, - references, fs, region, timer, @@ -94,6 +91,7 @@ object ExecuteContext { flags, backendContext, irMetadata, + references, blockMatrixCache, codeCache, irCache, @@ -117,7 +115,6 @@ class ExecuteContext( val tmpdir: String, val localTmpdir: String, val backend: Backend, - val references: Map[String, ReferenceGenome], val fs: FS, val r: Region, val timer: ExecutionTimer, @@ -126,6 +123,7 @@ class ExecuteContext( val flags: HailFeatureFlags, val backendContext: BackendContext, val irMetadata: IrMetadata, + val References: mutable.Map[String, ReferenceGenome], val BlockMatrixCache: mutable.Map[String, BlockMatrix], val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]], val IrCache: mutable.Map[Int, BaseIR], @@ -142,7 +140,8 @@ class ExecuteContext( ) } - val stateManager = HailStateManager(references) + def stateManager: HailStateManager = + HailStateManager(References.toMap) val tempFileManager: TempFileManager = if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs) @@ -168,7 +167,7 @@ class ExecuteContext( def getFlag(name: String): String = flags.get(name) - def getReference(name: String): ReferenceGenome = references(name) + def getReference(name: String): ReferenceGenome = References(name) def shouldWriteIRFiles(): Boolean = getFlag("write_ir_files") != null @@ -188,7 +187,6 @@ class ExecuteContext( tmpdir: String = this.tmpdir, localTmpdir: String = this.localTmpdir, backend: Backend = this.backend, - references: Map[String, ReferenceGenome] = this.references, fs: FS = this.fs, r: Region = this.r, timer: ExecutionTimer = this.timer, @@ -196,6 +194,7 @@ class ExecuteContext( theHailClassLoader: HailClassLoader = this.theHailClassLoader, flags: HailFeatureFlags = this.flags, backendContext: BackendContext = this.backendContext, + references: mutable.Map[String, ReferenceGenome] = this.References, irMetadata: IrMetadata = this.irMetadata, blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache, codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache, @@ -207,7 +206,6 @@ class ExecuteContext( tmpdir, localTmpdir, backend, - references, fs, r, timer, @@ -216,6 +214,7 @@ class ExecuteContext( flags, backendContext, irMetadata, + references, blockMatrixCache, codeCache, irCache, diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index e6eb9630e32..cdaa214c7ed 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -16,6 +16,7 @@ import is.hail.types.physical.PTuple import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual.TVoid import is.hail.utils._ +import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome import scala.collection.mutable @@ -97,14 +98,13 @@ class LocalBackend( // flags can be set after construction from python def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags)) - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - ExecutionTimer.logTime { timer => + override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = + ExecutionTimer.time { timer => val fs = this.fs ExecuteContext.scoped( tmpdir, tmpdir, this, - references.toMap, fs, timer, null, @@ -115,6 +115,7 @@ class LocalBackend( ExecutionCache.fromFlags(flags, fs, tmpdir) }, new IrMetadata(), + references, ImmutableMap.empty, codeCache, persistedIR, diff --git a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala index 9229a03d536..40ba2491f2d 100644 --- a/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala +++ b/hail/src/main/scala/is/hail/backend/py4j/Py4JBackendExtensions.scala @@ -4,9 +4,8 @@ import is.hail.HailFeatureFlags import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager} import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} import is.hail.expr.ir.{ - BaseIR, BindingEnv, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IR, IRParser, Interpret, - MatrixIR, MatrixNativeReader, MatrixRead, Name, NativeReaderOptions, TableIR, TableLiteral, - TableValue, + BaseIR, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IRParser, Interpret, MatrixIR, + MatrixNativeReader, MatrixRead, NativeReaderOptions, TableLiteral, TableValue, } import is.hail.expr.ir.IRParser.parseType import is.hail.expr.ir.functions.IRFunctionRegistry @@ -14,21 +13,17 @@ import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} import is.hail.linalg.RowMatrix import is.hail.types.physical.PStruct import is.hail.types.virtual.{TArray, TInterval} -import is.hail.utils.{defaultJSONFormats, log, toRichIterable, FastSeq, HailException, Interval} +import is.hail.utils.{log, toRichIterable, FastSeq, HailException, Interval} import is.hail.variant.ReferenceGenome import scala.collection.mutable -import scala.jdk.CollectionConverters.{ - asScalaBufferConverter, mapAsScalaMapConverter, seqAsJavaListConverter, -} +import scala.jdk.CollectionConverters.{asScalaBufferConverter, seqAsJavaListConverter} -import java.nio.charset.StandardCharsets import java.util import org.apache.spark.sql.DataFrame import org.json4s -import org.json4s.Formats -import org.json4s.jackson.{JsonMethods, Serialization} +import org.json4s.jackson.JsonMethods import sourcecode.Enclosing trait Py4JBackendExtensions { @@ -140,7 +135,7 @@ trait Py4JBackendExtensions { val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) addJavaIR(ctx, field) } - } + }._1 def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = { val key = jKey.asScala.toArray.toFastSeq @@ -166,7 +161,7 @@ trait Py4JBackendExtensions { backend.withExecuteContext { ctx => val tir = IRParser.parse_table_ir(ctx, s) Interpret(tir, ctx).toDF() - } + }._1 def pyReadMultipleMatrixTables(jsonQuery: String): util.List[MatrixIR] = backend.withExecuteContext { ctx => @@ -192,7 +187,7 @@ trait Py4JBackendExtensions { } log.info("pyReadMultipleMatrixTables: returning N matrix tables") matrixReaders.asJava - } + }._1 def pyAddReference(jsonConfig: String): Unit = addReference(ReferenceGenome.fromJSON(jsonConfig)) @@ -214,37 +209,11 @@ trait Py4JBackendExtensions { private[this] def removeReference(name: String): Unit = references -= name - def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = - backend.withExecuteContext { ctx => - IRParser.parse_value_ir( - ctx, - s, - BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) => - Name(n) -> IRParser.parseType(t) - }.toSeq: _*), - ) - } - - def parse_table_ir(s: String): TableIR = - withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_table_ir(ctx, s)) - - def parse_matrix_ir(s: String): MatrixIR = - withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_matrix_ir(ctx, s)) - def parse_blockmatrix_ir(s: String): BlockMatrixIR = withExecuteContext(selfContainedExecution = false) { ctx => IRParser.parse_blockmatrix_ir(ctx, s) } - def loadReferencesFromDataset(path: String): Array[Byte] = - backend.withExecuteContext { ctx => - val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) - ReferenceGenome.addFatalOnCollision(references, rgs) - - implicit val formats: Formats = defaultJSONFormats - Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) - } - def withExecuteContext[T]( selfContainedExecution: Boolean = true )( @@ -255,5 +224,5 @@ trait Py4JBackendExtensions { val tempFileManager = longLifeTempFileManager if (selfContainedExecution && tempFileManager != null) f(ctx) else ctx.local(tempFileManager = NonOwningTempFileManager(tempFileManager))(f) - } + }._1 } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index ff4d54d4b7a..0cc5ebc3f75 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -4,13 +4,13 @@ import is.hail.{CancellingExecutorService, HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ +import is.hail.backend.service.ServiceBackend.MaxAvailableGcsConnections import is.hail.expr.Validate import is.hail.expr.ir.{ - IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck, + IR, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck, } import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.compile.Compile -import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ import is.hail.io.fs._ import is.hail.io.reference.{IndexedFastaSequenceFile, LiftOver} @@ -19,8 +19,9 @@ import is.hail.services.JobGroupStates.{Cancelled, Failure, Running, Success} import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType -import is.hail.types.virtual._ +import is.hail.types.virtual.{Kinds, TVoid} import is.hail.utils._ +import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome import scala.annotation.switch @@ -32,108 +33,36 @@ import java.nio.charset.StandardCharsets import java.nio.file.Path import java.util.concurrent._ -import org.apache.log4j.Logger import org.json4s.{DefaultFormats, Formats} import org.json4s.JsonAST._ -import org.json4s.jackson.{JsonMethods, Serialization} +import org.json4s.jackson.JsonMethods import sourcecode.Enclosing -class ServiceBackendContext( - val billingProject: String, - val remoteTmpDir: String, - val workerCores: String, - val workerMemory: String, - val storageRequirement: String, - val regions: Array[String], - val cloudfuseConfig: Array[CloudfuseConfig], - val profile: Boolean, - val executionCache: ExecutionCache, -) extends BackendContext with Serializable {} +case class ServiceBackendContext( + remoteTmpDir: String, + jobConfig: BatchJobConfig, + override val executionCache: ExecutionCache, +) extends BackendContext with Serializable object ServiceBackend { - - def apply( - jarLocation: String, - name: String, - theHailClassLoader: HailClassLoader, - batchClient: BatchClient, - batchConfig: BatchConfig, - scratchDir: String = sys.env.getOrElse("HAIL_WORKER_SCRATCH_DIR", ""), - rpcConfig: ServiceBackendRPCPayload, - env: Map[String, String], - ): ServiceBackend = { - - val flags = HailFeatureFlags.fromEnv(rpcConfig.flags) - val shouldProfile = flags.get("profile") != null - val fs = RouterFS.buildRoutes( - CloudStorageFSConfig.fromFlagsAndEnv( - Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), - flags, - env, - ) - ) - - val backendContext = new ServiceBackendContext( - rpcConfig.billing_project, - rpcConfig.remote_tmpdir, - rpcConfig.worker_cores, - rpcConfig.worker_memory, - rpcConfig.storage, - rpcConfig.regions, - rpcConfig.cloudfuse_configs, - shouldProfile, - ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), - ) - - val references = mutable.Map.empty[String, ReferenceGenome] - references ++= ReferenceGenome.builtinReferences() - ReferenceGenome.addFatalOnCollision( - references, - rpcConfig.custom_references.map(ReferenceGenome.fromJSON), - ) - - rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => - liftoversForSource.foreach { case (destGenome, chainFile) => - references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) - } - } - rpcConfig.sequences.foreach { case (rg, seq) => - references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) - } - - new ServiceBackend( - JarUrl(jarLocation), - name, - theHailClassLoader, - references.toMap, - batchClient, - batchConfig, - flags, - rpcConfig.tmp_dir, - fs, - backendContext, - scratchDir, - ) - } + val MaxAvailableGcsConnections = 1000 } class ServiceBackend( - val jarSpec: JarSpec, - var name: String, - val theHailClassLoader: HailClassLoader, - val references: Map[String, ReferenceGenome], - val batchClient: BatchClient, + val name: String, + batchClient: BatchClient, + jarSpec: JarSpec, + theHailClassLoader: HailClassLoader, val batchConfig: BatchConfig, - val flags: HailFeatureFlags, - val tmpdir: String, + rpcConfig: ServiceBackendRPCPayload, + jobConfig: BatchJobConfig, + flags: HailFeatureFlags, val fs: FS, - val serviceBackendContext: ServiceBackendContext, - val scratchDir: String, + references: mutable.Map[String, ReferenceGenome], ) extends Backend with Logging { private[this] var stageCount = 0 - private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000 - private[this] val executor = Executors.newFixedThreadPool(MAX_AVAILABLE_GCS_CONNECTIONS) + private[this] val executor = Executors.newFixedThreadPool(MaxAvailableGcsConnections) override def shouldCacheQueryInfo: Boolean = false @@ -161,7 +90,6 @@ class ServiceBackend( } private[this] def submitJobGroupAndWait( - backendContext: ServiceBackendContext, collection: IndexedSeq[Array[Byte]], token: String, root: String, @@ -181,13 +109,13 @@ class ServiceBackend( resources = Some( JobResources( preemptible = true, - cpu = Some(backendContext.workerCores).filter(_ != "None"), - memory = Some(backendContext.workerMemory).filter(_ != "None"), - storage = Some(backendContext.storageRequirement).filter(_ != "0Gi"), + cpu = Some(jobConfig.worker_cores).filter(_ != "None"), + memory = Some(jobConfig.worker_memory).filter(_ != "None"), + storage = Some(jobConfig.storage).filter(_ != "0Gi"), ) ), - regions = Some(backendContext.regions).filter(_.nonEmpty), - cloudfuse = Some(backendContext.cloudfuseConfig).filter(_.nonEmpty), + regions = Some(jobConfig.regions).filter(_.nonEmpty), + cloudfuse = Some(jobConfig.cloudfuse_configs).filter(_.nonEmpty), ) val jobs = @@ -281,7 +209,7 @@ class ServiceBackend( uploadFunction.get() uploadContexts.get() - val jobGroup = submitJobGroupAndWait(backendContext, parts, token, root, stageIdentifier) + val jobGroup = submitJobGroupAndWait(parts, token, root, stageIdentifier) log.info(s"parallelizeAndComputeWithIndex: $token: reading results") val startTime = System.nanoTime() @@ -376,40 +304,43 @@ class ServiceBackend( : TableStage = LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses) - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - ExecutionTimer.logTime { timer => + override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = + ExecutionTimer.time { timer => ExecuteContext.scoped( - tmpdir, - "file:///tmp", + rpcConfig.tmp_dir, + rpcConfig.tmp_dir, this, - references, fs, timer, null, theHailClassLoader, flags, - serviceBackendContext, + ServiceBackendContext( + rpcConfig.tmp_dir, + jobConfig, + ExecutionCache.fromFlags(flags, fs, rpcConfig.tmp_dir), + ), new IrMetadata(), + references, ImmutableMap.empty, mutable.Map.empty, ImmutableMap.empty, )(f) } - - override def loadReferencesFromDataset(path: String): Array[Byte] = - withExecuteContext { ctx => - val rgs = ReferenceGenome.fromHailDataset(ctx.fs, path) - ReferenceGenome.addFatalOnCollision(references, rgs) - implicit val formats: Formats = defaultJSONFormats - Serialization.write(rgs.map(_.toJSON).toFastSeq).getBytes(StandardCharsets.UTF_8) - } } class EndOfInputException extends RuntimeException class HailBatchFailure(message: String) extends RuntimeException(message) -object ServiceBackendAPI { - private[this] val log = Logger.getLogger(getClass.getName()) +case class Request( + backend: ServiceBackend, + fs: FS, + outputUrl: String, + action: Int, + payload: JValue, +) + +object ServiceBackendAPI extends HttpLikeBackendRpc[Request] with Logging { def main(argv: Array[String]): Unit = { assert(argv.length == 7, argv.toFastSeq) @@ -423,39 +354,70 @@ object ServiceBackendAPI { val inputURL = argv(5) val outputURL = argv(6) - implicit val formats: Formats = DefaultFormats + implicit val fmts: Formats = DefaultFormats - val fs = RouterFS.buildRoutes( + val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") + DeployConfig.set(deployConfig) + sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) + + var fs = RouterFS.buildRoutes( CloudStorageFSConfig.fromFlagsAndEnv( Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), HailFeatureFlags.fromEnv(), ) ) - val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json") - DeployConfig.set(deployConfig) - sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir) - val batchClient = BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")) - log.info("BatchClient allocated.") + val (rpcConfig, jobConfig, action, payload) = + using(fs.openNoCompression(inputURL)) { is => + val input = JsonMethods.parse(is) + ( + (input \ "rpc_config").extract[ServiceBackendRPCPayload], + (input \ "job_config").extract[BatchJobConfig], + (input \ "action").extract[Int], + input \ "payload", + ) + } + + // requester pays config is conveyed in feature flags currently + val featureFlags = HailFeatureFlags.fromEnv(rpcConfig.flags) + fs = RouterFS.buildRoutes( + CloudStorageFSConfig.fromFlagsAndEnv( + Some(Path.of(scratchDir, "secrets/gsa-key/key.json")), + featureFlags, + ) + ) + + val references = mutable.Map[String, ReferenceGenome]() + references ++= ReferenceGenome.builtinReferences() + ReferenceGenome.addFatalOnCollision( + references, + rpcConfig.custom_references.map(ReferenceGenome.fromJSON), + ) - val batchConfig = - BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")) - log.info("BatchConfig parsed.") + rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => + liftoversForSource.foreach { case (destGenome, chainFile) => + references(sourceGenome).addLiftover(references(destGenome), LiftOver(fs, chainFile)) + } + } - val input = using(fs.openNoCompression(inputURL))(JsonMethods.parse(_)) - val rpcConfig = (input \ "config").extract[ServiceBackendRPCPayload] + rpcConfig.sequences.foreach { case (rg, seq) => + references(rg).addSequence(IndexedFastaSequenceFile(fs, seq.fasta, seq.index)) + } // FIXME: when can the classloader be shared? (optimizer benefits!) - val backend = ServiceBackend( - jarLocation, + val backend = new ServiceBackend( name, - new HailClassLoader(getClass().getClassLoader()), - batchClient, - batchConfig, - scratchDir, + BatchClient(deployConfig, Path.of(scratchDir, "secrets/gsa-key/key.json")), + JarUrl(jarLocation), + new HailClassLoader(getClass.getClassLoader), + BatchConfig.fromConfigFile(Path.of(scratchDir, "batch-config/batch-config.json")), rpcConfig, - sys.env, + jobConfig, + featureFlags, + fs, + references, ) + log.info("ServiceBackend allocated.") if (HailContext.isInitialized) { HailContext.get.backend = backend @@ -465,9 +427,76 @@ object ServiceBackendAPI { log.info("HailContexet initialized.") } - val action = (input \ "action").extract[Int] - val payload = (input \ "payload") - new ServiceBackendAPI(backend, fs, outputURL).executeOneCommand(action, payload) + runRpc(Request(backend, fs, outputURL, action, payload)) + } + + implicit override protected object Ask extends Routing { + import Routes._ + + override def route(a: Request): Route = + (a.action: @switch) match { + case 1 => TypeOf(Kinds.Value) + case 2 => TypeOf(Kinds.Table) + case 3 => TypeOf(Kinds.Matrix) + case 4 => TypeOf(Kinds.BlockMatrix) + case 5 => Execute + case 6 => ParseVcfMetadata + case 7 => ImportFam + case 8 => LoadReferencesFromDataset + case 9 => LoadReferencesFromFASTA + } + + override def payload(a: Request): JValue = a.payload + } + + implicit override protected object Write extends Write[Request] { + + // service backend doesn't support sending timings back to the python client + override def timings(env: Request)(t: Timings): Unit = + () + + override def result(env: Request)(result: Array[Byte]): Unit = + retryTransientErrors { + using(env.fs.createNoCompression(env.outputUrl)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(true) + output.writeBytes(result) + } + } + + override def error(env: Request)(t: Throwable): Unit = + retryTransientErrors { + val (shortMessage, expandedMessage, errorId) = + t match { + case t: HailWorkerException => + log.error( + "A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job.", + t, + ) + (t.shortMessage, t.expandedMessage, t.errorId) + case _ => + log.error( + "An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job.", + t, + ) + handleForPython(t) + } + + using(env.fs.createNoCompression(env.outputUrl)) { outputStream => + val output = new HailSocketAPIOutputStream(outputStream) + output.writeBool(false) + output.writeString(shortMessage) + output.writeString(expandedMessage) + output.writeInt(errorId) + } + + throw t + } + } + + implicit override protected object Context extends Context[Request] { + override def scoped[A](env: Request)(f: ExecuteContext => A): (A, Timings) = + env.backend.withExecuteContext(f) } } @@ -508,174 +537,18 @@ case class SequenceConfig(fasta: String, index: String) case class ServiceBackendRPCPayload( tmp_dir: String, - remote_tmpdir: String, - billing_project: String, - worker_cores: String, - worker_memory: String, - storage: String, - cloudfuse_configs: Array[CloudfuseConfig], - regions: Array[String], flags: Map[String, String], custom_references: Array[String], liftovers: Map[String, Map[String, String]], sequences: Map[String, SequenceConfig], ) -case class ServiceBackendExecutePayload( - functions: Array[SerializedIRFunction], - idempotency_token: String, - payload: ExecutePayload, -) - -case class SerializedIRFunction( - name: String, - type_parameters: Array[String], - value_parameter_names: Array[String], - value_parameter_types: Array[String], - return_type: String, - rendered_body: String, +case class BatchJobConfig( + token: String, + billing_project: String, + worker_cores: String, + worker_memory: String, + storage: String, + cloudfuse_configs: Array[CloudfuseConfig], + regions: Array[String], ) - -class ServiceBackendAPI( - private[this] val backend: ServiceBackend, - private[this] val fs: FS, - private[this] val outputURL: String, -) extends Thread { - private[this] val LOAD_REFERENCES_FROM_DATASET = 1 - private[this] val VALUE_TYPE = 2 - private[this] val TABLE_TYPE = 3 - private[this] val MATRIX_TABLE_TYPE = 4 - private[this] val BLOCK_MATRIX_TYPE = 5 - private[this] val EXECUTE = 6 - private[this] val PARSE_VCF_METADATA = 7 - private[this] val IMPORT_FAM = 8 - private[this] val FROM_FASTA_FILE = 9 - - private[this] val log = Logger.getLogger(getClass.getName()) - - private[this] def doAction(action: Int, payload: JValue): Array[Byte] = retryTransientErrors { - implicit val formats: Formats = DefaultFormats - (action: @switch) match { - case LOAD_REFERENCES_FROM_DATASET => - val path = payload.extract[LoadReferencesFromDatasetPayload].path - backend.loadReferencesFromDataset(path) - case VALUE_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.valueType(ir) - case TABLE_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.tableType(ir) - case MATRIX_TABLE_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.matrixTableType(ir) - case BLOCK_MATRIX_TYPE => - val ir = payload.extract[IRTypePayload].ir - backend.blockMatrixType(ir) - case EXECUTE => - val qobExecutePayload = payload.extract[ServiceBackendExecutePayload] - val bufferSpecString = qobExecutePayload.payload.stream_codec - val code = qobExecutePayload.payload.ir - backend.withExecuteContext { ctx => - withIRFunctionsReadFromInput(qobExecutePayload.functions, ctx) { () => - val ir = IRParser.parse_value_ir(ctx, code) - backend.execute(ctx, ir) match { - case Left(()) => - Array() - case Right((pt, off)) => - using(new ByteArrayOutputStream()) { os => - Backend.encodeToOutputStream(ctx, pt, off, bufferSpecString, os) - os.toByteArray - } - } - } - } - case PARSE_VCF_METADATA => - val path = payload.extract[ParseVCFMetadataPayload].path - backend.parseVCFMetadata(path) - case IMPORT_FAM => - val famPayload = payload.extract[ImportFamPayload] - val path = famPayload.path - val quantPheno = famPayload.quant_pheno - val delimiter = famPayload.delimiter - val missing = famPayload.missing - backend.importFam(path, quantPheno, delimiter, missing) - case FROM_FASTA_FILE => - val fastaPayload = payload.extract[FromFASTAFilePayload] - backend.fromFASTAFile( - fastaPayload.name, - fastaPayload.fasta_file, - fastaPayload.index_file, - fastaPayload.x_contigs, - fastaPayload.y_contigs, - fastaPayload.mt_contigs, - fastaPayload.par, - ) - } - } - - private[this] def withIRFunctionsReadFromInput( - serializedFunctions: Array[SerializedIRFunction], - ctx: ExecuteContext, - )( - body: () => Array[Byte] - ): Array[Byte] = { - try { - serializedFunctions.foreach { func => - IRFunctionRegistry.registerIR( - ctx, - func.name, - func.type_parameters, - func.value_parameter_names, - func.value_parameter_types, - func.return_type, - func.rendered_body, - ) - } - body() - } finally - IRFunctionRegistry.clearUserFunctions() - } - - def executeOneCommand(action: Int, payload: JValue): Unit = { - try { - val result = doAction(action, payload) - retryTransientErrors { - using(fs.createNoCompression(outputURL)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(true) - output.writeBytes(result) - } - } - } catch { - case exc: HailWorkerException => - retryTransientErrors { - using(fs.createNoCompression(outputURL)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(false) - output.writeString(exc.shortMessage) - output.writeString(exc.expandedMessage) - output.writeInt(exc.errorId) - } - } - log.error( - "A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job." - ) - throw exc - case t: Throwable => - val (shortMessage, expandedMessage, errorId) = handleForPython(t) - retryTransientErrors { - using(fs.createNoCompression(outputURL)) { outputStream => - val output = new HailSocketAPIOutputStream(outputStream) - output.writeBool(false) - output.writeString(shortMessage) - output.writeString(expandedMessage) - output.writeInt(errorId) - } - } - log.error( - "An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job." - ) - throw t - } - } -} diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index 4f596de3b4a..9c4a459591a 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -166,38 +166,23 @@ object Worker { timer.end("readInputs") timer.start("executeFunction") - if (HailContext.isInitialized) { - HailContext.get.backend = new ServiceBackend( - null, - null, - new HailClassLoader(getClass().getClassLoader()), - null, - null, - null, - null, - null, - null, - null, - scratchDir, - ) - } else { - HailContext( - // FIXME: workers should not have backends, but some things do need hail contexts - new ServiceBackend( - null, - null, - new HailClassLoader(getClass().getClassLoader()), - null, - null, - null, - null, - null, - null, - null, - scratchDir, - ) - ) - } + + // FIXME: workers should not have backends, but some things do need hail contexts + val backend = new ServiceBackend( + null, + null, + null, + new HailClassLoader(getClass().getClassLoader()), + null, + null, + null, + null, + null, + null, + ) + + if (HailContext.isInitialized) HailContext.get.backend = backend + else HailContext(backend) val result = using(new ServiceTaskContext(i)) { htc => try diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index a50d0e90009..fe1b73aad7a 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -19,6 +19,7 @@ import is.hail.types.physical.{PStruct, PTuple} import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual._ import is.hail.utils._ +import is.hail.utils.ExecutionTimer.Timings import is.hail.variant.ReferenceGenome import scala.collection.mutable @@ -89,7 +90,7 @@ object SparkBackend { private var theSparkBackend: SparkBackend = _ - def sparkContext(op: String): SparkContext = HailContext.sparkBackend(op).sc + def sparkContext(implicit E: Enclosing): SparkContext = HailContext.sparkBackend.sc def checkSparkCompatibility(jarVersion: String, sparkVersion: String): Unit = { def majorMinor(version: String): String = version.split("\\.", 3).take(2).mkString(".") @@ -359,17 +360,15 @@ class SparkBackend( def createExecuteContextForTests( timer: ExecutionTimer, region: Region, - selfContainedExecution: Boolean = true, ): ExecuteContext = new ExecuteContext( tmpdir, localTmpdir, this, - references.toMap, fs, region, timer, - if (selfContainedExecution) null else NonOwningTempFileManager(longLifeTempFileManager), + null, theHailClassLoader, flags, new BackendContext { @@ -377,18 +376,18 @@ class SparkBackend( ExecutionCache.forTesting }, new IrMetadata(), + references, ImmutableMap.empty, mutable.Map.empty, ImmutableMap.empty, ) - override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T = - ExecutionTimer.logTime { timer => + override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): (T, Timings) = + ExecutionTimer.time { timer => ExecuteContext.scoped( tmpdir, localTmpdir, this, - references.toMap, fs, timer, null, @@ -399,6 +398,7 @@ class SparkBackend( ExecutionCache.fromFlags(flags, fs, tmpdir) }, new IrMetadata(), + references, bmCache, codeCache, persistedIr, @@ -463,7 +463,7 @@ class SparkBackend( def defaultParallelism: Int = sc.defaultParallelism - override def asSpark(op: String): SparkBackend = this + override def asSpark(implicit E: Enclosing): SparkBackend = this def close(): Unit = { bmCache.close() diff --git a/hail/src/main/scala/is/hail/expr/ir/GenericLines.scala b/hail/src/main/scala/is/hail/expr/ir/GenericLines.scala index 8adb4fe75eb..000348f272f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/GenericLines.scala +++ b/hail/src/main/scala/is/hail/expr/ir/GenericLines.scala @@ -449,7 +449,7 @@ class GenericLinesRDDPartition(val index: Int, val context: Any) extends Partiti class GenericLinesRDD( @(transient @param) contexts: IndexedSeq[Any], body: (Any) => CloseableIterator[GenericLine], -) extends RDD[GenericLine](SparkBackend.sparkContext("GenericLinesRDD"), Seq()) { +) extends RDD[GenericLine](SparkBackend.sparkContext, Seq()) { protected def getPartitions: Array[Partition] = contexts.iterator.zipWithIndex.map { case (c, i) => diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index 4c528d79a98..f50b07d597a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -728,7 +728,7 @@ case class PartitionRVDReader(rvd: RVD, uidFieldName: String) extends PartitionR val upcastF = mb.genFieldThisRef[AsmFunction2RegionLongLong]("rvdreader_upcast") val broadcastRVD = - mb.getObject[BroadcastRVD](new BroadcastRVD(ctx.backend.asSpark("RVDReader"), rvd)) + mb.getObject[BroadcastRVD](new BroadcastRVD(ctx.backend.asSpark, rvd)) val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb @@ -3213,7 +3213,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { fileStack += filesToMerge filesToMerge = - ContextRDD.weaken(SparkBackend.sparkContext("TableMapRows.execute").parallelize( + ContextRDD.weaken(SparkBackend.sparkContext.parallelize( 0 until nToMerge, nToMerge, )) diff --git a/hail/src/main/scala/is/hail/expr/ir/TableValue.scala b/hail/src/main/scala/is/hail/expr/ir/TableValue.scala index 8b5441db48a..a1dd38c72d1 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableValue.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableValue.scala @@ -188,7 +188,7 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow } def toDF(): DataFrame = - HailContext.sparkBackend("toDF").sparkSession.createDataFrame( + HailContext.sparkBackend.sparkSession.createDataFrame( rvd.toRows, typ.rowType.schema.asInstanceOf[StructType], ) diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala index be5f0e3e3b2..b81fe405175 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala @@ -21,7 +21,9 @@ import scala.reflect._ import org.apache.spark.sql.Row object IRFunctionRegistry { - private val userAddedFunctions: mutable.Set[(String, (Type, Seq[Type], Seq[Type]))] = + type UserDefinedFnKey = (String, (Type, Seq[Type], Seq[Type])) + + private[this] val userAddedFunctions: mutable.Set[UserDefinedFnKey] = mutable.HashSet.empty def clearUserFunctions(): Unit = { @@ -69,25 +71,41 @@ object IRFunctionRegistry { typeParamStrs: Array[String], argNameStrs: Array[String], argTypeStrs: Array[String], - returnType: String, + returnTypeStr: String, bodyStr: String, - ): Unit = { + ): UserDefinedFnKey = { requireJavaIdentifier(name) - val argNames = argNameStrs.map(Name) val typeParameters = typeParamStrs.map(IRParser.parseType).toFastSeq val valueParameterTypes = argTypeStrs.map(IRParser.parseType).toFastSeq - val refMap = BindingEnv.eval(argNames.zip(valueParameterTypes): _*) - val body = IRParser.parse_value_ir(ctx, bodyStr, refMap) + val argNames = argNameStrs.map(Name) + + val body = + IRParser.parse_value_ir(ctx, bodyStr, BindingEnv.eval(argNames.zip(valueParameterTypes): _*)) + val returnType = IRParser.parseType(returnTypeStr) + assert(body.typ == returnType) - userAddedFunctions += ((name, (body.typ, typeParameters, valueParameterTypes))) + val key: UserDefinedFnKey = (name, (returnType, typeParameters, valueParameterTypes)) + userAddedFunctions += key addIR( name, typeParameters, valueParameterTypes, - IRParser.parseType(returnType), + returnType, false, (_, args, _) => Subst(body, BindingEnv.eval(argNames.zip(args): _*)), ) + key + } + + def unregisterIr(key: UserDefinedFnKey): Unit = { + val (name, (returnType, typeParameterTypes, valueParameterTypes)) = key + if (userAddedFunctions.remove(key)) + removeIRFunction(name, returnType, typeParameterTypes, valueParameterTypes) + else { + throw new NoSuchElementException( + s"No user defined function registered matching: ${prettyFunctionSignature(name, returnType, typeParameterTypes, valueParameterTypes)}" + ) + } } def removeIRFunction( @@ -112,7 +130,9 @@ object IRFunctionRegistry { case Seq() => None case Seq(f) => Some(f) case _ => - fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).") + fatal( + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." + ) } def lookupFunctionOrFail( @@ -124,28 +144,34 @@ object IRFunctionRegistry { jvmRegistry.lift(name) match { case None => fatal( - s"no functions found with the signature $name(${valueParameterTypes.mkString(", ")}): $returnType" + s"no functions found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." ) case Some(functions) => functions.filter(t => t.unify(typeParameters, valueParameterTypes, returnType) ).toSeq match { case Seq() => - val prettyFunctionSignature = - s"$name[${typeParameters.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType" val prettyMismatchedFunctionSignatures = functions.map(x => s" $x").mkString("\n") fatal( - s"No function found with the signature $prettyFunctionSignature.\n" + + s"No function found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}.\n" + s"However, there are other functions with that name:\n$prettyMismatchedFunctionSignatures" ) case Seq(f) => f case _ => fatal( - s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(", ")})." + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)})." ) } } } + private[this] def prettyFunctionSignature( + name: String, + returnType: Type, + typeParameterTypes: Seq[Type], + valueParameterTypes: Seq[Type], + ): String = + s"$name[${typeParameterTypes.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType" + def lookupIR( name: String, returnType: Type, @@ -165,7 +191,9 @@ object IRFunctionRegistry { case Seq() => None case Seq(kv) => Some(kv) case _ => - fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).") + fatal( + s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}." + ) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala index e0af240efc5..3230023b7da 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -123,9 +123,7 @@ object TableStageToRVD { partition = { ctx: Ref => _ts.partition(GetField(ctx, "context")) }, ) - val sparkContext = ctx.backend - .asSpark("TableStageToRVD") - .sc + val sparkContext = ctx.backend.asSpark.sc val globalsAndBroadcastVals = Let( diff --git a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala index 9c4cd0d5f43..a4cf8a5fa43 100644 --- a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala +++ b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala @@ -19,7 +19,6 @@ import is.hail.variant._ import org.apache.spark.TaskContext import org.apache.spark.sql.Row import org.json4s.{DefaultFormats, Formats, JValue} -import org.json4s.jackson.JsonMethods case class FamFileConfig( isQuantPheno: Boolean = false, @@ -82,14 +81,13 @@ object LoadPlink { isQuantPheno: Boolean, delimiter: String, missingValue: String, - ): String = { + ): JValue = { val ffConfig = FamFileConfig(isQuantPheno, delimiter, missingValue) val (data, ptyp) = LoadPlink.parseFam(fs, path, ffConfig) - val jv = JSONAnnotationImpex.exportAnnotation( + JSONAnnotationImpex.exportAnnotation( Row(ptyp.virtualType.toString, data), TStruct("type" -> TString, "data" -> TArray(ptyp.virtualType)), ) - JsonMethods.compact(jv) } def parseFam(fs: FS, filename: String, ffConfig: FamFileConfig) diff --git a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala index fcf137cd96e..2be62c0d407 100644 --- a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala +++ b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala @@ -1685,7 +1685,7 @@ class PartitionedVCFRDD( file: String, @(transient @param) reverseContigMapping: Map[String, String], @(transient @param) _partitions: Array[Partition], -) extends RDD[WithContext[String]](SparkBackend.sparkContext("PartitionedVCFRDD"), Seq()) { +) extends RDD[WithContext[String]](SparkBackend.sparkContext, Seq()) { val contigRemappingBc = if (reverseContigMapping.size != 0) sparkContext.broadcast(reverseContigMapping) else null diff --git a/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala b/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala index 7cda148db27..f5a6eb1e1b4 100644 --- a/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala +++ b/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala @@ -44,7 +44,7 @@ case class CollectMatricesRDDPartition( } class CollectMatricesRDD(@transient var bms: IndexedSeq[BlockMatrix]) - extends RDD[BDM[Double]](SparkBackend.sparkContext("CollectMatricesRDD"), Nil) { + extends RDD[BDM[Double]](SparkBackend.sparkContext, Nil) { private val nBlocks = bms.map(_.blocks.getNumPartitions) private val firstPartition = nBlocks.scan(0)(_ + _).init @@ -114,7 +114,7 @@ object BlockMatrix { def apply(gp: GridPartitioner, piBlock: (GridPartitioner, Int) => ((Int, Int), BDM[Double])) : BlockMatrix = new BlockMatrix( - new RDD[((Int, Int), BDM[Double])](SparkBackend.sparkContext("BlockMatrix.apply"), Nil) { + new RDD[((Int, Int), BDM[Double])](SparkBackend.sparkContext, Nil) { override val partitioner = Some(gp) protected def getPartitions: Array[Partition] = Array.tabulate(gp.numPartitions)(pi => @@ -2199,7 +2199,7 @@ class WriteBlocksRDD( parentPartStarts: Array[Long], entryField: String, gp: GridPartitioner, -) extends RDD[(Int, String)](SparkBackend.sparkContext("WriteBlocksRDD"), Nil) { +) extends RDD[(Int, String)](SparkBackend.sparkContext, Nil) { require(gp.nRows == parentPartStarts.last) @@ -2369,7 +2369,7 @@ class BlockMatrixReadRowBlockedRDD( metadata: BlockMatrixMetadata, maybeMaximumCacheMemoryInBytes: Option[Int], ) extends RDD[RVDContext => Iterator[Long]]( - SparkBackend.sparkContext("BlockMatrixReadRowBlockedRDD"), + SparkBackend.sparkContext, Nil, ) { import BlockMatrixReadRowBlockedRDD._ diff --git a/hail/src/main/scala/is/hail/linalg/RowMatrix.scala b/hail/src/main/scala/is/hail/linalg/RowMatrix.scala index 6adbaca171a..969ef2d4916 100644 --- a/hail/src/main/scala/is/hail/linalg/RowMatrix.scala +++ b/hail/src/main/scala/is/hail/linalg/RowMatrix.scala @@ -286,7 +286,7 @@ class ReadBlocksAsRowsRDD( partFiles: IndexedSeq[String], partitionCounts: Array[Long], gp: GridPartitioner, -) extends RDD[(Long, Array[Double])](SparkBackend.sparkContext("ReadBlocksAsRowsRDD"), Nil) { +) extends RDD[(Long, Array[Double])](SparkBackend.sparkContext, Nil) { private val partitionStarts = partitionCounts.scanLeft(0L)(_ + _) diff --git a/hail/src/main/scala/is/hail/methods/MatrixExportEntriesByCol.scala b/hail/src/main/scala/is/hail/methods/MatrixExportEntriesByCol.scala index 43ade7d1418..5db999bbbba 100644 --- a/hail/src/main/scala/is/hail/methods/MatrixExportEntriesByCol.scala +++ b/hail/src/main/scala/is/hail/methods/MatrixExportEntriesByCol.scala @@ -197,7 +197,7 @@ case class MatrixExportEntriesByCol( // clean up temporary files val temps = tempFolders.result() val fsBc = fs.broadcast - SparkBackend.sparkContext("MatrixExportEntriesByCol.execute").parallelize( + SparkBackend.sparkContext.parallelize( temps, (temps.length / 32).max(1), ).foreach(path => fsBc.value.delete(path, recursive = true)) diff --git a/hail/src/main/scala/is/hail/rvd/RVD.scala b/hail/src/main/scala/is/hail/rvd/RVD.scala index 9c463915722..2cd97c3df03 100644 --- a/hail/src/main/scala/is/hail/rvd/RVD.scala +++ b/hail/src/main/scala/is/hail/rvd/RVD.scala @@ -1377,7 +1377,7 @@ object RVD { ) } - val sc = SparkBackend.sparkContext("writeRowsSplitFiles") + val sc = SparkBackend.sparkContext val localTmpdir = execCtx.localTmpdir val fs = execCtx.fs val fsBc = fs.broadcast diff --git a/hail/src/main/scala/is/hail/sparkextras/ContextRDD.scala b/hail/src/main/scala/is/hail/sparkextras/ContextRDD.scala index fe9c4d4e4ac..9ecefc16e13 100644 --- a/hail/src/main/scala/is/hail/sparkextras/ContextRDD.scala +++ b/hail/src/main/scala/is/hail/sparkextras/ContextRDD.scala @@ -86,7 +86,7 @@ object ContextRDD { def empty[T: ClassTag](): ContextRDD[T] = new ContextRDD( - SparkBackend.sparkContext("ContextRDD.empty").emptyRDD[RVDContext => Iterator[T]] + SparkBackend.sparkContext.emptyRDD[RVDContext => Iterator[T]] ) def union[T: ClassTag]( @@ -117,7 +117,7 @@ object ContextRDD { filterAndReplace: TextInputFilterAndReplace, ): ContextRDD[WithContext[String]] = ContextRDD.weaken( - SparkBackend.sparkContext("ContxtRDD.textFilesLines").textFilesLines( + SparkBackend.sparkContext.textFilesLines( files, nPartitions, ) @@ -129,12 +129,12 @@ object ContextRDD { weaken(sc.parallelize(data, nPartitions.getOrElse(sc.defaultMinPartitions))).map(x => x) def parallelize[T: ClassTag](data: Seq[T], numSlices: Int): ContextRDD[T] = - weaken(SparkBackend.sparkContext("ContextRDD.parallelize").parallelize(data, numSlices)).map { + weaken(SparkBackend.sparkContext.parallelize(data, numSlices)).map { x => x } def parallelize[T: ClassTag](data: Seq[T]): ContextRDD[T] = - weaken(SparkBackend.sparkContext("ContextRDD.parallelize").parallelize(data)).map(x => x) + weaken(SparkBackend.sparkContext.parallelize(data)).map(x => x) type ElementType[T] = RVDContext => Iterator[T] diff --git a/hail/src/main/scala/is/hail/sparkextras/IndexReadRDD.scala b/hail/src/main/scala/is/hail/sparkextras/IndexReadRDD.scala index d7316f3d736..94e00942a2f 100644 --- a/hail/src/main/scala/is/hail/sparkextras/IndexReadRDD.scala +++ b/hail/src/main/scala/is/hail/sparkextras/IndexReadRDD.scala @@ -15,7 +15,7 @@ class IndexReadRDD[T: ClassTag]( @transient val partFiles: Array[String], @transient val intervalBounds: Option[Array[Interval]], f: (IndexedFilePartition, TaskContext) => T, -) extends RDD[T](SparkBackend.sparkContext("IndexReadRDD"), Nil) { +) extends RDD[T](SparkBackend.sparkContext, Nil) { def getPartitions: Array[Partition] = Array.tabulate(partFiles.length) { i => IndexedFilePartition(i, partFiles(i), intervalBounds.map(_(i))) diff --git a/hail/src/main/scala/is/hail/types/virtual/package.scala b/hail/src/main/scala/is/hail/types/virtual/package.scala new file mode 100644 index 00000000000..032c862bafb --- /dev/null +++ b/hail/src/main/scala/is/hail/types/virtual/package.scala @@ -0,0 +1,12 @@ +package is.hail.types + +package virtual { + sealed abstract class Kind[T <: VType] extends Product with Serializable + + object Kinds { + case object Value extends Kind[Type] + case object Table extends Kind[TableType] + case object Matrix extends Kind[MatrixType] + case object BlockMatrix extends Kind[BlockMatrixType] + } +} diff --git a/hail/src/main/scala/is/hail/utils/SpillingCollectIterator.scala b/hail/src/main/scala/is/hail/utils/SpillingCollectIterator.scala index a8bb6d8babb..cbf9833ed15 100644 --- a/hail/src/main/scala/is/hail/utils/SpillingCollectIterator.scala +++ b/hail/src/main/scala/is/hail/utils/SpillingCollectIterator.scala @@ -16,7 +16,7 @@ object SpillingCollectIterator { val nPartitions = rdd.partitions.length val x = new SpillingCollectIterator(localTmpdir, fs, nPartitions, sizeLimit) val ctc = classTag[T] - SparkBackend.sparkContext("SpillingCollectIterator.apply").runJob( + SparkBackend.sparkContext.runJob( rdd, (_, it: Iterator[T]) => it.toArray(ctc), 0 until nPartitions, diff --git a/hail/src/test/scala/is/hail/HailSuite.scala b/hail/src/test/scala/is/hail/HailSuite.scala index aa4d217575e..5fd30c3a1a4 100644 --- a/hail/src/test/scala/is/hail/HailSuite.scala +++ b/hail/src/test/scala/is/hail/HailSuite.scala @@ -43,7 +43,7 @@ object HailSuite { lazy val hc: HailContext = { val hc = withSparkBackend() - hc.sparkBackend("HailSuite.hc").flags.set("lower", "1") + hc.backend.asSpark.flags.set("lower", "1") hc.checkRVDKeys = true hc } @@ -56,7 +56,7 @@ class HailSuite extends TestNGSuite { @BeforeClass def ensureHailContextInitialized(): Unit = hc - def backend: SparkBackend = hc.sparkBackend("HailSuite.backend") + def backend: SparkBackend = hc.backend.asSpark def sc: SparkContext = backend.sc diff --git a/hail/src/test/scala/is/hail/TestUtils.scala b/hail/src/test/scala/is/hail/TestUtils.scala index ff7a4ed4868..71ab9445a4d 100644 --- a/hail/src/test/scala/is/hail/TestUtils.scala +++ b/hail/src/test/scala/is/hail/TestUtils.scala @@ -119,9 +119,9 @@ object TestUtils { if (agg.isDefined || !env.isEmpty || !args.isEmpty) throw new LowererUnsupportedOperation("can't test with aggs or user defined args/env") - HailContext.sparkBackend("TestUtils.loweredExecute") - .jvmLowerAndExecute(ctx, x, optimize = false, lowerTable = true, lowerBM = true, - print = bytecodePrinter) + HailContext.sparkBackend.jvmLowerAndExecute(ctx, x, optimize = false, lowerTable = true, + lowerBM = true, + print = bytecodePrinter) } def eval(x: IR): Any = ExecuteContext.scoped { ctx => diff --git a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala index bf6e15f09e9..1c1f0dea0a5 100644 --- a/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala +++ b/hail/src/test/scala/is/hail/backend/ServiceBackendSuite.scala @@ -2,12 +2,15 @@ package is.hail.backend import is.hail.HailFeatureFlags import is.hail.asm4s.HailClassLoader -import is.hail.backend.service.{ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload} +import is.hail.backend.service.{ + BatchJobConfig, ServiceBackend, ServiceBackendContext, ServiceBackendRPCPayload, +} import is.hail.io.fs.{CloudStorageFSConfig, RouterFS} import is.hail.services._ import is.hail.services.JobGroupStates.Success import is.hail.utils.{tokenUrlSafe, using} +import scala.collection.mutable import scala.reflect.io.{Directory, Path} import scala.util.Random @@ -24,9 +27,9 @@ import org.testng.annotations.Test class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionValues { @Test def testCreateJobPayload(): Unit = - withMockDriverContext { rpcConfig => + withMockDriverContext { case (rpcConfig, jobConfig) => val batchClient = mock[BatchClient] - using(ServiceBackend(batchClient, rpcConfig)) { backend => + using(ServiceBackend(batchClient, rpcConfig, jobConfig)) { backend => val contexts = Array.tabulate(1)(_.toString.getBytes) // verify that @@ -41,12 +44,12 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV val jobs = jobGroup.jobs jobs.length shouldEqual contexts.length jobs.foreach { payload => - payload.regions.value shouldBe rpcConfig.regions + payload.regions.value shouldBe jobConfig.regions payload.resources.value shouldBe JobResources( preemptible = true, - cpu = Some(rpcConfig.worker_cores), - memory = Some(rpcConfig.worker_memory), - storage = Some(rpcConfig.storage), + cpu = Some(jobConfig.worker_cores), + memory = Some(jobConfig.worker_memory), + storage = Some(jobConfig.storage), ) } @@ -61,7 +64,7 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV jobGroupId shouldEqual backend.batchConfig.jobGroupId + 1 val resultsDir = - Path(backend.serviceBackendContext.remoteTmpDir) / + Path(rpcConfig.remote_tmpdir) / "parallelizeAndComputeWithIndex" / tokenUrlSafe @@ -82,7 +85,11 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV val (failure, _) = backend.parallelizeAndComputeWithIndex( - backend.serviceBackendContext, + ServiceBackendContext( + remoteTmpDir = rpcConfig.remote_tmpdir, + jobConfig = jobConfig, + executionCache = ExecutionCache.noCache, + ), backend.fs, contexts, "stage1", @@ -95,58 +102,53 @@ class ServiceBackendSuite extends TestNGSuite with IdiomaticMockito with OptionV } } - def ServiceBackend(client: BatchClient, rpcConfig: ServiceBackendRPCPayload): ServiceBackend = { + def ServiceBackend( + client: BatchClient, + rpcConfig: ServiceBackendRPCPayload, + jobConfig: BatchJobConfig, + ): ServiceBackend = { val flags = HailFeatureFlags.fromEnv() val fs = RouterFS.buildRoutes(CloudStorageFSConfig()) new ServiceBackend( - jarSpec = GitRevision("123"), name = "name", - theHailClassLoader = new HailClassLoader(getClass.getClassLoader), - references = Map.empty, batchClient = client, + jarSpec = GitRevision("123"), + theHailClassLoader = new HailClassLoader(getClass.getClassLoader), batchConfig = BatchConfig(batchId = Random.nextInt(), jobGroupId = Random.nextInt()), + rpcConfig = rpcConfig, + jobConfig = jobConfig, flags = flags, - tmpdir = rpcConfig.tmp_dir, fs = fs, - serviceBackendContext = - new ServiceBackendContext( - rpcConfig.billing_project, - rpcConfig.remote_tmpdir, - rpcConfig.worker_cores, - rpcConfig.worker_memory, - rpcConfig.storage, - rpcConfig.regions, - rpcConfig.cloudfuse_configs, - profile = false, - ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), - ), - scratchDir = rpcConfig.remote_tmpdir, + references = mutable.Map.empty, ) } - def withMockDriverContext(test: ServiceBackendRPCPayload => Any): Any = + def withMockDriverContext(test: (ServiceBackendRPCPayload, BatchJobConfig) => Any): Any = using(LocalTmpFolder) { tmp => withObjectSpied[is.hail.utils.UtilsType] { // not obvious how to pull out `tokenUrlSafe` and inject this directory // using a spy is a hack and i don't particularly like it. when(is.hail.utils.tokenUrlSafe) thenAnswer "TOKEN" - test { + test( ServiceBackendRPCPayload( tmp_dir = tmp.path, remote_tmpdir = tmp.path, + flags = Map(), + custom_references = Array(), + liftovers = Map(), + sequences = Map(), + ), + BatchJobConfig( + token = tokenUrlSafe, billing_project = "fancy", worker_cores = "128", worker_memory = "a lot.", storage = "a big ssd?", cloudfuse_configs = Array(), regions = Array("lunar1"), - flags = Map(), - custom_references = Array(), - liftovers = Map(), - sequences = Map(), - ) - } + ), + ) } } diff --git a/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala b/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala index 81145a20bca..52a901176cd 100644 --- a/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala @@ -15,6 +15,8 @@ import is.hail.types.physical.stypes.primitives.{SInt32, SInt32Value} import is.hail.utils.{using, FastSeq} import is.hail.variant.{Locus, ReferenceGenome} +import scala.collection.mutable + import org.scalatest.matchers.should.Matchers.{be, convertToAnyShouldWrapper} import org.testng.annotations.Test @@ -60,7 +62,7 @@ class StagedMinHeapSuite extends HailSuite { @Test def testLocus(): Unit = forAll(loci) { case (rg: ReferenceGenome, loci: IndexedSeq[Locus]) => - ctx.local(references = Map(rg.name -> rg)) { ctx => + ctx.local(references = mutable.Map(rg.name -> rg)) { ctx => implicit val coercions: StagedCoercions[Locus] = stagedLocusCoercions(rg) diff --git a/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala b/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala index d8975340f3a..46cfa4547c5 100644 --- a/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala +++ b/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala @@ -1,7 +1,7 @@ package is.hail.variant import is.hail.{HailSuite, TestUtils} -import is.hail.backend.{ExecuteContext, HailStateManager} +import is.hail.backend.HailStateManager import is.hail.check.Prop._ import is.hail.check.Properties import is.hail.expr.ir.EmitFunctionBuilder @@ -9,18 +9,18 @@ import is.hail.io.reference.{FASTAReader, FASTAReaderConfig, LiftOver} import is.hail.types.virtual.TLocus import is.hail.utils._ +import scala.collection.mutable + import htsjdk.samtools.reference.ReferenceSequenceFileFactory import org.testng.annotations.Test class ReferenceGenomeSuite extends HailSuite { - def hasReference(name: String) = ctx.stateManager.referenceGenomes.contains(name) - - def getReference(name: String) = ctx.getReference(name) + def hasReference(name: String) = ctx.References.contains(name) @Test def testGRCh37(): Unit = { assert(hasReference(ReferenceGenome.GRCh37)) - val grch37 = getReference(ReferenceGenome.GRCh37) + val grch37 = ctx.References(ReferenceGenome.GRCh37) assert(grch37.inX("X") && grch37.inY("Y") && grch37.isMitochondrial("MT")) assert(grch37.contigLength("1") == 249250621) @@ -36,7 +36,7 @@ class ReferenceGenomeSuite extends HailSuite { @Test def testGRCh38(): Unit = { assert(hasReference(ReferenceGenome.GRCh38)) - val grch38 = getReference(ReferenceGenome.GRCh38) + val grch38 = ctx.References(ReferenceGenome.GRCh38) assert(grch38.inX("chrX") && grch38.inY("chrY") && grch38.isMitochondrial("chrM")) assert(grch38.contigLength("chr1") == 248956422) @@ -104,12 +104,12 @@ class ReferenceGenomeSuite extends HailSuite { @Test def testContigRemap(): Unit = { val mapping = Map("23" -> "foo") TestUtils.interceptFatal("have remapped contigs in reference genome")( - getReference(ReferenceGenome.GRCh37).validateContigRemap(mapping) + ctx.References(ReferenceGenome.GRCh37).validateContigRemap(mapping) ) } @Test def testComparisonOps(): Unit = { - val rg = getReference(ReferenceGenome.GRCh37) + val rg = ctx.References(ReferenceGenome.GRCh37) // Test contigs assert(rg.compare("3", "18") < 0) @@ -128,7 +128,7 @@ class ReferenceGenomeSuite extends HailSuite { @Test def testWriteToFile(): Unit = { val tmpFile = ctx.createTmpPath("grWrite", "json") - val rg = getReference(ReferenceGenome.GRCh37) + val rg = ctx.References(ReferenceGenome.GRCh37) rg.copy(name = "GRCh37_2").write(fs, tmpFile) val gr2 = ReferenceGenome.fromFile(fs, tmpFile) @@ -222,23 +222,22 @@ class ReferenceGenomeSuite extends HailSuite { } @Test def testSerializeOnFB(): Unit = { - ExecuteContext.scoped { ctx => - val grch38 = ctx.getReference(ReferenceGenome.GRCh38) - val fb = EmitFunctionBuilder[String, Boolean](ctx, "serialize_rg") - val rgfield = fb.getReferenceGenome(grch38.name) - fb.emit(rgfield.invoke[String, Boolean]("isValidContig", fb.getCodeParam[String](1))) - - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r) + val grch38 = ctx.References(ReferenceGenome.GRCh38) + val fb = EmitFunctionBuilder[String, Boolean](ctx, "serialize_rg") + val rgfield = fb.getReferenceGenome(grch38.name) + fb.emit(rgfield.invoke[String, Boolean]("isValidContig", fb.getCodeParam[String](1))) + ctx.scopedExecution { (cl, fs, tc, r) => + val f = fb.resultWithIndex()(cl, fs, tc, r) assert(f("X") == grch38.isValidContig("X")) } } - @Test def testSerializeWithLiftoverOnFB(): Unit = { - ExecuteContext.scoped { ctx => - val grch37 = ctx.getReference(ReferenceGenome.GRCh37) + @Test def testSerializeWithLiftoverOnFB(): Unit = + ctx.local(references = mutable.Map(ctx.References.mapValues(_.copy()).toSeq: _*)) { ctx => + val grch37 = ctx.References(ReferenceGenome.GRCh37) val liftoverFile = "src/test/resources/grch37_to_grch38_chr20.over.chain.gz" - grch37.addLiftover(ctx.references("GRCh38"), LiftOver(ctx.fs, liftoverFile)) + grch37.addLiftover(ctx.References("GRCh38"), LiftOver(ctx.fs, liftoverFile)) val fb = EmitFunctionBuilder[String, Locus, Double, (Locus, Boolean)](ctx, "serialize_with_liftover") @@ -250,13 +249,13 @@ class ReferenceGenomeSuite extends HailSuite { fb.getCodeParam[Double](3), )) - val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r) - assert(f("GRCh38", Locus("20", 60001), 0.95) == grch37.liftoverLocus( - "GRCh38", - Locus("20", 60001), - 0.95, - )) - grch37.removeLiftover("GRCh38") + ctx.scopedExecution { (cl, fs, tc, r) => + val f = fb.resultWithIndex()(cl, fs, tc, r) + assert(f("GRCh38", Locus("20", 60001), 0.95) == grch37.liftoverLocus( + "GRCh38", + Locus("20", 60001), + 0.95, + )) + } } - } }