diff --git a/hail/src/main/scala/is/hail/backend/service/Main.scala b/hail/src/main/scala/is/hail/backend/service/Main.scala index 88d5b2253d6d..f35b683988d9 100644 --- a/hail/src/main/scala/is/hail/backend/service/Main.scala +++ b/hail/src/main/scala/is/hail/backend/service/Main.scala @@ -19,6 +19,10 @@ object Main { argv(3) match { case WORKER => Worker.main(argv) case DRIVER => ServiceBackendAPI.main(argv) + + // Batch's "JvmJob" is a special kind of job that can only call `Main.main`. + // TEST is used for integration testing the `BatchClient` to verify that we + // can create JvmJobs without having to mock the payload to a `Worker` job. case TEST => () case kind => throw new RuntimeException(s"unknown kind: $kind") } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 14970d452a1c..4908f2429ae8 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -242,6 +242,7 @@ class ServiceBackend( val token = tokenUrlSafe val root = s"${backendContext.remoteTmpDir}/parallelizeAndComputeWithIndex/$token" + log.info(s"parallelizeAndComputeWithIndex: token='$token', nPartitions=${contexts.length}") val uploadFunction = executor.submit[Unit](() => retryTransientErrors { diff --git a/hail/src/main/scala/is/hail/services/BatchClient.scala b/hail/src/main/scala/is/hail/services/BatchClient.scala index 571c03279425..914de7a0a2d3 100644 --- a/hail/src/main/scala/is/hail/services/BatchClient.scala +++ b/hail/src/main/scala/is/hail/services/BatchClient.scala @@ -1,7 +1,10 @@ package is.hail.services import is.hail.expr.ir.ByteArrayBuilder -import is.hail.services.BatchClient.BunchMaxSizeBytes +import is.hail.services.BatchClient.{ + BunchMaxSizeBytes, JarSpecSerializer, JobGroupResponseDeserializer, JobGroupStateDeserializer, + JobProcessRequestSerializer, +} import is.hail.services.oauth2.CloudCredentials import is.hail.services.requests.Requester import is.hail.utils._ @@ -86,7 +89,16 @@ object JobGroupStates { object BatchClient { - private[this] def BatchServiceScopes(env: Map[String, String]): Array[String] = + val BunchMaxSizeBytes: Int = 1024 * 1024 + + def apply(deployConfig: DeployConfig, credentialsFile: Path, env: Map[String, String] = sys.env) + : BatchClient = + new BatchClient(Requester( + new URL(deployConfig.baseUrl("batch")), + CloudCredentials(credentialsFile, BatchServiceScopes(env), env), + )) + + def BatchServiceScopes(env: Map[String, String]): Array[String] = env.get("HAIL_CLOUD") match { case Some("gcp") => Array( @@ -102,14 +114,70 @@ object BatchClient { throw new IllegalArgumentException(s"HAIL_CLOUD must be set.") } - def apply(deployConfig: DeployConfig, credentialsFile: Path, env: Map[String, String] = sys.env) - : BatchClient = - new BatchClient(Requester( - new URL(deployConfig.baseUrl("batch")), - CloudCredentials(credentialsFile, BatchServiceScopes(env), env), - )) + object JobProcessRequestSerializer extends CustomSerializer[JobProcess](implicit fmts => + ( + PartialFunction.empty, + { + case BashJob(image, command) => + JObject( + "type" -> JString("docker"), + "image" -> JString(image), + "command" -> JArray(command.map(JString).toList), + ) + case JvmJob(command, jarSpec, profile) => + JObject( + "type" -> JString("jvm"), + "command" -> JArray(command.map(JString).toList), + "jar_spec" -> Extraction.decompose(jarSpec), + "profile" -> JBool(profile), + ) + }, + ) + ) - private val BunchMaxSizeBytes: Int = 1024 * 1024 + object JobGroupStateDeserializer extends CustomSerializer[JobGroupState](_ => + ( + { + case JString("failure") => JobGroupStates.Failure + case JString("cancelled") => JobGroupStates.Cancelled + case JString("success") => JobGroupStates.Success + case JString("running") => JobGroupStates.Running + }, + PartialFunction.empty, + ) + ) + + object JobGroupResponseDeserializer extends CustomSerializer[JobGroupResponse](implicit fmts => + ( + { + case o: JObject => + JobGroupResponse( + batch_id = (o \ "batch_id").extract[Int], + job_group_id = (o \ "job_group_id").extract[Int], + state = (o \ "state").extract[JobGroupState], + complete = (o \ "complete").extract[Boolean], + n_jobs = (o \ "n_jobs").extract[Int], + n_completed = (o \ "n_completed").extract[Int], + n_succeeded = (o \ "n_succeeded").extract[Int], + n_failed = (o \ "n_failed").extract[Int], + n_cancelled = (o \ "n_failed").extract[Int], + ) + }, + PartialFunction.empty, + ) + ) + + object JarSpecSerializer extends CustomSerializer[JarSpec](_ => + ( + PartialFunction.empty, + { + case JarUrl(url) => + JObject("type" -> JString("jar_url"), "value" -> JString(url)) + case GitRevision(sha) => + JObject("type" -> JString("git_revision"), "value" -> JString(sha)) + }, + ) + ) } case class BatchClient private (req: Requester) extends Logging with AutoCloseable { @@ -220,6 +288,11 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab .merge( JObject( "job_id" -> JInt(jobIdx + 1), + // Batch allows clients to create multiple job groups in an update. + // For each table stage, we create and update with one job group; all jobs in + // that update belong to that one job group. This allows us to abstract updates + // from the case class used by the ServiceBackend but that information needs to + // get added back here. "in_update_job_group_id" -> JInt(1), ) ) @@ -256,73 +329,4 @@ case class BatchClient private (req: Requester) extends Logging with AutoCloseab ) )), ) - - private[this] object JobProcessRequestSerializer - extends CustomSerializer[JobProcess](_ => - ( - PartialFunction.empty, - { - case BashJob(image, command) => - JObject( - "type" -> JString("docker"), - "image" -> JString(image), - "command" -> JArray(command.map(JString).toList), - ) - case JvmJob(command, jarSpec, profile) => - JObject( - "type" -> JString("jvm"), - "command" -> JArray(command.map(JString).toList), - "jar_spec" -> Extraction.decompose(jarSpec), - "profile" -> JBool(profile), - ) - }, - ) - ) - - private[this] object JobGroupStateDeserializer - extends CustomSerializer[JobGroupState](_ => - ( - { - case JString("failure") => JobGroupStates.Failure - case JString("cancelled") => JobGroupStates.Cancelled - case JString("success") => JobGroupStates.Success - case JString("running") => JobGroupStates.Running - }, - PartialFunction.empty, - ) - ) - - private[this] object JobGroupResponseDeserializer - extends CustomSerializer[JobGroupResponse](implicit fmts => - ( - { - case o: JObject => - JobGroupResponse( - batch_id = (o \ "batch_id").extract[Int], - job_group_id = (o \ "job_group_id").extract[Int], - state = (o \ "state").extract[JobGroupState], - complete = (o \ "complete").extract[Boolean], - n_jobs = (o \ "n_jobs").extract[Int], - n_completed = (o \ "n_completed").extract[Int], - n_succeeded = (o \ "n_succeeded").extract[Int], - n_failed = (o \ "n_failed").extract[Int], - n_cancelled = (o \ "n_failed").extract[Int], - ) - }, - PartialFunction.empty, - ) - ) - - private[this] object JarSpecSerializer - extends CustomSerializer[JarSpec](_ => - ( - PartialFunction.empty, - { - case JarUrl(url) => - JObject("type" -> JString("jar_url"), "value" -> JString(url)) - case GitRevision(sha) => - JObject("type" -> JString("git_revision"), "value" -> JString(sha)) - }, - ) - ) } diff --git a/hail/src/main/scala/is/hail/services/oauth2.scala b/hail/src/main/scala/is/hail/services/oauth2.scala index 5063a51bb8ee..5dcbd85684c7 100644 --- a/hail/src/main/scala/is/hail/services/oauth2.scala +++ b/hail/src/main/scala/is/hail/services/oauth2.scala @@ -1,5 +1,6 @@ package is.hail.services +import is.hail.services.oauth2.AzureCloudCredentials.AzureTokenRefreshMinutes import is.hail.services.oauth2.AzureCloudCredentials.EnvVars.AzureApplicationCredentials import is.hail.services.oauth2.GoogleCloudCredentials.EnvVars.GoogleApplicationCredentials import is.hail.shadedazure.com.azure.core.credential.{ @@ -88,7 +89,9 @@ object oauth2 { } private[this] def isExpired: Boolean = - token == null || OffsetDateTime.now.plusMinutes(5).isBefore(token.getExpiresAt) + token == null || OffsetDateTime.now.plusMinutes(AzureTokenRefreshMinutes).isBefore( + token.getExpiresAt + ) } object AzureCloudCredentials { @@ -96,6 +99,8 @@ object oauth2 { val AzureApplicationCredentials = "AZURE_APPLICATION_CREDENTIALS" } + private[AzureCloudCredentials] val AzureTokenRefreshMinutes = 5 + def apply(keyPath: Option[Path], scopes: IndexedSeq[String], env: Map[String, String] = sys.env) : AzureCloudCredentials = keyPath.orElse(env.get(AzureApplicationCredentials).map(Path.of(_))) match { diff --git a/hail/src/main/scala/is/hail/services/requests.scala b/hail/src/main/scala/is/hail/services/requests.scala index b6ec90a08b3c..5362a289048e 100644 --- a/hail/src/main/scala/is/hail/services/requests.scala +++ b/hail/src/main/scala/is/hail/services/requests.scala @@ -27,30 +27,32 @@ object requests { def patch(route: String): JValue } - private[this] val TIMEOUT_MS = 5 * 1000 + private[this] val TimeoutMs = 5 * 1000 + private[this] val MaxNumConnectionPerRoute = 20 + private[this] val MaxNumConnections = 100 def Requester(baseUrl: URL, cred: CloudCredentials): Requester = { val httpClient: CloseableHttpClient = { log.info("creating HttpClient") val requestConfig = RequestConfig.custom() - .setConnectTimeout(TIMEOUT_MS) - .setConnectionRequestTimeout(TIMEOUT_MS) - .setSocketTimeout(TIMEOUT_MS) + .setConnectTimeout(TimeoutMs) + .setConnectionRequestTimeout(TimeoutMs) + .setSocketTimeout(TimeoutMs) .build() try { HttpClients.custom() .setSSLContext(tls.getSSLContext) - .setMaxConnPerRoute(20) - .setMaxConnTotal(100) + .setMaxConnPerRoute(MaxNumConnectionPerRoute) + .setMaxConnTotal(MaxNumConnections) .setDefaultRequestConfig(requestConfig) .build() } catch { case _: NoSSLConfigFound => log.info("creating HttpClient with no SSL Context") HttpClients.custom() - .setMaxConnPerRoute(20) - .setMaxConnTotal(100) + .setMaxConnPerRoute(MaxNumConnectionPerRoute) + .setMaxConnTotal(MaxNumConnections) .setDefaultRequestConfig(requestConfig) .build() }