Skip to content

Commit

Permalink
make backend closable + fix mocked test
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Oct 16, 2024
1 parent 307d00d commit 7d2c099
Show file tree
Hide file tree
Showing 15 changed files with 162 additions and 151 deletions.
2 changes: 1 addition & 1 deletion hail/python/hail/backend/py4j_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _to_java_blockmatrix_ir(self, ir):
return self._parse_blockmatrix_ir(self._render_ir(ir))

def stop(self):
self._backend_server.stop()
self._backend_server.close()
self._jhc.stop()
self._jhc = None
self._registered_ir_function_names = set()
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hail/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def default_reference(self, value):

def stop(self):
assert self._backend
self._backend.stop()
self._backend.close()
self._backend = None
Env._hc = None
Env._dummy_table = None
Expand Down
5 changes: 2 additions & 3 deletions hail/python/hailtop/config/user_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,5 @@ def get_remote_tmpdir(
raise ValueError(
f'remote_tmpdir must be a storage uri path like gs://bucket/folder. Received: {remote_tmpdir}. Possible schemes include gs for GCP and https for Azure'
)
if remote_tmpdir[-1] != '/':
remote_tmpdir += '/'
return remote_tmpdir

return remote_tmpdir[:-1] if remote_tmpdir[-1] == '/' else remote_tmpdir
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/HailContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ object HailContext {

def stop(): Unit = synchronized {
IRFunctionRegistry.clearUserFunctions()
backend.stop()
backend.close()

theContext = null
}
Expand Down
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ trait BackendContext {
def executionCache: ExecutionCache
}

abstract class Backend {
abstract class Backend extends Closeable {
// From https://github.com/hail-is/hail/issues/14580 :
// IR can get quite big, especially as it can contain an arbitrary
// amount of encoded literals from the user's python session. This
Expand Down Expand Up @@ -123,7 +123,7 @@ abstract class Backend {
f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]
): (Option[Throwable], IndexedSeq[(Array[Byte], Int)])

def stop(): Unit
def close(): Unit

def asSpark(op: String): SparkBackend =
fatal(s"${getClass.getSimpleName}: $op requires SparkBackend")
Expand Down
12 changes: 4 additions & 8 deletions hail/src/main/scala/is/hail/backend/BackendServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ import is.hail.expr.ir.{IRParser, IRParserEnvironment}
import is.hail.utils._

import scala.util.control.NonFatal

import java.net.InetSocketAddress
import java.nio.charset.StandardCharsets
import java.util.concurrent._

import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer}
import org.json4s._
import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods.compact

import java.io.Closeable

case class IRTypePayload(ir: String)
case class LoadReferencesFromDatasetPayload(path: String)

Expand All @@ -31,11 +31,7 @@ case class ParseVCFMetadataPayload(path: String)
case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: String, missing: String)
case class ExecutePayload(ir: String, stream_codec: String, timed: Boolean)

object BackendServer {
def apply(backend: Backend) = new BackendServer(backend)
}

class BackendServer(backend: Backend) {
class BackendServer(backend: Backend) extends Closeable {
// 0 => let the OS pick an available port
private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10)
private[this] val handler = new BackendHttpHandler(backend)
Expand Down Expand Up @@ -77,7 +73,7 @@ class BackendServer(backend: Backend) {
def start(): Unit =
thread.start()

def stop(): Unit =
override def close(): Unit =
httpServer.stop(10)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache

def defaultParallelism: Int = 1

def stop(): Unit = LocalBackend.stop()
def close(): Unit = LocalBackend.stop()

private[this] def _jvmLowerAndExecute(
ctx: ExecuteContext,
Expand Down
10 changes: 4 additions & 6 deletions hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class ServiceBackend(
val tmpdir: String,
val fs: FS,
val serviceBackendContext: ServiceBackendContext,
val scratchDir: String = sys.env.get("HAIL_WORKER_SCRATCH_DIR").getOrElse(""),
val scratchDir: String,
) extends Backend with BackendWithNoCodeCache {
import ServiceBackend.log

Expand Down Expand Up @@ -165,7 +165,7 @@ class ServiceBackend(
val backendContext = _backendContext.asInstanceOf[ServiceBackendContext]
val n = collection.length
val token = tokenUrlSafe
val root = s"${backendContext.remoteTmpDir}parallelizeAndComputeWithIndex/$token"
val root = s"${backendContext.remoteTmpDir}/parallelizeAndComputeWithIndex/$token"

log.info(s"parallelizeAndComputeWithIndex: $token: nPartitions $n")
log.info(s"parallelizeAndComputeWithIndex: $token: writing f and contexts")
Expand Down Expand Up @@ -303,7 +303,7 @@ class ServiceBackend(
r
}

def stop(): Unit = {
override def close(): Unit = {
executor.shutdownNow()
batchClient.close()
}
Expand Down Expand Up @@ -421,9 +421,7 @@ object ServiceBackendAPI {
HailFeatureFlags.fromEnv(),
)
)
val deployConfig = DeployConfig.fromConfigFile(
s"$scratchDir/secrets/deploy-config/deploy-config.json"
)
val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json")
DeployConfig.set(deployConfig)
sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir)

Expand Down
22 changes: 10 additions & 12 deletions hail/src/main/scala/is/hail/backend/service/Worker.scala
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
package is.hail.backend.service

import is.hail.{HAIL_REVISION, HailContext, HailFeatureFlags}
import is.hail.asm4s._
import is.hail.backend.HailTaskContext
import is.hail.io.fs._
import is.hail.services._
import is.hail.utils._
import is.hail.{HAIL_REVISION, HailContext, HailFeatureFlags}
import org.apache.log4j.Logger

import scala.collection.mutable
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.util.control.NonFatal
import java.io._
import java.nio.charset._
import java.nio.file.Path
import java.util
import java.util.{concurrent => javaConcurrent}
import org.apache.log4j.Logger

import java.nio.file.Path
import scala.collection.mutable
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.util.control.NonFatal

class ServiceTaskContext(val partitionId: Int) extends HailTaskContext {
override def stageId(): Int = 0
Expand Down Expand Up @@ -113,9 +112,7 @@ object Worker {
val n = argv(6).toInt
val timer = new WorkerTimer()

val deployConfig = DeployConfig.fromConfigFile(
s"$scratchDir/secrets/deploy-config/deploy-config.json"
)
val deployConfig = DeployConfig.fromConfigFile("/deploy-config/deploy-config.json")
DeployConfig.set(deployConfig)
sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir)

Expand Down Expand Up @@ -167,7 +164,6 @@ object Worker {

timer.end("readInputs")
timer.start("executeFunction")

if (HailContext.isInitialized) {
HailContext.get.backend = new ServiceBackend(
null,
Expand All @@ -180,6 +176,7 @@ object Worker {
null,
null,
null,
scratchDir,
)
} else {
HailContext(
Expand All @@ -195,6 +192,7 @@ object Worker {
null,
null,
null,
scratchDir,
)
)
}
Expand Down
8 changes: 4 additions & 4 deletions hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,10 @@ class SparkBackend(

override def asSpark(op: String): SparkBackend = this

def stop(): Unit = SparkBackend.stop()
def close(): Unit = {
SparkBackend.stop()
longLifeTempFileManager.close()
}

def startProgressBar(): Unit =
ProgressBarBuilder.build(sc)
Expand Down Expand Up @@ -761,9 +764,6 @@ class SparkBackend(
RVDTableReader(RVD.unkeyed(rowPType, orderedCRDD), globalsLit, rt)
}

def close(): Unit =
longLifeTempFileManager.close()

def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses)
: TableStage = {
CanLowerEfficiently(ctx, inputIR) match {
Expand Down
5 changes: 3 additions & 2 deletions hail/src/main/scala/is/hail/services/BatchClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ object JobGroupStates {
}

object BatchClient {
def apply(deployConfig: DeployConfig, credentialsFile: Path): BatchClient =
new BatchClient(BatchServiceRequester(deployConfig, credentialsFile))
def apply(deployConfig: DeployConfig, credentialsFile: Path, env: Map[String, String] = sys.env)
: BatchClient =
new BatchClient(BatchServiceRequester(deployConfig, credentialsFile, env))
}

case class BatchClient private (req: Requester) extends Logging with AutoCloseable {
Expand Down
Loading

0 comments on commit 7d2c099

Please sign in to comment.