diff --git a/sdk/cosmos/azure-cosmos-encryption/src/main/java/com/azure/cosmos/encryption/CosmosEncryptionAsyncContainer.java b/sdk/cosmos/azure-cosmos-encryption/src/main/java/com/azure/cosmos/encryption/CosmosEncryptionAsyncContainer.java index 22f39098700ec..89bb0569d279a 100644 --- a/sdk/cosmos/azure-cosmos-encryption/src/main/java/com/azure/cosmos/encryption/CosmosEncryptionAsyncContainer.java +++ b/sdk/cosmos/azure-cosmos-encryption/src/main/java/com/azure/cosmos/encryption/CosmosEncryptionAsyncContainer.java @@ -16,6 +16,7 @@ import com.azure.cosmos.encryption.implementation.EncryptionUtils; import com.azure.cosmos.encryption.implementation.mdesrc.cryptography.MicrosoftDataEncryptionException; import com.azure.cosmos.encryption.models.SqlQuerySpecWithEncryption; +import com.azure.cosmos.implementation.CosmosBulkExecutionOptionsImpl; import com.azure.cosmos.implementation.CosmosPagedFluxOptions; import com.azure.cosmos.implementation.HttpConstants; import com.azure.cosmos.implementation.ImplementationBridgeHelpers; @@ -1683,8 +1684,9 @@ private void setRequestHeaders(CosmosBatchRequestOptions requestOptions) { } private void setRequestHeaders(CosmosBulkExecutionOptions requestOptions) { - cosmosBulkExecutionOptionsAccessor.setHeader(requestOptions, Constants.IS_CLIENT_ENCRYPTED_HEADER, "true"); - cosmosBulkExecutionOptionsAccessor.setHeader(requestOptions, Constants.INTENDED_COLLECTION_RID_HEADER, this.encryptionProcessor.getContainerRid()); + CosmosBulkExecutionOptionsImpl requestOptionsImpl = cosmosBulkExecutionOptionsAccessor.getImpl(requestOptions); + requestOptionsImpl.setHeader(Constants.IS_CLIENT_ENCRYPTED_HEADER, "true"); + requestOptionsImpl.setHeader(Constants.INTENDED_COLLECTION_RID_HEADER, this.encryptionProcessor.getContainerRid()); } boolean isIncorrectContainerRid(CosmosException cosmosException) { diff --git a/sdk/cosmos/azure-cosmos-kafka-connect/src/main/java/com/azure/cosmos/kafka/connect/implementation/sink/CosmosBulkWriter.java b/sdk/cosmos/azure-cosmos-kafka-connect/src/main/java/com/azure/cosmos/kafka/connect/implementation/sink/CosmosBulkWriter.java index e46aaee3587ad..c736c2385d079 100644 --- a/sdk/cosmos/azure-cosmos-kafka-connect/src/main/java/com/azure/cosmos/kafka/connect/implementation/sink/CosmosBulkWriter.java +++ b/sdk/cosmos/azure-cosmos-kafka-connect/src/main/java/com/azure/cosmos/kafka/connect/implementation/sink/CosmosBulkWriter.java @@ -134,7 +134,8 @@ private CosmosBulkExecutionOptions getBulkExecutionOperations() { ImplementationBridgeHelpers .CosmosBulkExecutionOptionsHelper .getCosmosBulkExecutionOptionsAccessor() - .setMaxConcurrentCosmosPartitions(bulkExecutionOptions, this.writeConfig.getBulkMaxConcurrentCosmosPartitions()); + .getImpl(bulkExecutionOptions) + .setMaxConcurrentCosmosPartitions(this.writeConfig.getBulkMaxConcurrentCosmosPartitions()); } CosmosThroughputControlHelper.tryPopulateThroughputControlGroupName(bulkExecutionOptions, this.throughputControlConfig); diff --git a/sdk/cosmos/azure-cosmos-spark_3-1_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-1_2-12/CHANGELOG.md index 128664504dd77..29a9ebbb3564f 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-1_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-1_2-12/CHANGELOG.md @@ -10,6 +10,7 @@ * Fixed an issue to avoid transient `IllegalArgumentException` due to duplicate json properties for the `uniqueKeyPolicy` property. - See [PR 41608](https://github.com/Azure/azure-sdk-for-java/pull/41608) #### Other Changes +* Added retries on a new `BulkWriter` instance when first attempt to commit times out for bulk write jobs. - See [PR 41553](https://github.com/Azure/azure-sdk-for-java/pull/41553) ### 4.33.0 (2024-06-22) diff --git a/sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor b/sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor new file mode 100644 index 0000000000000..e2239720776d6 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor @@ -0,0 +1 @@ +com.azure.cosmos.spark.TestFaultInjectionClientInterceptor \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor b/sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor new file mode 100644 index 0000000000000..c60cbf2f14e41 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3-1_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor @@ -0,0 +1 @@ +com.azure.cosmos.spark.TestWriteOnRetryCommitInterceptor \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3-2_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-2_2-12/CHANGELOG.md index f670d8362b032..f691b2f3ed0c1 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-2_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-2_2-12/CHANGELOG.md @@ -10,6 +10,7 @@ * Fixed an issue to avoid transient `IllegalArgumentException` due to duplicate json properties for the `uniqueKeyPolicy` property. - See [PR 41608](https://github.com/Azure/azure-sdk-for-java/pull/41608) #### Other Changes +* Added retries on a new `BulkWriter` instance when first attempt to commit times out for bulk write jobs. - See [PR 41553](https://github.com/Azure/azure-sdk-for-java/pull/41553) ### 4.33.0 (2024-06-22) diff --git a/sdk/cosmos/azure-cosmos-spark_3-2_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor b/sdk/cosmos/azure-cosmos-spark_3-2_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor new file mode 100644 index 0000000000000..e2239720776d6 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3-2_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor @@ -0,0 +1 @@ +com.azure.cosmos.spark.TestFaultInjectionClientInterceptor \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3-2_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor b/sdk/cosmos/azure-cosmos-spark_3-2_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor new file mode 100644 index 0000000000000..c60cbf2f14e41 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3-2_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor @@ -0,0 +1 @@ +com.azure.cosmos.spark.TestWriteOnRetryCommitInterceptor \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md index 2c4e54485ef27..d8f725a03de2a 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md @@ -10,6 +10,7 @@ * Fixed an issue to avoid transient `IllegalArgumentException` due to duplicate json properties for the `uniqueKeyPolicy` property. - See [PR 41608](https://github.com/Azure/azure-sdk-for-java/pull/41608) #### Other Changes +* Added retries on a new `BulkWriter` instance when first attempt to commit times out for bulk write jobs. - See [PR 41553](https://github.com/Azure/azure-sdk-for-java/pull/41553) ### 4.33.0 (2024-06-22) diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor new file mode 100644 index 0000000000000..e2239720776d6 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor @@ -0,0 +1 @@ +com.azure.cosmos.spark.TestFaultInjectionClientInterceptor \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3-3_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor new file mode 100644 index 0000000000000..c60cbf2f14e41 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3-3_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor @@ -0,0 +1 @@ +com.azure.cosmos.spark.TestWriteOnRetryCommitInterceptor \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md index ebae366c42970..1b329102fab79 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md @@ -10,6 +10,7 @@ * Fixed an issue to avoid transient `IllegalArgumentException` due to duplicate json properties for the `uniqueKeyPolicy` property. - See [PR 41608](https://github.com/Azure/azure-sdk-for-java/pull/41608) #### Other Changes +* Added retries on a new `BulkWriter` instance when first attempt to commit times out for bulk write jobs. - See [PR 41553](https://github.com/Azure/azure-sdk-for-java/pull/41553) ### 4.33.0 (2024-06-22) diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor new file mode 100644 index 0000000000000..e2239720776d6 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor @@ -0,0 +1 @@ +com.azure.cosmos.spark.TestFaultInjectionClientInterceptor \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3-4_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor new file mode 100644 index 0000000000000..c60cbf2f14e41 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3-4_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor @@ -0,0 +1 @@ +com.azure.cosmos.spark.TestWriteOnRetryCommitInterceptor \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md index 63dfa85dce7db..84bc14c7024d0 100644 --- a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md @@ -10,6 +10,7 @@ * Fixed an issue to avoid transient `IllegalArgumentException` due to duplicate json properties for the `uniqueKeyPolicy` property. - See [PR 41608](https://github.com/Azure/azure-sdk-for-java/pull/41608) #### Other Changes +* Added retries on a new `BulkWriter` instance when first attempt to commit times out for bulk write jobs. - See [PR 41553](https://github.com/Azure/azure-sdk-for-java/pull/41553) ### 4.33.0 (2024-06-22) diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor new file mode 100644 index 0000000000000..e2239720776d6 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor @@ -0,0 +1 @@ +com.azure.cosmos.spark.TestFaultInjectionClientInterceptor \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3-5_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor new file mode 100644 index 0000000000000..c60cbf2f14e41 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3-5_2-12/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor @@ -0,0 +1 @@ +com.azure.cosmos.spark.TestWriteOnRetryCommitInterceptor \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/AsyncItemWriter.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/AsyncItemWriter.scala index 20d2a4fb77c7b..e35b43692de1e 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/AsyncItemWriter.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/AsyncItemWriter.scala @@ -24,7 +24,7 @@ private trait AsyncItemWriter { * Don't wait for any remaining work but signal to the writer the ungraceful close * Should not throw any exceptions */ - def abort(): Unit + def abort(shouldThrow: Boolean): Unit private[spark] def getETag(objectNode: ObjectNode) = { val eTagField = objectNode.get(CosmosConstants.Properties.ETag) diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/BulkWriter.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/BulkWriter.scala index 99d0b6da81a1c..c9b07b21a4048 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/BulkWriter.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/BulkWriter.scala @@ -4,17 +4,18 @@ package com.azure.cosmos.spark // scalastyle:off underscore.import import com.azure.cosmos.implementation.CosmosDaemonThreadFactory -import com.azure.cosmos.{BridgeInternal, CosmosAsyncContainer, CosmosDiagnosticsContext, CosmosException} +import com.azure.cosmos.{BridgeInternal, CosmosAsyncContainer, CosmosDiagnosticsContext, CosmosEndToEndOperationLatencyPolicyConfigBuilder, CosmosException} import com.azure.cosmos.implementation.apachecommons.lang.StringUtils import com.azure.cosmos.implementation.batch.{BatchRequestResponseConstants, BulkExecutorDiagnosticsTracker, ItemBulkOperation} import com.azure.cosmos.models._ -import com.azure.cosmos.spark.BulkWriter.{BulkOperationFailedException, bulkWriterRequestsBoundedElastic, bulkWriterResponsesBoundedElastic, getThreadInfo, readManyBoundedElastic} +import com.azure.cosmos.spark.BulkWriter.{BulkOperationFailedException, bulkWriterInputBoundedElastic, bulkWriterRequestsBoundedElastic, bulkWriterResponsesBoundedElastic, getThreadInfo, readManyBoundedElastic} import com.azure.cosmos.spark.diagnostics.DefaultDiagnostics import reactor.core.Scannable import reactor.core.publisher.Mono import reactor.core.scheduler.Scheduler import java.util +import java.util.Objects import java.util.concurrent.{ScheduledFuture, ScheduledThreadPoolExecutor} import scala.collection.concurrent.TrieMap import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` @@ -47,12 +48,15 @@ import scala.collection.JavaConverters._ //scalastyle:off null //scalastyle:off multiple.string.literals //scalastyle:off file.size.limit -private class BulkWriter(container: CosmosAsyncContainer, - partitionKeyDefinition: PartitionKeyDefinition, - writeConfig: CosmosWriteConfig, - diagnosticsConfig: DiagnosticsConfig, - outputMetricsPublisher: OutputMetricsPublisherTrait) - extends AsyncItemWriter { +private class BulkWriter +( + container: CosmosAsyncContainer, + partitionKeyDefinition: PartitionKeyDefinition, + writeConfig: CosmosWriteConfig, + diagnosticsConfig: DiagnosticsConfig, + outputMetricsPublisher: OutputMetricsPublisherTrait, + commitAttempt: Long = 1 +) extends AsyncItemWriter { private val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass) @@ -100,23 +104,27 @@ private class BulkWriter(container: CosmosAsyncContainer, private val errorCaptureFirstException = new AtomicReference[Throwable]() private val bulkInputEmitter: Sinks.Many[CosmosItemOperation] = Sinks.many().unicast().onBackpressureBuffer() - private val activeBulkWriteOperations = java.util.concurrent.ConcurrentHashMap.newKeySet[CosmosItemOperation]().asScala + private val activeBulkWriteOperations =java.util.concurrent.ConcurrentHashMap.newKeySet[CosmosItemOperation]().asScala private val activeReadManyOperations = java.util.concurrent.ConcurrentHashMap.newKeySet[ReadManyOperation]().asScala private val semaphore = new Semaphore(maxPendingOperations) private val totalScheduledMetrics = new AtomicLong(0) private val totalSuccessfulIngestionMetrics = new AtomicLong(0) - private val cosmosBulkExecutionOptions = ImplementationBridgeHelpers.CosmosBulkExecutionOptionsHelper + private val maxOperationTimeout = java.time.Duration.ofSeconds(CosmosConstants.batchOperationEndToEndTimeoutInSeconds) + private val endToEndTimeoutPolicy = new CosmosEndToEndOperationLatencyPolicyConfigBuilder(maxOperationTimeout) + .enable(true) + .build + private val cosmosBulkExecutionOptions = new CosmosBulkExecutionOptions(BulkWriter.bulkProcessingThresholds) + + private val cosmosBulkExecutionOptionsImpl = ImplementationBridgeHelpers.CosmosBulkExecutionOptionsHelper .getCosmosBulkExecutionOptionsAccessor - .setSchedulerOverride( - ImplementationBridgeHelpers.CosmosBulkExecutionOptionsHelper - .getCosmosBulkExecutionOptionsAccessor - .setMaxConcurrentCosmosPartitions( - new CosmosBulkExecutionOptions(BulkWriter.bulkProcessingThresholds), - maxConcurrentPartitions - ), - bulkWriterRequestsBoundedElastic) + .getImpl(cosmosBulkExecutionOptions) + private val monotonicOperationCounter = new AtomicLong(0) + + cosmosBulkExecutionOptionsImpl.setSchedulerOverride(bulkWriterRequestsBoundedElastic) + cosmosBulkExecutionOptionsImpl.setMaxConcurrentCosmosPartitions(maxConcurrentPartitions) + cosmosBulkExecutionOptionsImpl.setCosmosEndToEndLatencyPolicyConfig(endToEndTimeoutPolicy) private class ForwardingMetricTracker(val verboseLoggingEnabled: AtomicBoolean) extends BulkExecutorDiagnosticsTracker { override def trackDiagnostics(ctx: CosmosDiagnosticsContext): Unit = { @@ -132,10 +140,7 @@ private class BulkWriter(container: CosmosAsyncContainer, } } - ImplementationBridgeHelpers.CosmosBulkExecutionOptionsHelper - .getCosmosBulkExecutionOptionsAccessor - .setDiagnosticsTracker( - cosmosBulkExecutionOptions, + cosmosBulkExecutionOptionsImpl.setDiagnosticsTracker( new ForwardingMetricTracker(verboseLoggingAfterReEnqueueingRetriesEnabled) ) @@ -143,12 +148,8 @@ private class BulkWriter(container: CosmosAsyncContainer, writeConfig.maxMicroBatchPayloadSizeInBytes match { case Some(customMaxMicroBatchPayloadSizeInBytes) => - ImplementationBridgeHelpers.CosmosBulkExecutionOptionsHelper - .getCosmosBulkExecutionOptionsAccessor - .setMaxMicroBatchPayloadSizeInBytes( - cosmosBulkExecutionOptions, - customMaxMicroBatchPayloadSizeInBytes - ) + cosmosBulkExecutionOptionsImpl + .setMaxMicroBatchPayloadSizeInBytes(customMaxMicroBatchPayloadSizeInBytes) case None => } @@ -182,10 +183,8 @@ private class BulkWriter(container: CosmosAsyncContainer, } } - private val batchIntervalInMs = ImplementationBridgeHelpers - .CosmosBulkExecutionOptionsHelper - .getCosmosBulkExecutionOptionsAccessor - .getMaxMicroBatchInterval(cosmosBulkExecutionOptions) + private val batchIntervalInMs = cosmosBulkExecutionOptionsImpl + .getMaxMicroBatchInterval .toMillis private[this] val flushExecutorHolder: Option[(ScheduledThreadPoolExecutor, ScheduledFuture[_])] = { @@ -226,9 +225,8 @@ private class BulkWriter(container: CosmosAsyncContainer, DiagnosticsLoader.getDiagnosticsProvider(diagnosticsConfig).getLogger(this.getClass) val operationContextAndListenerTuple = new OperationContextAndListenerTuple(taskDiagnosticsContext, listener) - ImplementationBridgeHelpers.CosmosBulkExecutionOptionsHelper - .getCosmosBulkExecutionOptionsAccessor - .setOperationContext(cosmosBulkExecutionOptions, operationContextAndListenerTuple) + cosmosBulkExecutionOptionsImpl + .setOperationContextAndListenerTuple(operationContextAndListenerTuple) taskDiagnosticsContext } else{ @@ -276,10 +274,7 @@ private class BulkWriter(container: CosmosAsyncContainer, case None => BatchRequestResponseConstants.MAX_OPERATIONS_IN_DIRECT_MODE_BATCH_REQUEST } - val batchConcurrency = ImplementationBridgeHelpers - .CosmosBulkExecutionOptionsHelper - .getCosmosBulkExecutionOptionsAccessor - .getMaxMicroBatchConcurrency(cosmosBulkExecutionOptions) + val batchConcurrency = cosmosBulkExecutionOptionsImpl.getMaxMicroBatchConcurrency val firstRecordTimeStamp = new AtomicLong(-1) val currentMicroBatchSize = new AtomicLong(0) @@ -397,21 +392,23 @@ private class BulkWriter(container: CosmosAsyncContainer, rootNode, readManyOperation.cosmosItemIdentity.getPartitionKey, new CosmosBulkItemRequestOptions().setIfMatchETag(etag.asText()), - OperationContext( + new OperationContext( readManyOperation.operationContext.itemId, readManyOperation.operationContext.partitionKeyValue, Some(etag.asText()), readManyOperation.operationContext.attemptNumber, + monotonicOperationCounter.incrementAndGet(), Some(readManyOperation.objectNode) )) case None => CosmosBulkOperations.getCreateItemOperation( rootNode, readManyOperation.cosmosItemIdentity.getPartitionKey, - OperationContext( + new OperationContext( readManyOperation.operationContext.itemId, readManyOperation.operationContext.partitionKeyValue, - eTag = None, + eTagInput = None, readManyOperation.operationContext.attemptNumber, + monotonicOperationCounter.incrementAndGet(), Some(readManyOperation.objectNode) )) } @@ -519,11 +516,12 @@ private class BulkWriter(container: CosmosAsyncContainer, scheduleWriteInternal( partitionKey, objectNode, - OperationContext( + new OperationContext( operationContext.itemId, operationContext.partitionKeyValue, operationContext.eTag, - operationContext.attemptNumber + 1)) + operationContext.attemptNumber + 1, + operationContext.sequenceNumber)) if (clearPendingRetryAction()) { this.pendingRetries.decrementAndGet() } @@ -552,13 +550,24 @@ private class BulkWriter(container: CosmosAsyncContainer, private val subscriptionDisposable: Disposable = { log.logTrace(s"subscriptionDisposable, Context: ${operationContext.toString} $getThreadInfo") + val inputFlux = bulkInputEmitter + .asFlux() + .onBackpressureBuffer() + .publishOn(bulkWriterInputBoundedElastic) + .doOnError(t => { + log.logError(s"Input publishing flux failed, Context: ${operationContext.toString} $getThreadInfo", t) + }) + val bulkOperationResponseFlux: SFlux[CosmosBulkOperationResponse[Object]] = container .executeBulkOperations[Object]( - bulkInputEmitter.asFlux().publishOn(bulkWriterRequestsBoundedElastic), + inputFlux, cosmosBulkExecutionOptions) .onBackpressureBuffer() .publishOn(bulkWriterResponsesBoundedElastic) + .doOnError(t => { + log.logError(s"Bulk execution flux failed, Context: ${operationContext.toString} $getThreadInfo", t) + }) .asScala bulkOperationResponseFlux.subscribe( @@ -641,15 +650,23 @@ private class BulkWriter(container: CosmosAsyncContainer, throwIfCapturedExceptionExists() val activeTasksSemaphoreTimeout = 10 - val operationContext = OperationContext(getId(objectNode), partitionKeyValue, getETag(objectNode), 1) + val operationContext = new OperationContext( + getId(objectNode), + partitionKeyValue, + getETag(objectNode), + 1, + monotonicOperationCounter.incrementAndGet()) val numberOfIntervalsWithIdenticalActiveOperationSnapshots = new AtomicLong(0) // Don't clone the activeOperations for the first iteration // to reduce perf impact before the Semaphore has been acquired // this means if the semaphore can't be acquired within 10 minutes // the first attempt will always assume it wasn't stale - so effectively we // allow staleness for ten additional minutes - which is perfectly fine - var activeOperationsSnapshot = mutable.Set.empty[CosmosItemOperation] + var activeBulkWriteOperationsSnapshot = mutable.Set.empty[CosmosItemOperation] + var pendingBulkWriteRetriesSnapshot = mutable.Set.empty[CosmosItemOperation] var activeReadManyOperationsSnapshot = mutable.Set.empty[ReadManyOperation] + var pendingReadManyRetriesSnapshot = mutable.Set.empty[ReadManyOperation] + log.logTrace( s"Before TryAcquire ${totalScheduledMetrics.get}, Context: ${operationContext.toString} $getThreadInfo") while (!semaphore.tryAcquire(activeTasksSemaphoreTimeout, TimeUnit.MINUTES)) { @@ -662,12 +679,17 @@ private class BulkWriter(container: CosmosAsyncContainer, throwIfProgressStaled( "Semaphore acquisition", - activeOperationsSnapshot, + activeBulkWriteOperationsSnapshot, + pendingBulkWriteRetriesSnapshot, activeReadManyOperationsSnapshot, - numberOfIntervalsWithIdenticalActiveOperationSnapshots) + pendingReadManyRetriesSnapshot, + numberOfIntervalsWithIdenticalActiveOperationSnapshots, + allowRetryOnNewBulkWriterInstance = false) - activeOperationsSnapshot = activeBulkWriteOperations.clone() + activeBulkWriteOperationsSnapshot = activeBulkWriteOperations.clone() + pendingBulkWriteRetriesSnapshot = pendingBulkWriteRetries.clone() activeReadManyOperationsSnapshot = activeReadManyOperations.clone() + pendingReadManyRetriesSnapshot = pendingReadManyRetries.clone() } val cnt = totalScheduledMetrics.getAndIncrement() @@ -781,16 +803,6 @@ private class BulkWriter(container: CosmosAsyncContainer, responseException: Option[CosmosException] ) : Unit = { - val cosmosDiagnosticsContext: Option[CosmosDiagnosticsContext] = { - responseException match { - case Some(e) => Option.apply(e.getDiagnostics) match { - case Some(diagnostics) => Option.apply(diagnostics.getDiagnosticsContext) - case None => None - } - case None => None - } - } - val exceptionMessage = responseException match { case Some(e) => e.getMessage case None => "" @@ -916,18 +928,60 @@ private class BulkWriter(container: CosmosAsyncContainer, sb.toString() } + private[this] def sameBulkWriteOperations + ( + snapshot: mutable.Set[CosmosItemOperation], + current: mutable.Set[CosmosItemOperation] + ): Boolean = { + + if (snapshot.size != current.size) { + false + } else { + snapshot.forall(snapshotOperation => { + current.exists( + currentOperation => snapshotOperation.getOperationType == currentOperation.getOperationType + && snapshotOperation.getPartitionKeyValue == currentOperation.getPartitionKeyValue + && Objects.equals(snapshotOperation.getId, currentOperation.getId) + && Objects.equals(snapshotOperation.getItem[ObjectNode], currentOperation.getItem[ObjectNode]) + ) + }) + } + } + + private[this] def sameReadManyOperations + ( + snapshot: mutable.Set[ReadManyOperation], + current: mutable.Set[ReadManyOperation] + ): Boolean = { + + if (snapshot.size != current.size) { + false + } else { + snapshot.forall(snapshotOperation => { + current.exists( + currentOperation => snapshotOperation.cosmosItemIdentity == currentOperation.cosmosItemIdentity + && Objects.equals(snapshotOperation.objectNode, currentOperation.objectNode) + ) + }) + } + } + private[this] def throwIfProgressStaled ( operationName: String, activeOperationsSnapshot: mutable.Set[CosmosItemOperation], + pendingRetriesSnapshot: mutable.Set[CosmosItemOperation], activeReadManyOperationsSnapshot: mutable.Set[ReadManyOperation], - numberOfIntervalsWithIdenticalActiveOperationSnapshots: AtomicLong + pendingReadManyOperationsSnapshot: mutable.Set[ReadManyOperation], + numberOfIntervalsWithIdenticalActiveOperationSnapshots: AtomicLong, + allowRetryOnNewBulkWriterInstance: Boolean ): Unit = { val operationsLog = getActiveOperationsLog(activeOperationsSnapshot, activeReadManyOperationsSnapshot) - if (activeOperationsSnapshot.equals(activeBulkWriteOperations) - && activeReadManyOperationsSnapshot.equals(activeReadManyOperations)) { + if (sameBulkWriteOperations(pendingRetriesSnapshot ++ activeOperationsSnapshot , activeBulkWriteOperations ++ pendingBulkWriteRetries) + && sameReadManyOperations(pendingReadManyOperationsSnapshot ++ activeReadManyOperationsSnapshot , activeReadManyOperations ++ pendingReadManyRetries)) { + numberOfIntervalsWithIdenticalActiveOperationSnapshots.incrementAndGet() log.logWarning( s"$operationName has been waiting $numberOfIntervalsWithIdenticalActiveOperationSnapshots " + @@ -942,25 +996,52 @@ private class BulkWriter(container: CosmosAsyncContainer, ) } - if (numberOfIntervalsWithIdenticalActiveOperationSnapshots.get >= BulkWriter.maxAllowedMinutesWithoutAnyProgress) { + val secondsWithoutProgress = numberOfIntervalsWithIdenticalActiveOperationSnapshots.get * + writeConfig.flushCloseIntervalInSeconds + val maxAllowedIntervalWithoutAnyProgressExceeded = + secondsWithoutProgress >= writeConfig.maxRetryNoProgressIntervalInSeconds || + (commitAttempt == 1 + && allowRetryOnNewBulkWriterInstance + && this.activeReadManyOperations.isEmpty + && this.pendingReadManyRetries.isEmpty + && secondsWithoutProgress >= writeConfig.maxNoProgressIntervalInSeconds) + + if (maxAllowedIntervalWithoutAnyProgressExceeded) { + + val exception = if (activeReadManyOperationsSnapshot.isEmpty) { + val retriableRemainingOperations = if (allowRetryOnNewBulkWriterInstance) { + Some( + (pendingRetriesSnapshot ++ activeOperationsSnapshot) + .toList + .sortBy(op => op.getContext[OperationContext].sequenceNumber) + ) + } else { + None + } - captureIfFirstFailure( - new IllegalStateException( + new BulkWriterNoProgressException( s"Stale bulk ingestion identified in $operationName - the following active operations have not been " + s"completed (first ${BulkWriter.maxItemOperationsToShowInErrorMessage} shown) or progressed after " + - s"${BulkWriter.maxAllowedMinutesWithoutAnyProgress} minutes: $operationsLog" - )) + s"${writeConfig.maxNoProgressIntervalInSeconds} seconds: $operationsLog", + commitAttempt, + retriableRemainingOperations) + } else { + new BulkWriterNoProgressException( + s"Stale bulk ingestion as well as readMany operations identified in $operationName - the following active operations have not been " + + s"completed (first ${BulkWriter.maxItemOperationsToShowInErrorMessage} shown) or progressed after " + + s"${writeConfig.maxRetryNoProgressIntervalInSeconds} : $operationsLog", + commitAttempt, + None) + } + + captureIfFirstFailure(exception) + cancelWork() } throwIfCapturedExceptionExists() } - def getFlushAndCloseIntervalInSeconds(): Int = { - val key = "COSMOS.FLUSH_CLOSE_INTERVAL_SEC" - sys.props.get(key).getOrElse(sys.env.get(key).getOrElse("60")).toInt - } - // the caller has to ensure that after invoking this method scheduleWrite doesn't get invoked // scalastyle:off method.length // scalastyle:off cyclomatic.complexity @@ -989,13 +1070,18 @@ private class BulkWriter(container: CosmosAsyncContainer, s"$pendingRetriesSnapshot, Context: ${operationContext.toString} $getThreadInfo") val activeOperationsSnapshot = activeBulkWriteOperations.clone() val activeReadManyOperationsSnapshot = activeReadManyOperations.clone() - val awaitCompleted = pendingTasksCompleted.await(getFlushAndCloseIntervalInSeconds, TimeUnit.SECONDS) + val pendingOperationsSnapshot = pendingBulkWriteRetries.clone() + val pendingReadManyOperationsSnapshot = pendingReadManyRetries.clone() + val awaitCompleted = pendingTasksCompleted.await(writeConfig.flushCloseIntervalInSeconds, TimeUnit.SECONDS) if (!awaitCompleted) { throwIfProgressStaled( "FlushAndClose", activeOperationsSnapshot, + pendingOperationsSnapshot, activeReadManyOperationsSnapshot, - numberOfIntervalsWithIdenticalActiveOperationSnapshots + pendingReadManyOperationsSnapshot, + numberOfIntervalsWithIdenticalActiveOperationSnapshots, + allowRetryOnNewBulkWriterInstance = true ) if (numberOfIntervalsWithIdenticalActiveOperationSnapshots.get > 0L) { @@ -1083,6 +1169,7 @@ private class BulkWriter(container: CosmosAsyncContainer, throwIfCapturedExceptionExists() assume(activeTasks.get() <= 0) + assume(activeBulkWriteOperations.isEmpty) assume(activeReadManyOperations.isEmpty) assume(semaphore.availablePermits() >= maxPendingOperations) @@ -1229,20 +1316,6 @@ private class BulkWriter(container: CosmosAsyncContainer, Exceptions.isPreconditionFailedException(statusCode) } - private case class OperationContext - ( - itemId: String, - partitionKeyValue: PartitionKey, - eTag: Option[String], - attemptNumber: Int /** starts from 1 * */, - sourceItem: Option[ObjectNode] = None) // for patchBulkUpdate: source item refers to the original objectNode from which SDK constructs the final bulk item operation - - - private case class ReadManyOperation( - cosmosItemIdentity: CosmosItemIdentity, - objectNode: ObjectNode, - operationContext: OperationContext) - private def getId(objectNode: ObjectNode) = { val idField = objectNode.get(CosmosConstants.Properties.Id) assume(idField != null && idField.isTextual) @@ -1253,37 +1326,76 @@ private class BulkWriter(container: CosmosAsyncContainer, * Don't wait for any remaining work but signal to the writer the ungraceful close * Should not throw any exceptions */ - override def abort(): Unit = { - log.logError(s"Abort, Context: ${operationContext.toString} $getThreadInfo") - // signal an exception that will be thrown for any pending work/flushAndClose if no other exception has - // been registered - captureIfFirstFailure( - new IllegalStateException(s"The Spark task was aborted, Context: ${operationContext.toString}")) + override def abort(shouldThrow: Boolean): Unit = { + if (shouldThrow) { + log.logError(s"Abort, Context: ${operationContext.toString} $getThreadInfo") + // signal an exception that will be thrown for any pending work/flushAndClose if no other exception has + // been registered + captureIfFirstFailure( + new IllegalStateException(s"The Spark task was aborted, Context: ${operationContext.toString}")) + } else { + log.logWarning(s"BulkWriter aborted and commit retried, Context: ${operationContext.toString} $getThreadInfo") + } cancelWork() } + + private class OperationContext + ( + itemIdInput: String, + partitionKeyValueInput: PartitionKey, + eTagInput: Option[String], + val attemptNumber: Int, + val sequenceNumber: Long, + /** starts from 1 * */ + sourceItemInput: Option[ObjectNode] = None) // for patchBulkUpdate: source item refers to the original objectNode from which SDK constructs the final bulk item operation + { + private val ctxCore: OperationContextCore = OperationContextCore(itemIdInput, partitionKeyValueInput, eTagInput, sourceItemInput) + + override def equals(obj: Any): Boolean = ctxCore.equals(obj) + + override def hashCode(): Int = ctxCore.hashCode() + + override def toString: String = { + ctxCore.toString + s", attemptNumber = $attemptNumber" + } + + def itemId: String = ctxCore.itemId + + def partitionKeyValue: PartitionKey = ctxCore.partitionKeyValue + + def eTag: Option[String] = ctxCore.eTag + + def sourceItem: Option[ObjectNode] = ctxCore.sourceItem + } + + private case class OperationContextCore + ( + itemId: String, + partitionKeyValue: PartitionKey, + eTag: Option[String], + sourceItem: Option[ObjectNode] = None) // for patchBulkUpdate: source item refers to the original objectNode from which SDK constructs the final bulk item operation + { + override def productPrefix: String = "OperationContext" + } + + + private case class ReadManyOperation( + cosmosItemIdentity: CosmosItemIdentity, + objectNode: ObjectNode, + operationContext: OperationContext) } private object BulkWriter { private val log = new DefaultDiagnostics().getLogger(this.getClass) //scalastyle:off magic.number - private val maxDelayOn408RequestTimeoutInMs = 10000 - private val minDelayOn408RequestTimeoutInMs = 1000 + private val maxDelayOn408RequestTimeoutInMs = 3000 + private val minDelayOn408RequestTimeoutInMs = 500 private val maxItemOperationsToShowInErrorMessage = 10 private val BULK_WRITER_REQUESTS_BOUNDED_ELASTIC_THREAD_NAME = "bulk-writer-requests-bounded-elastic" + private val BULK_WRITER_INPUT_BOUNDED_ELASTIC_THREAD_NAME = "bulk-writer-input-bounded-elastic" private val BULK_WRITER_RESPONSES_BOUNDED_ELASTIC_THREAD_NAME = "bulk-writer-responses-bounded-elastic" private val READ_MANY_BOUNDED_ELASTIC_THREAD_NAME = "read-many-bounded-elastic" private val TTL_FOR_SCHEDULER_WORKER_IN_SECONDS = 60 // same as BoundedElasticScheduler.DEFAULT_TTL_SECONDS - - // we used to use 15 minutes here - extending it because of several incidents where - // backend returned 420/3088 (ThrottleDueToSplit) for >= 30 minutes - // UPDATE - reverting back to 15 minutes - causing an unreasonably large delay/hang - // due to a backend issue doesn't sound right for most customers (helpful during my own - // long stress runs - but for customers 15 minutes is more reasonable) - // UPDATE - TODO @fabianm - with 15 minutes the end-to-end sample fails too often - because the extensive 429/3088 - // intervals are around 2 hours. So I need to increase this threshold for now again - will move it - // to 45 minutes - and when I am back from vacation will drive an investigation to improve the - // end-to-end behavior on 429/3088 with the backend and monitoring teams. - private val maxAllowedMinutesWithoutAnyProgress = 45 //scalastyle:on magic.number // let's say the spark executor VM has 16 CPU cores. @@ -1326,7 +1438,7 @@ private object BulkWriter { private val bulkProcessingThresholds = new CosmosBulkExecutionThresholdsState() - val maxPendingOperationsPerJVM = DefaultMaxPendingOperationPerCore * SparkUtils.getNumberOfHostCPUCores + private val maxPendingOperationsPerJVM: Int = DefaultMaxPendingOperationPerCore * SparkUtils.getNumberOfHostCPUCores // Custom bounded elastic scheduler to consume input flux val bulkWriterRequestsBoundedElastic: Scheduler = Schedulers.newBoundedElastic( @@ -1335,6 +1447,13 @@ private object BulkWriter { BULK_WRITER_REQUESTS_BOUNDED_ELASTIC_THREAD_NAME, TTL_FOR_SCHEDULER_WORKER_IN_SECONDS, true) + // Custom bounded elastic scheduler to consume input flux + val bulkWriterInputBoundedElastic: Scheduler = Schedulers.newBoundedElastic( + Schedulers.DEFAULT_BOUNDED_ELASTIC_SIZE, + Schedulers.DEFAULT_BOUNDED_ELASTIC_QUEUESIZE + 2 * maxPendingOperationsPerJVM, + BULK_WRITER_INPUT_BOUNDED_ELASTIC_THREAD_NAME, + TTL_FOR_SCHEDULER_WORKER_IN_SECONDS, true) + // Custom bounded elastic scheduler to switch off IO thread to process response. val bulkWriterResponsesBoundedElastic: Scheduler = Schedulers.newBoundedElastic( Schedulers.DEFAULT_BOUNDED_ELASTIC_SIZE, diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/BulkWriterNoProgressException.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/BulkWriterNoProgressException.scala new file mode 100644 index 0000000000000..19b71a58cf4cc --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/BulkWriterNoProgressException.scala @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.models.CosmosItemOperation + +private[spark] class BulkWriterNoProgressException +( + val message: String, + val commitAttempt: Long, + val activeBulkWriteOperations: Option[List[CosmosItemOperation]]) extends RuntimeException(message) { +} diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientBuilderInterceptor.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientBuilderInterceptor.scala index 77e53f3504f4d..bc716f9871eef 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientBuilderInterceptor.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientBuilderInterceptor.scala @@ -7,8 +7,27 @@ import com.azure.cosmos.CosmosClientBuilder /** * The CosmosClientBuilderInterceptor trait is used to allow spark environments to provide customizations of the - * Cosmos client builder configuration + * Cosmos client builder configuration - for example to enable diagnostics */ trait CosmosClientBuilderInterceptor { - def process(cosmosClientBuilder : CosmosClientBuilder): CosmosClientBuilder + + /** + * This method will be invoked by the Cosmos DB Spark connector when instantiating new CosmosClients. If the + * returned function is defined, it will be invoked to allow making modifications on the client builder - for + * example to configure diagnostics - before the builder is used to instantiate a new client instance. + * NOTE: It is important that implementations of this trait return singleton functions in + * getClientBuilderInterceptor when applicable based on the configs passed in. Each new function instance + * will result in a new CosmosClient being created in the cache - intentionally because the + * CosmosClientBuilderInterceptor implementation might choose completely different interceptions based + * on the config. The best pattern to achieve this would be to map the configs Map to a case class + * containing the config values relevant to your CosmosClientBuilderInterceptor implementation - then you can + * use a TrieMap with the config case class as key and the function implementation as value + * A sample implementing this pattern is under azure-cosmos-spark-account-data-resolver-sample + * in this repo - see the 'LoggingClientBuilderInterceptor' class. + * + * @param configs the user configuration originally provided + * @return A function that is used to allow changing client builder before instantiating a new client instance + * to be added to the cache + */ + def getClientBuilderInterceptor(configs : Map[String, String]): Option[CosmosClientBuilder => CosmosClientBuilder] } diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala index 15bbef77252ae..b985e61f44f4a 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientCache.scala @@ -21,7 +21,7 @@ import reactor.core.scheduler.{Scheduler, Schedulers} import java.io.ByteArrayInputStream import java.time.{Duration, Instant} -import java.util.{Base64, ConcurrentModificationException, ServiceLoader} +import java.util.{Base64, ConcurrentModificationException} import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit} import scala.collection.concurrent.TrieMap @@ -109,6 +109,7 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { } } + def purge(cosmosClientConfiguration: CosmosClientConfiguration): Unit = { purgeImpl(ClientConfigurationWrapper(cosmosClientConfiguration), forceClosure = false) } @@ -335,7 +336,7 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { builder = builder.directMode(directConfig) if (cosmosClientConfiguration.proactiveConnectionInitialization.isDefined && - !cosmosClientConfiguration.proactiveConnectionInitialization.get.isEmpty) { + cosmosClientConfiguration.proactiveConnectionInitialization.get.nonEmpty) { val containerIdentities = CosmosAccountConfig.parseProactiveConnectionInitConfigs( cosmosClientConfiguration.proactiveConnectionInitialization.get) @@ -394,32 +395,22 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { } if (cosmosClientConfiguration.clientBuilderInterceptors.isDefined) { - logInfo(s"CosmosClientBuilder interceptors specified: ${cosmosClientConfiguration.clientBuilderInterceptors.get}") - val interceptorsBuilder = scala.collection.immutable.HashMap.newBuilder[String, CosmosClientBuilderInterceptor] - val serviceLoader = ServiceLoader.load(classOf[CosmosClientBuilderInterceptor]) - val services = serviceLoader.iterator() - while (services.hasNext) { - val interceptor = services.next() - interceptorsBuilder += (interceptor.getClass.getName.toLowerCase() -> interceptor) - } - val interceptorsFromClassPath = interceptorsBuilder.result() - - val requestedInterceptors = cosmosClientConfiguration.clientBuilderInterceptors.get.split(',') - for (requestedInterceptorName <- requestedInterceptors) { - val interceptorFromClassPathOpt = interceptorsFromClassPath.get(requestedInterceptorName.toLowerCase()) - if (interceptorFromClassPathOpt.isDefined) { - logInfo(s"Applying CosmosClientBuilderInterceptor `${requestedInterceptorName}`.") - builder = interceptorFromClassPathOpt.get.process(builder) - } else { - throw new IllegalStateException( - s"The requested `CosmosClientBuilderInterceptor` `$requestedInterceptorName` is not available on the classpath." - + s"Interceptors available on the class path are: ${interceptorsFromClassPath.keySet.mkString(",")}" - ) - } + logInfo(s"Applying CosmosClientBuilder interceptors") + for (interceptorFunction <- cosmosClientConfiguration.clientBuilderInterceptors.get) { + builder = interceptorFunction.apply(builder) } } - builder.buildAsyncClient() + var client = builder.buildAsyncClient() + + if (cosmosClientConfiguration.clientInterceptors.isDefined) { + logInfo(s"Applying CosmosClient interceptors") + for (interceptorFunction <- cosmosClientConfiguration.clientInterceptors.get) { + client = interceptorFunction.apply(client) + } + } + + client } // scalastyle:on method.length // scalastyle:on cyclomatic.complexity @@ -466,7 +457,7 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { } private[this] def createTokenCredential(authConfig: CosmosManagedIdentityAuthConfig): CosmosAccessTokenCredential = { - val tokenProvider: (List[String] => CosmosAccessToken) = { + val tokenProvider: List[String] => CosmosAccessToken = { val tokenCredentialBuilder = new ManagedIdentityCredentialBuilder() if (authConfig.clientId.isDefined) { tokenCredentialBuilder.clientId(authConfig.clientId.get) @@ -603,7 +594,8 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { httpConnectionPoolSize: Int, useEventualConsistency: Boolean, preferredRegionsList: String, - clientBuilderInterceptor: Option[String]) + clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], + clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]]) private[this] object ClientConfigurationWrapper { def apply(clientConfig: CosmosClientConfiguration): ClientConfigurationWrapper = { @@ -618,7 +610,8 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { case Some(regionListArray) => s"[${regionListArray.mkString(", ")}]" case None => "" }, - clientConfig.clientBuilderInterceptors + clientConfig.clientBuilderInterceptors, + clientConfig.clientInterceptors ) } } @@ -662,7 +655,7 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { extends SparkListener with BasicLoggingTrait { - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { monitoredSparkApplications.remove(ctx) match { case Some(_) => logInfo( @@ -671,11 +664,10 @@ private[spark] object CosmosClientCache extends BasicLoggingTrait { case None => logWarning(s"ApplicationEndListener:onApplicationEnd (${ctx.hashCode}) - not monitored anymore") } - } } - private[this] class CosmosAccessTokenCredential(val tokenProvider: (List[String]) =>CosmosAccessToken) extends TokenCredential { + private[this] class CosmosAccessTokenCredential(val tokenProvider: List[String] =>CosmosAccessToken) extends TokenCredential { override def getToken(tokenRequestContext: TokenRequestContext): Mono[AccessToken] = { val returnValue: Mono[AccessToken] = Mono.fromCallable(() => { val token = tokenProvider diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala index 5057a49f8b5ea..f04bb79c473ea 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientConfiguration.scala @@ -2,6 +2,7 @@ // Licensed under the MIT License. package com.azure.cosmos.spark +import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder} import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession @@ -28,7 +29,8 @@ private[spark] case class CosmosClientConfiguration ( resourceGroupName: Option[String], azureEnvironmentEndpoints: java.util.Map[String, String], sparkEnvironmentInfo: String, - clientBuilderInterceptors: Option[String]) + clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], + clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]]) private[spark] object CosmosClientConfiguration { def apply( @@ -50,7 +52,7 @@ private[spark] object CosmosClientConfiguration { var applicationName = CosmosConstants.userAgentSuffix - if (!sparkEnvironmentInfo.isEmpty) { + if (sparkEnvironmentInfo.nonEmpty) { applicationName = s"$applicationName|$sparkEnvironmentInfo" } @@ -85,7 +87,8 @@ private[spark] object CosmosClientConfiguration { cosmosAccountConfig.resourceGroupName, cosmosAccountConfig.azureEnvironmentEndpoints, sparkEnvironmentInfo, - cosmosAccountConfig.clientBuilderInterceptors) + cosmosAccountConfig.clientBuilderInterceptors, + cosmosAccountConfig.clientInterceptors) } private[spark] def getSparkEnvironmentInfo(sessionOption: Option[SparkSession]): String = { diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientInterceptor.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientInterceptor.scala new file mode 100644 index 0000000000000..b5a8ae1dc9b03 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosClientInterceptor.scala @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.CosmosAsyncClient + +/** + * The CosmosClientInterceptor trait is used to allow spark environments to provide customizations of the + * Cosmos client configuration - for example to inject faults + */ +trait CosmosClientInterceptor { + + /** + * This method will be invoked by the Cosmos DB Spark connector when instantiating new CosmosClients. If the + * returned function is defined, it will be invoked to allow making modifications on the client - for example to + * inject faults. + * NOTE: It is important that implementations of this trait return singleton functions in + * getClientInterceptor when applicable based on the configs passed in. Each new function instance + * will result in a new CosmosClient being created in the cache - intentionally because the + * CosmosClientInterceptor implementation might choose completely different interceptions based + * on the config. The best pattern to achieve this would be to map the configs Map to a case class + * containing the config values relevant to your CosmosClientInterceptor implementation - then you can + * use a TrieMap with the config case class as key and the function implementation as value + * A sample implementing this pattern is under azure-cosmos-spark-account-data-resolver-sample + * in this repo - see the 'FaultInjectingClientInterceptor' class. + * + * @param configs the user configuration originally provided + * @return A function that is used to allow changing the new client instance to be added to the cache + */ + def getClientInterceptor(configs : Map[String, String]): Option[CosmosAsyncClient => CosmosAsyncClient] +} diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index 9a92b8d6ce72a..f5aeca7eaa1f0 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -4,6 +4,7 @@ package com.azure.cosmos.spark import com.azure.core.management.AzureEnvironment +import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, spark} import com.azure.cosmos.implementation.batch.BatchRequestResponseConstants import com.azure.cosmos.implementation.routing.LocationHelper import com.azure.cosmos.implementation.{Configs, SparkBridgeImplementationInternal, Strings} @@ -11,6 +12,7 @@ import com.azure.cosmos.models.{CosmosChangeFeedRequestOptions, CosmosContainerI import com.azure.cosmos.spark.ChangeFeedModes.ChangeFeedMode import com.azure.cosmos.spark.ChangeFeedStartFromModes.{ChangeFeedStartFromMode, PointInTime} import com.azure.cosmos.spark.CosmosAuthType.CosmosAuthType +import com.azure.cosmos.spark.CosmosConfig.{getClientBuilderInterceptor, getClientInterceptor, getRetryCommitInterceptor} import com.azure.cosmos.spark.CosmosPatchOperationTypes.CosmosPatchOperationTypes import com.azure.cosmos.spark.CosmosPredicates.{assertNotNullOrEmpty, requireNotNullOrEmpty} import com.azure.cosmos.spark.ItemWriteStrategy.{ItemWriteStrategy, values} @@ -34,10 +36,6 @@ import scala.collection.concurrent.TrieMap import scala.collection.immutable.{HashSet, List, Map} import scala.collection.mutable -// scalastyle:off underscore.import -import scala.collection.JavaConverters._ -// scalastyle:on underscore.import - // scalastyle:off multiple.string.literals // scalastyle:off file.size.limit // scalastyle:off number.of.types @@ -101,6 +99,9 @@ private[spark] object CosmosConfigNames { val WriteBulkUpdateColumnConfigs = "spark.cosmos.write.bulkUpdate.columnConfigs" val WriteStrategy = "spark.cosmos.write.strategy" val WriteMaxRetryCount = "spark.cosmos.write.maxRetryCount" + val WriteFlushCloseIntervalInSeconds = "spark.cosmos.write.flush.intervalInSeconds" + val WriteMaxNoProgressIntervalInSeconds = "spark.cosmos.write.flush.noProgress.maxIntervalInSeconds" + val WriteMaxRetryNoProgressIntervalInSeconds = "spark.cosmos.write.flush.noProgress.maxRetryIntervalInSeconds" val ChangeFeedStartFrom = "spark.cosmos.changeFeed.startFrom" val ChangeFeedMode = "spark.cosmos.changeFeed.mode" val ChangeFeedItemCountPerTriggerHint = "spark.cosmos.changeFeed.itemCountPerTriggerHint" @@ -133,6 +134,8 @@ private[spark] object CosmosConfigNames { val MetricsIntervalInSeconds = "spark.cosmos.metrics.intervalInSeconds" val MetricsAzureMonitorConnectionString = "spark.cosmos.metrics.azureMonitor.connectionString" val ClientBuilderInterceptors = "spark.cosmos.account.clientBuilderInterceptors" + val ClientInterceptors = "spark.cosmos.account.clientInterceptors" + val WriteOnRetryCommitInterceptor = "spark.cosmos.write.onRetryCommitInterceptor" // Only meant to be used when throughput control is configured without using dedicated containers // Then in this case, we are going to allocate the throughput budget equally across all executors @@ -225,7 +228,12 @@ private[spark] object CosmosConfigNames { MetricsEnabledForSlf4j, MetricsIntervalInSeconds, MetricsAzureMonitorConnectionString, - ClientBuilderInterceptors + ClientBuilderInterceptors, + ClientInterceptors, + WriteOnRetryCommitInterceptor, + WriteFlushCloseIntervalInSeconds, + WriteMaxNoProgressIntervalInSeconds, + WriteMaxRetryNoProgressIntervalInSeconds ) def validateConfigName(name: String): Unit = { @@ -233,7 +241,7 @@ private[spark] object CosmosConfigNames { name.length > cosmosPrefix.length && cosmosPrefix.equalsIgnoreCase(name.substring(0, cosmosPrefix.length))) { - if (validConfigNames.find(n => name.equalsIgnoreCase(n)).isEmpty) { + if (!validConfigNames.exists(n => name.equalsIgnoreCase(n))) { throw new IllegalArgumentException( s"The config property '$name' is invalid. No config setting with this name exists.") } @@ -243,9 +251,18 @@ private[spark] object CosmosConfigNames { private object CosmosConfig extends BasicLoggingTrait { - val accountDataResolvers: TrieMap[Option[String], Option[AccountDataResolver]] = + private val accountDataResolvers: TrieMap[Option[String], Option[AccountDataResolver]] = new TrieMap[Option[String], Option[AccountDataResolver]]() + private val retryCommitInterceptors: TrieMap[String, Option[WriteOnRetryCommitInterceptor]] = + new TrieMap[String, Option[WriteOnRetryCommitInterceptor]]() + + private val clientBuilderInterceptors: TrieMap[Option[String], Option[CosmosClientBuilderInterceptor]] = + new TrieMap[Option[String], Option[CosmosClientBuilderInterceptor]]() + + private val clientInterceptors: TrieMap[Option[String], Option[CosmosClientInterceptor]] = + new TrieMap[Option[String], Option[CosmosClientInterceptor]]() + def getAccountDataResolver(accountDataResolverServiceName : Option[String]): Option[AccountDataResolver] = { accountDataResolvers.getOrElseUpdate( accountDataResolverServiceName, @@ -257,7 +274,7 @@ private object CosmosConfig extends BasicLoggingTrait { var accountDataResolverCls = None: Option[AccountDataResolver] val serviceLoader = ServiceLoader.load(classOf[AccountDataResolver]) val iterator = serviceLoader.iterator() - while (!accountDataResolverCls.isDefined && iterator.hasNext()) { + while (accountDataResolverCls.isEmpty && iterator.hasNext()) { val resolver = iterator.next() if (accountDataResolverServiceName.isEmpty || accountDataResolverServiceName.get.equalsIgnoreCase(resolver.getClass.getName)) { @@ -273,6 +290,81 @@ private object CosmosConfig extends BasicLoggingTrait { accountDataResolverCls } + def getClientBuilderInterceptor(serviceName: Option[String]): Option[CosmosClientBuilderInterceptor] = { + clientBuilderInterceptors.getOrElseUpdate(serviceName, getClientBuilderInterceptorImpl(serviceName)) + } + + private def getClientBuilderInterceptorImpl(serviceName: Option[String]): Option[CosmosClientBuilderInterceptor] = { + logInfo(s"Checking for client builder interceptors - requested service name '${serviceName.getOrElse("n/a")}'") + var cls = None: Option[CosmosClientBuilderInterceptor] + val serviceLoader = ServiceLoader.load(classOf[CosmosClientBuilderInterceptor]) + val iterator = serviceLoader.iterator() + while (cls.isEmpty && iterator.hasNext()) { + val resolver = iterator.next() + if (serviceName.isEmpty || serviceName.get.equalsIgnoreCase(resolver.getClass.getName)) { + logInfo(s"Found client builder interceptor ${resolver.getClass.getName}") + cls = Some(resolver) + } else { + logInfo( + s"Ignoring client builder interceptor ${resolver.getClass.getName} because name is different " + + s"than requested ${serviceName.get}") + } + } + + cls + } + + def getClientInterceptor(serviceName: Option[String]): Option[CosmosClientInterceptor] = { + clientInterceptors.getOrElseUpdate(serviceName, getClientInterceptorImpl(serviceName)) + } + + private def getClientInterceptorImpl(serviceName: Option[String]): Option[CosmosClientInterceptor] = { + logInfo(s"Checking for client interceptors - requested service name '${serviceName.getOrElse("n/a")}'") + var cls = None: Option[CosmosClientInterceptor] + val serviceLoader = ServiceLoader.load(classOf[CosmosClientInterceptor]) + val iterator = serviceLoader.iterator() + while (cls.isEmpty && iterator.hasNext()) { + val resolver = iterator.next() + if (serviceName.isEmpty || serviceName.get.equalsIgnoreCase(resolver.getClass.getName)) { + logInfo(s"Found client interceptor ${resolver.getClass.getName}") + cls = Some(resolver) + } else { + logInfo( + s"Ignoring client interceptor ${resolver.getClass.getName} because name is different " + + s"than requested ${serviceName.get}") + } + } + + cls + } + + def getRetryCommitInterceptor(serviceName: String): Option[WriteOnRetryCommitInterceptor] = { + retryCommitInterceptors.getOrElseUpdate( + serviceName, + getRetryCommitInterceptorImpl(serviceName)) + } + + private def getRetryCommitInterceptorImpl(serviceName: String): Option[WriteOnRetryCommitInterceptor] = { + logInfo( + s"Checking for WriteOnRetryCommitInterceptor - requested service name '$serviceName'") + var cls = None: Option[WriteOnRetryCommitInterceptor] + val serviceLoader = ServiceLoader.load(classOf[WriteOnRetryCommitInterceptor]) + val iterator = serviceLoader.iterator() + while (cls.isEmpty && iterator.hasNext()) { + val resolver = iterator.next() + if (serviceName.equalsIgnoreCase(resolver.getClass.getName)) { + logInfo(s"Found WriteOnRetryCommitInterceptor ${resolver.getClass.getName}") + cls = Some(resolver) + } else { + logInfo( + s"Ignoring WriteOnRetryCommitInterceptor ${resolver.getClass.getName} because name is different " + + s"than requested $serviceName") + } + } + + cls + } + def getEffectiveConfig ( databaseName: Option[String], @@ -284,10 +376,9 @@ private object CosmosConfig extends BasicLoggingTrait { ) : Map[String, String] = { var effectiveUserConfig = CaseInsensitiveMap(userProvidedOptions) val mergedConfig = sparkConf match { - case Some(sparkConfig) => { + case Some(sparkConfig) => val conf = sparkConfig.clone() conf.setAll(effectiveUserConfig.toMap).getAll.toMap - } case None => effectiveUserConfig.toMap } @@ -311,14 +402,13 @@ private object CosmosConfig extends BasicLoggingTrait { } val returnValue = sparkConf match { - case Some(sparkConfig) => { + case Some(sparkConfig) => val conf = sparkConfig.clone() conf.setAll(effectiveUserConfig.toMap).getAll.toMap - } case None => effectiveUserConfig.toMap } - returnValue.foreach((configProperty) => CosmosConfigNames.validateConfigName(configProperty._1)) + returnValue.foreach(configProperty => CosmosConfigNames.validateConfigName(configProperty._1)) returnValue } @@ -344,7 +434,7 @@ private object CosmosConfig extends BasicLoggingTrait { Some(executorCount)) // user provided config } - def getExecutorCount(sparkSession: SparkSession): Int = { + private def getExecutorCount(sparkSession: SparkSession): Int = { val sparkContext = sparkSession.sparkContext // The getExecutorInfos will return information for both the driver and executors // We only want the total executor count @@ -384,9 +474,11 @@ private case class CosmosAccountConfig(endpoint: String, tenantId: Option[String], resourceGroupName: Option[String], azureEnvironmentEndpoints: java.util.Map[String, String], - clientBuilderInterceptors: Option[String]) + clientBuilderInterceptors: Option[List[CosmosClientBuilder => CosmosClientBuilder]], + clientInterceptors: Option[List[CosmosAsyncClient => CosmosAsyncClient]], + ) -private object CosmosAccountConfig { +private object CosmosAccountConfig extends BasicLoggingTrait { private val DefaultAzureEnvironmentEndpoints = AzureEnvironmentType.Azure private val CosmosAccountEndpointUri = CosmosConfigEntry[String](key = CosmosConfigNames.AccountEndpoint, @@ -435,7 +527,7 @@ private object CosmosAccountConfig { .toStream .map(preferredRegion => preferredRegion.toLowerCase(Locale.ROOT).trim) .map(preferredRegion => { - if (!PreferredRegionRegex.findFirstIn(preferredRegion).isDefined) { + if (PreferredRegionRegex.findFirstIn(preferredRegion).isEmpty) { throw new IllegalArgumentException(s"$preferredRegionsListAsString is invalid") } preferredRegion @@ -527,7 +619,7 @@ private object CosmosAccountConfig { case AzureEnvironmentType.AzureChina => AzureEnvironment.AZURE_CHINA.getEndpoints case AzureEnvironmentType.AzureGermany => AzureEnvironment.AZURE_GERMANY.getEndpoints case AzureEnvironmentType.AzureUsGovernment => AzureEnvironment.AZURE_US_GOVERNMENT.getEndpoints - case _ => throw new IllegalArgumentException(s"Azure environment type ${azureEnvironmentType} is not supported") + case _ => throw new IllegalArgumentException(s"Azure environment type $azureEnvironmentType is not supported") } }, helpMessage = "The azure environment of the CosmosDB account: `Azure`, `AzureChina`, `AzureUsGovernment`, `AzureGermany`.") @@ -537,6 +629,11 @@ private object CosmosAccountConfig { parseFromStringFunction = clientBuilderInterceptorFQDN => clientBuilderInterceptorFQDN, helpMessage = "CosmosClientBuilder interceptors (comma separated) - FQDNs of the service implementing the 'CosmosClientBuilderInterceptor' trait.") + private val ClientInterceptors = CosmosConfigEntry[String](key = CosmosConfigNames.ClientInterceptors, + mandatory = false, + parseFromStringFunction = clientInterceptorFQDN => clientInterceptorFQDN, + helpMessage = "CosmosAsyncClient interceptors (comma separated) - FQDNs of the service implementing the 'CosmosClientInterceptor' trait.") + private[spark] def parseProactiveConnectionInitConfigs(config: String): java.util.List[CosmosContainerIdentity] = { val result = new java.util.ArrayList[CosmosContainerIdentity] try { @@ -551,7 +648,7 @@ private object CosmosAccountConfig { catch { case e: Exception => throw new IllegalArgumentException( s"Invalid proactive connection initialization config $config. The string must be a list of containers to " - + "be warmed-up in the format of `DBName1/ContainerName1;DBName2/ContainerName2;DBName1/ContainerName3`") + + "be warmed-up in the format of `DBName1/ContainerName1;DBName2/ContainerName2;DBName1/ContainerName3`", e) } } @@ -570,6 +667,7 @@ private object CosmosAccountConfig { val tenantIdOpt = CosmosConfigEntry.parse(cfg, TenantId) val azureEnvironmentOpt = CosmosConfigEntry.parse(cfg, AzureEnvironmentTypeEnum) val clientBuilderInterceptors = CosmosConfigEntry.parse(cfg, ClientBuilderInterceptors) + val clientInterceptors = CosmosConfigEntry.parse(cfg, ClientInterceptors) val disableTcpConnectionEndpointRediscovery = CosmosConfigEntry.parse(cfg, DisableTcpConnectionEndpointRediscovery) val preferredRegionsListOpt = CosmosConfigEntry.parse(cfg, PreferredRegionsList) @@ -610,11 +708,41 @@ private object CosmosAccountConfig { // validates each preferred region LocationHelper.getLocationEndpoint(uri, preferredRegion) } catch { - case e: Exception => throw new IllegalArgumentException(s"Invalid preferred region $preferredRegion") + case e: Exception => throw new IllegalArgumentException(s"Invalid preferred region $preferredRegion", e) } }) } + val clientBuilderInterceptorsList = mutable.ListBuffer[CosmosClientBuilder => CosmosClientBuilder]() + if (clientBuilderInterceptors.isDefined) { + logInfo(s"CosmosClientBuilder interceptors specified: ${clientBuilderInterceptors.get}") + val requestedInterceptors = clientBuilderInterceptors.get.split(',') + for (requestedInterceptorName <- requestedInterceptors) { + val foundInterceptorCandidate = getClientBuilderInterceptor(Some(requestedInterceptorName)) + if (foundInterceptorCandidate.isDefined) { + foundInterceptorCandidate.get.getClientBuilderInterceptor(cfg) match { + case Some(interceptor) => clientBuilderInterceptorsList += interceptor + case None => + } + } + } + } + + val clientInterceptorsList = mutable.ListBuffer[CosmosAsyncClient => CosmosAsyncClient]() + if (clientInterceptors.isDefined) { + logInfo(s"CosmosAsyncClient interceptors specified: ${clientInterceptors.get}") + val requestedInterceptors = clientInterceptors.get.split(',') + for (requestedInterceptorName <- requestedInterceptors) { + val foundInterceptorCandidate = getClientInterceptor(Some(requestedInterceptorName)) + if (foundInterceptorCandidate.isDefined) { + foundInterceptorCandidate.get.getClientInterceptor(cfg) match { + case Some(interceptor) => clientInterceptorsList += interceptor + case None => + } + } + } + } + CosmosAccountConfig( endpointOpt.get, authConfig, @@ -631,7 +759,8 @@ private object CosmosAccountConfig { tenantIdOpt, resourceGroupNameOpt, azureEnvironmentOpt.get, - clientBuilderInterceptors) + if (clientBuilderInterceptorsList.nonEmpty) { Some(clientBuilderInterceptorsList.toList) } else { None }, + if (clientInterceptorsList.nonEmpty) { Some(clientInterceptorsList.toList) } else { None }) } } @@ -755,14 +884,14 @@ private object CosmosAuthConfig { assert(tenantId.isDefined, s"Parameter '${CosmosConfigNames.TenantId}' is missing.") val accountDataResolverServiceName : Option[String] = CaseInsensitiveMap(cfg).get(CosmosConfigNames.AccountDataResolverServiceName) val accountDataResolver = CosmosConfig.getAccountDataResolver(accountDataResolverServiceName) - if (!accountDataResolver.isDefined) { + if (accountDataResolver.isEmpty) { throw new IllegalArgumentException( s"For auth type '${authType.get}' you have to provide an implementation of the " + "'com.azure.cosmos.spark.AccountDataResolver' trait on the class path.") } val accessTokenProvider = accountDataResolver.get.getAccessTokenProvider(cfg) - if (!accessTokenProvider.isDefined) { + if (accessTokenProvider.isEmpty) { throw new IllegalArgumentException( s"For auth type '${authType.get}' you have to provide an implementation of the " + "'com.azure.cosmos.spark.AccountDataResolver' trait on the class path, which " + @@ -910,8 +1039,8 @@ private object CosmosReadConfig { private case class CosmosViewRepositoryConfig(metaDataPath: Option[String]) private object CosmosViewRepositoryConfig { - val MetaDataPathKeyName = CosmosConfigNames.ViewsRepositoryPath - val IsCosmosViewKeyName = "isCosmosView" + val MetaDataPathKeyName: String = CosmosConfigNames.ViewsRepositoryPath + private val IsCosmosViewKeyName = "isCosmosView" private val MetaDataPath = CosmosConfigEntry[String](key = MetaDataPathKeyName, mandatory = false, defaultValue = None, @@ -1006,12 +1135,12 @@ private object ItemWriteStrategy extends Enumeration { private object CosmosPatchOperationTypes extends Enumeration { type CosmosPatchOperationTypes = Value - val None = Value("none") - val Add = Value("add") - val Set = Value("set") - val Replace = Value("replace") - val Remove = Value("remove") - val Increment = Value("increment") + val None: spark.CosmosPatchOperationTypes.Value = Value("none") + val Add: spark.CosmosPatchOperationTypes.Value = Value("add") + val Set: spark.CosmosPatchOperationTypes.Value = Value("set") + val Replace: spark.CosmosPatchOperationTypes.Value = Value("replace") + val Remove: spark.CosmosPatchOperationTypes.Value = Value("remove") + val Increment: spark.CosmosPatchOperationTypes.Value = Value("increment") } private case class CosmosPatchColumnConfig(columnName: String, @@ -1032,7 +1161,11 @@ private case class CosmosWriteConfig(itemWriteStrategy: ItemWriteStrategy, throughputControlConfig: Option[CosmosThroughputControlConfig] = None, maxMicroBatchPayloadSizeInBytes: Option[Int] = None, initialMicroBatchSize: Option[Int] = None, - maxMicroBatchSize: Option[Int] = None) + maxMicroBatchSize: Option[Int] = None, + flushCloseIntervalInSeconds: Int = 60, + maxNoProgressIntervalInSeconds: Int = 180, + maxRetryNoProgressIntervalInSeconds: Int = 45 * 60, + retryCommitInterceptor: Option[WriteOnRetryCommitInterceptor] = None) private object CosmosWriteConfig { private val DefaultMaxRetryCount = 10 @@ -1154,7 +1287,32 @@ private object CosmosWriteConfig { "2. col(column).path(patchInCosmosdb).rawJson - allows you to configure different mapping path in cosmosdb, and indicates the value of the column is in raw json format" + "3. col(column).rawJson - indicates the value of the column is in raw json format") - def parseUserDefinedPatchColumnConfigs(patchColumnConfigsString: String): TrieMap[String, CosmosPatchColumnConfig] = { + private val writeOnRetryCommitInterceptor = CosmosConfigEntry[Option[WriteOnRetryCommitInterceptor]](key = CosmosConfigNames.WriteOnRetryCommitInterceptor, + mandatory = false, + parseFromStringFunction = serviceName => getRetryCommitInterceptor(serviceName), + helpMessage = "Name of the service to be invoked when retrying write commits (currently only implemented for bulk).") + + val key = "COSMOS.FLUSH_CLOSE_INTERVAL_SEC" + + private val flushCloseIntervalInSeconds = CosmosConfigEntry[Int](key = CosmosConfigNames.WriteFlushCloseIntervalInSeconds, + defaultValue = Some(sys.props.get(key).getOrElse(sys.env.getOrElse(key, "60")).toInt), + mandatory = false, + parseFromStringFunction = intAsString => intAsString.toInt, + helpMessage = s"Interval of checks whether any progress has been made when flushing write operations.") + + private val maxNoProgressIntervalInSeconds = CosmosConfigEntry[Int](key = CosmosConfigNames.WriteMaxNoProgressIntervalInSeconds, + defaultValue = Some(45 * 60), + mandatory = false, + parseFromStringFunction = intAsString => intAsString.toInt, + helpMessage = s"Interval after which a writer fails when no progress has been made when flushing operations.") + + private val maxRetryNoProgressIntervalInSeconds = CosmosConfigEntry[Int](key = CosmosConfigNames.WriteMaxRetryNoProgressIntervalInSeconds, + defaultValue = Some(3 * 60), + mandatory = false, + parseFromStringFunction = intAsString => intAsString.toInt, + helpMessage = s"Interval after which a writer fails when no progress has been made when flushing operations in the second commit.") + + private def parseUserDefinedPatchColumnConfigs(patchColumnConfigsString: String): TrieMap[String, CosmosPatchColumnConfig] = { val columnConfigMap = new TrieMap[String, CosmosPatchColumnConfig] if (patchColumnConfigsString.isEmpty) { @@ -1172,7 +1330,7 @@ private object CosmosWriteConfig { .foreach(item => { val columnConfigString = item.trim - if (!columnConfigString.isEmpty) { + if (columnConfigString.nonEmpty) { // Currently there are two patterns which are valid // 1. col(column).op(operationType) // 2. col(column).path(mappedPath).op(operationType) @@ -1195,7 +1353,7 @@ private object CosmosWriteConfig { mappingPath = s"/$columnName" } - val isRawJson = !rawJsonSuffix.isEmpty + val isRawJson = rawJsonSuffix.nonEmpty val columnConfig = CosmosPatchColumnConfig( columnName = columnName, @@ -1217,7 +1375,7 @@ private object CosmosWriteConfig { } } - def parsePatchBulkUpdateColumnConfigs(patchBulkUpdateColumnConfigsString: String): TrieMap[String, CosmosPatchColumnConfig] = { + private def parsePatchBulkUpdateColumnConfigs(patchBulkUpdateColumnConfigsString: String): TrieMap[String, CosmosPatchColumnConfig] = { val columnConfigMap = new TrieMap[String, CosmosPatchColumnConfig] if (patchBulkUpdateColumnConfigsString.isEmpty) { @@ -1234,7 +1392,7 @@ private object CosmosWriteConfig { trimmedInput.split(",") .foreach(item => { val columnConfigString = item.trim - if (!columnConfigString.isEmpty) { + if (columnConfigString.nonEmpty) { // Currently there are three patterns which are valid // 1. col(column).path(mappedPath) // 2. col(column).path(mappingPath).rawJson @@ -1256,7 +1414,7 @@ private object CosmosWriteConfig { mappingPath = s"/$columnName" } - val isRawJson = !rawJsonSuffix.isEmpty + val isRawJson = rawJsonSuffix.nonEmpty val columnConfig = CosmosPatchColumnConfig( columnName = columnName, @@ -1287,6 +1445,8 @@ private object CosmosWriteConfig { val microBatchPayloadSizeInBytesOpt = CosmosConfigEntry.parse(cfg, microBatchPayloadSizeInBytes) val initialBatchSizeOpt = CosmosConfigEntry.parse(cfg, initialMicroBatchSize) val maxBatchSizeOpt = CosmosConfigEntry.parse(cfg, maxMicroBatchSize) + val writeRetryCommitInterceptor = CosmosConfigEntry + .parse(cfg, writeOnRetryCommitInterceptor).flatten assert(bulkEnabledOpt.isDefined, s"Parameter '${CosmosConfigNames.WriteBulkEnabled}' is missing.") @@ -1316,10 +1476,14 @@ private object CosmosWriteConfig { throughputControlConfig = throughputControlConfigOpt, maxMicroBatchPayloadSizeInBytes = microBatchPayloadSizeInBytesOpt, initialMicroBatchSize = initialBatchSizeOpt, - maxMicroBatchSize = maxBatchSizeOpt) + maxMicroBatchSize = maxBatchSizeOpt, + flushCloseIntervalInSeconds = CosmosConfigEntry.parse(cfg, flushCloseIntervalInSeconds).get, + maxNoProgressIntervalInSeconds = CosmosConfigEntry.parse(cfg, maxNoProgressIntervalInSeconds).get, + maxRetryNoProgressIntervalInSeconds = CosmosConfigEntry.parse(cfg, maxRetryNoProgressIntervalInSeconds).get, + retryCommitInterceptor = writeRetryCommitInterceptor) } - def parsePatchColumnConfigs(cfg: Map[String, String], inputSchema: StructType): TrieMap[String, CosmosPatchColumnConfig] = { + private def parsePatchColumnConfigs(cfg: Map[String, String], inputSchema: StructType): TrieMap[String, CosmosPatchColumnConfig] = { val defaultPatchOperationType = CosmosConfigEntry.parse(cfg, patchDefaultOperationType) // Parse customer specified column configs, which will override the default config @@ -1335,7 +1499,7 @@ private object CosmosWriteConfig { userDefinedPatchColumnConfigMap.remove(schemaField.name) case None => // There is no customer specified column config, create one based on the default config - val newColumnConfig = CosmosPatchColumnConfig(schemaField.name, defaultPatchOperationType.get, s"/${schemaField.name}", false) + val newColumnConfig = CosmosPatchColumnConfig(schemaField.name, defaultPatchOperationType.get, s"/${schemaField.name}", isRawJson = false) aggregatedPatchColumnConfigMap += schemaField.name -> validatePatchColumnConfig(newColumnConfig, schemaField.dataType) } }) @@ -1355,7 +1519,7 @@ private object CosmosWriteConfig { aggregatedPatchColumnConfigMap } - def validatePatchColumnConfig(cosmosPatchColumnConfig: CosmosPatchColumnConfig, dataType: DataType): CosmosPatchColumnConfig = { + private def validatePatchColumnConfig(cosmosPatchColumnConfig: CosmosPatchColumnConfig, dataType: DataType): CosmosPatchColumnConfig = { cosmosPatchColumnConfig.operationType match { case CosmosPatchOperationTypes.Increment => dataType match { @@ -1449,7 +1613,7 @@ private object CosmosContainerConfig { parseFromStringFunction = container => container, helpMessage = "Cosmos DB container name") - val optionalContainerNameSupplier = CosmosConfigEntry[String](key = CONTAINER_NAME_KEY, + val optionalContainerNameSupplier: CosmosConfigEntry[String] = CosmosConfigEntry[String](key = CONTAINER_NAME_KEY, mandatory = false, parseFromStringFunction = container => container, helpMessage = "Cosmos DB container name") @@ -1621,7 +1785,7 @@ private object CosmosPartitioningConfig { val fragments = filter.split(",") for (fragment <- fragments) { val minAndMax = fragment.trim.split("-") - epkRanges += (NormalizedRange(minAndMax(0), minAndMax(1))) + epkRanges += NormalizedRange(minAndMax(0), minAndMax(1)) } epkRanges.toArray @@ -1960,7 +2124,7 @@ private object CosmosThroughputControlConfig { } } - def parseThroughputControlAccountConfig(cfg: Map[String, String]): CosmosAccountConfig = { + private def parseThroughputControlAccountConfig(cfg: Map[String, String]): CosmosAccountConfig = { val throughputControlAccountEndpoint = CosmosConfigEntry.parse(cfg, throughputControlAccountEndpointUriSupplier) val throughputControlAccountKey = CosmosConfigEntry.parse(cfg, throughputControlAccountKeySupplier) assert( @@ -2009,7 +2173,7 @@ private object CosmosThroughputControlConfig { CosmosAccountConfig.parseCosmosAccountConfig(throughputControlAccountConfigMap.toMap) } - def addNonNullConfig( + private def addNonNullConfig( originalLowercaseCfg: Map[String, String], newCfg: mutable.Map[String, String], originalConfigName: String, @@ -2019,8 +2183,8 @@ private object CosmosThroughputControlConfig { val originalLowercaseCfgName = originalConfigName.toLowerCase(Locale.ROOT) val newLowercaseCfgName = newConfigName.toLowerCase(Locale.ROOT) - if (originalLowercaseCfg.get(originalLowercaseCfgName).isDefined) { - newCfg += (newLowercaseCfgName -> originalLowercaseCfg.get(originalLowercaseCfgName).get) + if (originalLowercaseCfg.contains(originalLowercaseCfgName)) { + newCfg += (newLowercaseCfgName -> originalLowercaseCfg(originalLowercaseCfgName)) } } } @@ -2046,20 +2210,15 @@ private case class CosmosConfigEntry[T](key: String, } } -// TODO: moderakh how to merge user config with SparkConf application config? private object CosmosConfigEntry { def parseEnumeration[T <: Enumeration](enumValueAsString: String, enumeration: T): T#Value = { require(enumValueAsString != null) enumeration.values.find(_.toString.toLowerCase == enumValueAsString.toLowerCase()).getOrElse( - throw new IllegalArgumentException(s"$enumValueAsString valid value, valid values are ${values}")) + throw new IllegalArgumentException(s"$enumValueAsString valid value, valid values are $values")) } private val configEntriesDefinitions = new java.util.HashMap[String, CosmosConfigEntry[_]]() - def allConfigNames(): Seq[String] = { - configEntriesDefinitions.keySet().asScala.toSeq - } - def parse[T](configuration: Map[String, String], configEntry: CosmosConfigEntry[T]): Option[T] = { // we are doing this here per config parsing for now val loweredCaseConfiguration = configuration diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala index 83542abe42863..45db4f2172404 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosConstants.scala @@ -26,6 +26,7 @@ private[cosmos] object CosmosConstants { val defaultMetricsIntervalInSeconds = 60 val defaultSlf4jMetricReporterEnabled = false val readOperationEndToEndTimeoutInSeconds = 65 + val batchOperationEndToEndTimeoutInSeconds = 65 object Names { val ItemsDataSourceShortName = "cosmos.oltp" diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosWriterBase.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosWriterBase.scala index bb92090d11b44..0472ba68d450e 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosWriterBase.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/CosmosWriterBase.scala @@ -5,13 +5,14 @@ package com.azure.cosmos.spark import com.azure.cosmos.SparkBridgeInternal import com.azure.cosmos.spark.diagnostics.LoggerHelper +import com.fasterxml.jackson.databind.node.ObjectNode import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} import org.apache.spark.sql.types.StructType -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} private abstract class CosmosWriterBase( userConfig: Map[String, String], @@ -57,23 +58,26 @@ private abstract class CosmosWriterBase( private val containerDefinition = SparkBridgeInternal .getContainerPropertiesFromCollectionCache(container) private val partitionKeyDefinition = containerDefinition.getPartitionKeyDefinition - - private val writer = if (cosmosWriteConfig.bulkEnabled) { - new BulkWriter( - container, - partitionKeyDefinition, - cosmosWriteConfig, - diagnosticsConfig, - getOutputMetricsPublisher()) - } else { - new PointWriter( - container, - partitionKeyDefinition, - cosmosWriteConfig, - diagnosticsConfig, - TaskContext.get(), - getOutputMetricsPublisher()) - } + private val commitAttempt = new AtomicInteger(1) + + private val writer: AtomicReference[AsyncItemWriter] = new AtomicReference( + if (cosmosWriteConfig.bulkEnabled) { + new BulkWriter( + container, + partitionKeyDefinition, + cosmosWriteConfig, + diagnosticsConfig, + getOutputMetricsPublisher(), + commitAttempt.getAndIncrement()) + } else { + new PointWriter( + container, + partitionKeyDefinition, + cosmosWriteConfig, + diagnosticsConfig, + TaskContext.get(), + getOutputMetricsPublisher()) + }) override def write(internalRow: InternalRow): Unit = { val objectNode = cosmosRowConverter.fromInternalRowToObjectNode(internalRow, inputSchema) @@ -91,27 +95,79 @@ private abstract class CosmosWriterBase( } val partitionKeyValue = PartitionKeyHelper.getPartitionKeyPath(objectNode, partitionKeyDefinition) - writer.scheduleWrite(partitionKeyValue, objectNode) + writer.get.scheduleWrite(partitionKeyValue, objectNode) } override def commit(): WriterCommitMessage = { log.logInfo("commit invoked!!!") - writer.flushAndClose() - + flushAndCloseWriterWithRetries("committing") new WriterCommitMessage {} } override def abort(): Unit = { log.logInfo("abort invoked!!!") - writer.abort() - if (cacheItemReleasedCount.incrementAndGet() == 1) { - clientCacheItem.close() + try { + writer.get.abort(true) + } finally { + closeClients() } } override def close(): Unit = { log.logInfo("close invoked!!!") - writer.flushAndClose() + try { + flushAndCloseWriterWithRetries("closing") + } finally { + closeClients() + } + } + + private def flushAndCloseWriterWithRetries(operationName: String) = { + try { + writer.get.flushAndClose() + } catch { + case bulkWriterStaleError: BulkWriterNoProgressException => + bulkWriterStaleError.activeBulkWriteOperations match { + case Some(remainingWriteOperations) => + log.logWarning(s"Error indicating stuck writer when $operationName write job. Retry will be attempted for " + + s"the outstanding ${remainingWriteOperations.size} write operations.", bulkWriterStaleError) + + val bulkWriterForRetry = + new BulkWriter( + container, + partitionKeyDefinition, + cosmosWriteConfig, + diagnosticsConfig, + getOutputMetricsPublisher(), + commitAttempt.getAndIncrement()) + val oldBulkWriter = writer.getAndSet(bulkWriterForRetry) + + cosmosWriteConfig.retryCommitInterceptor match { + case Some(onRetryCommitInterceptor) => + log.logInfo("Invoking custom on-retry-commit interceptor...") + onRetryCommitInterceptor.beforeRetryCommit() + case None => + } + + for (operation <- remainingWriteOperations) { + bulkWriterForRetry.scheduleWrite(operation.getPartitionKeyValue, operation.getItem[ObjectNode]) + } + oldBulkWriter.abort(false) + bulkWriterForRetry.flushAndClose() + // None means not just write operations but also read-many are outstanding we can't retry + case None => + log.logError(s"Error indicating stuck writer when $operationName write job. No retry possible because " + + "of outstanding read-many operations.", bulkWriterStaleError) + + throw bulkWriterStaleError + } + case e: Throwable => + log.logError(s"Unexpected error when $operationName write job.", e) + throw e + } + } + + private def closeClients() = { if (cacheItemReleasedCount.incrementAndGet() == 1) { clientCacheItem.close() if (throughputControlClientCacheItemOpt.isDefined) { diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/PointWriter.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/PointWriter.scala index 6acf3678a1c40..55be01c0481fa 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/PointWriter.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/PointWriter.scala @@ -679,11 +679,13 @@ private class PointWriter(container: CosmosAsyncContainer, * Don't wait for any remaining work but signal to the writer the ungraceful close * Should not throw any exceptions */ - override def abort(): Unit = { - // signal an exception that will be thrown for any pending work/flushAndClose if no other exception has - // been registered - captureIfFirstFailure( - new IllegalStateException(s"The Spark task was aborted, Context: ${taskDiagnosticsContext.toString}")) + override def abort(shouldThrow: Boolean): Unit = { + if (shouldThrow) { + // signal an exception that will be thrown for any pending work/flushAndClose if no other exception has + // been registered + captureIfFirstFailure( + new IllegalStateException(s"The Spark task was aborted, Context: ${taskDiagnosticsContext.toString}")) + } closed.set(true) diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/WriteOnRetryCommitInterceptor.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/WriteOnRetryCommitInterceptor.scala new file mode 100644 index 0000000000000..91b7d7644b7b0 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/main/scala/com/azure/cosmos/spark/WriteOnRetryCommitInterceptor.scala @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +/** + * The CosmosClientInterceptor trait is used to allow spark environments to provide customizations of the + * Cosmos client configuration - for example to inject faults + */ +trait WriteOnRetryCommitInterceptor { + + /** + * This method will be invoked by the Cosmos DB Spark connector before retrying a commit during writes (currently + * only when bulk mode is enabled). + * + * @return A function that will be invoked before retrying a commit on the write path. + */ + def beforeRetryCommit():Unit +} diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala index e9aca812f130a..27e04763be2b4 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/CosmosClientCacheITest.scala @@ -63,7 +63,8 @@ class CosmosClientCacheITest resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None) + clientBuilderInterceptors = None, + clientInterceptors = None) ), ( "StandardCtorWithEmptyPreferredRegions", @@ -88,7 +89,8 @@ class CosmosClientCacheITest resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None) + clientBuilderInterceptors = None, + clientInterceptors = None) ), ( "StandardCtorWithOnePreferredRegion", @@ -113,7 +115,8 @@ class CosmosClientCacheITest resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None) + clientBuilderInterceptors = None, + clientInterceptors = None) ), ( "StandardCtorWithTwoPreferredRegions", @@ -138,7 +141,8 @@ class CosmosClientCacheITest resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None) + clientBuilderInterceptors = None, + clientInterceptors = None) ) ) @@ -170,7 +174,8 @@ class CosmosClientCacheITest userConfig.resourceGroupName, userConfig.azureEnvironmentEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None + clientBuilderInterceptors = None, + clientInterceptors = None ) logInfo(s"TestCase: {$testCaseName}") @@ -185,9 +190,9 @@ class CosmosClientCacheITest Some(CosmosClientCache(userConfigShallowCopy, None, s"$testCaseName-CosmosClientCacheITest-02")) )) .to(clients2 => { - clients2(0).get.cosmosClient should be theSameInstanceAs clients(0).get.cosmosClient - clients2(0).get.sparkCatalogClient.isInstanceOf[CosmosCatalogCosmosSDKClient] should be - clients(0).get.sparkCatalogClient.isInstanceOf[CosmosCatalogCosmosSDKClient] should be + clients2.head.get.cosmosClient should be theSameInstanceAs clients.head.get.cosmosClient + clients2.head.get.sparkCatalogClient.isInstanceOf[CosmosCatalogCosmosSDKClient] should be + clients.head.get.sparkCatalogClient.isInstanceOf[CosmosCatalogCosmosSDKClient] should be val ownerInfo = CosmosClientCache.ownerInformation(userConfig) logInfo(s"$testCaseName-OwnerInfo $ownerInfo") @@ -218,7 +223,7 @@ class CosmosClientCacheITest )) .to(clients2 => { - clients2(0).get shouldNot be theSameInstanceAs clients(0).get + clients2.head.get shouldNot be theSameInstanceAs clients.head.get CosmosClientCache.purge(userConfig) }) }) @@ -236,8 +241,8 @@ class CosmosClientCacheITest Some(CosmosClientCache(userConfig, Option(cosmosClientCacheSnapshot), "CosmosClientCacheITest-05")) )) .to(clients => { - clients(0).get shouldBe a[CosmosClientCacheItem] - clients(0).get.cosmosClient shouldBe a[CosmosAsyncClient] + clients.head.get shouldBe a[CosmosClientCacheItem] + clients.head.get.cosmosClient shouldBe a[CosmosAsyncClient] CosmosClientCache.purge(userConfig) }) } diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala index 71e650feae727..c004a9879ca35 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/CosmosPartitionPlannerSpec.scala @@ -36,7 +36,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None) + clientBuilderInterceptors = None, + clientInterceptors = None) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) val normalizedRange = NormalizedRange(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -110,7 +111,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None) + clientBuilderInterceptors = None, + clientInterceptors = None) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) val normalizedRange = NormalizedRange(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -184,7 +186,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None) + clientBuilderInterceptors = None, + clientInterceptors = None) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) val normalizedRange = NormalizedRange(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -258,7 +261,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None) + clientBuilderInterceptors = None, + clientInterceptors = None) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) val normalizedRange = NormalizedRange(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -330,7 +334,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None) + clientBuilderInterceptors = None, + clientInterceptors = None) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) val normalizedRange = NormalizedRange(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -418,7 +423,8 @@ class CosmosPartitionPlannerSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None) + clientBuilderInterceptors = None, + clientInterceptors = None) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) val normalizedRange = NormalizedRange(UUID.randomUUID().toString, UUID.randomUUID().toString) diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala index 54c777d4a3836..ee67409974a0e 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/PartitionMetadataSpec.scala @@ -35,7 +35,8 @@ class PartitionMetadataSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None + clientBuilderInterceptors = None, + clientInterceptors = None ) private[this] val contCfg = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -80,7 +81,8 @@ class PartitionMetadataSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None + clientBuilderInterceptors = None, + clientInterceptors = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -164,7 +166,8 @@ class PartitionMetadataSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None + clientBuilderInterceptors = None, + clientInterceptors = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -248,7 +251,8 @@ class PartitionMetadataSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None + clientBuilderInterceptors = None, + clientInterceptors = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -314,7 +318,8 @@ class PartitionMetadataSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None + clientBuilderInterceptors = None, + clientInterceptors = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -375,7 +380,8 @@ class PartitionMetadataSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None + clientBuilderInterceptors = None, + clientInterceptors = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -430,7 +436,8 @@ class PartitionMetadataSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None + clientBuilderInterceptors = None, + clientInterceptors = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -485,7 +492,8 @@ class PartitionMetadataSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None + clientBuilderInterceptors = None, + clientInterceptors = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -540,7 +548,8 @@ class PartitionMetadataSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None + clientBuilderInterceptors = None, + clientInterceptors = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -595,7 +604,8 @@ class PartitionMetadataSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None + clientBuilderInterceptors = None, + clientInterceptors = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -650,7 +660,8 @@ class PartitionMetadataSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None + clientBuilderInterceptors = None, + clientInterceptors = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) @@ -722,7 +733,8 @@ class PartitionMetadataSpec extends UnitSpec { resourceGroupName = None, azureEnvironmentEndpoints = AzureEnvironment.AZURE.getEndpoints, sparkEnvironmentInfo = "", - clientBuilderInterceptors = None + clientBuilderInterceptors = None, + clientInterceptors = None ) val containerConfig = CosmosContainerConfig(UUID.randomUUID().toString, UUID.randomUUID().toString) diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/PointWriterITest.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/PointWriterITest.scala index 432c749be624a..ea3a72f3580d0 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/PointWriterITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/PointWriterITest.scala @@ -801,7 +801,8 @@ class PointWriterITest extends IntegrationSpec with CosmosClient with AutoCleana partitionKeyDefinition, writeConfig, DiagnosticsConfig(Option.empty, isClientTelemetryEnabled = false, None), - new TestOutputMetricsPublisher) + new TestOutputMetricsPublisher, + 1) // First create one item, as patch can only operate on existing items val itemWithFullSchema = CosmosPatchTestHelper.getPatchItemWithFullSchema(UUID.randomUUID().toString, strippedPartitionKeyPath) @@ -1318,7 +1319,8 @@ class PointWriterITest extends IntegrationSpec with CosmosClient with AutoCleana partitionKeyDefinition, writeConfig, DiagnosticsConfig(Option.empty, isClientTelemetryEnabled = false, None), - new TestOutputMetricsPublisher) + new TestOutputMetricsPublisher, + 1) // First create one item val itemWithFullSchema = CosmosPatchTestHelper.getPatchItemWithFullSchema(UUID.randomUUID().toString, strippedPartitionKeyPath) diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/PointWriterSubpartitionITest.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/PointWriterSubpartitionITest.scala index f891d964c9860..2fa6c21233f02 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/PointWriterSubpartitionITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/PointWriterSubpartitionITest.scala @@ -4,7 +4,7 @@ package com.azure.cosmos.spark import com.azure.cosmos.implementation.apachecommons.lang.StringUtils -import com.azure.cosmos.models.{CosmosContainerProperties, PartitionKey, PartitionKeyBuilder, PartitionKeyDefinition, PartitionKeyDefinitionVersion, PartitionKind, ThroughputProperties} +import com.azure.cosmos.models.{CosmosContainerProperties, PartitionKeyBuilder, PartitionKeyDefinition, PartitionKeyDefinitionVersion, PartitionKind, ThroughputProperties} import com.azure.cosmos.spark.utils.{CosmosPatchTestHelper, TestOutputMetricsPublisher} import com.azure.cosmos.{CosmosAsyncContainer, CosmosException} import com.fasterxml.jackson.databind.ObjectMapper @@ -761,7 +761,7 @@ class PointWriterSubpartitionITest extends IntegrationSpec with CosmosClient wit bulkMaxPendingOperations = Some(900) ) - val bulkWriter = new BulkWriter(container, subpartitionKeyDefinition, writeConfig, DiagnosticsConfig(Option.empty, false, None),new TestOutputMetricsPublisher) + val bulkWriter = new BulkWriter(container, subpartitionKeyDefinition, writeConfig, DiagnosticsConfig(Option.empty, false, None),new TestOutputMetricsPublisher, 1) // First create one item, as patch can only operate on existing items val itemWithFullSchema = CosmosPatchTestHelper.getPatchItemWithFullSchemaSubpartitions(UUID.randomUUID().toString) @@ -1293,7 +1293,7 @@ class PointWriterSubpartitionITest extends IntegrationSpec with CosmosClient wit bulkMaxPendingOperations = Some(900) ) - val bulkWriter = new BulkWriter(container, subpartitionKeyDefinition, writeConfig, DiagnosticsConfig(Option.empty, false, None),new TestOutputMetricsPublisher) + val bulkWriter = new BulkWriter(container, subpartitionKeyDefinition, writeConfig, DiagnosticsConfig(Option.empty, false, None),new TestOutputMetricsPublisher, 1) // First create one item val itemWithFullSchema = CosmosPatchTestHelper.getPatchItemWithFullSchemaSubpartitions(UUID.randomUUID().toString) diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/SparkE2EBulkWriteITest.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/SparkE2EBulkWriteITest.scala new file mode 100644 index 0000000000000..0ad43de74e305 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/SparkE2EBulkWriteITest.scala @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.core.util.Context +import com.azure.cosmos.{CosmosDiagnosticsContext, CosmosDiagnosticsHandler, CosmosDiagnosticsThresholds} +import com.azure.cosmos.implementation.TestConfigurations +import com.azure.cosmos.models.{CosmosClientTelemetryConfig, FeedRange, PartitionKey, ShowQueryMode} +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import com.azure.cosmos.test.faultinjection.{CosmosFaultInjectionHelper, FaultInjectionConditionBuilder, FaultInjectionConnectionType, FaultInjectionEndpointBuilder, FaultInjectionOperationType, FaultInjectionResultBuilders, FaultInjectionRule, FaultInjectionRuleBuilder, FaultInjectionServerErrorType} +import org.apache.spark.scheduler.{AccumulableInfo, SparkListener, SparkListenerTaskEnd} +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Waiters.{interval, timeout} +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime + +// scalastyle:off underscore.import +import scala.collection.JavaConverters._ +// scalastyle:on underscore.import + +import java.time.Duration + +class SparkE2EBulkWriteITest + extends IntegrationSpec + with SparkWithJustDropwizardAndNoSlf4jMetrics + with CosmosClient + with AutoCleanableCosmosContainer + with BasicLoggingTrait + with MetricAssertions { + + //scalastyle:off multiple.string.literals + //scalastyle:off magic.number + //scalastyle:off null + + it should s"support bulk ingestion when BulkWriter needs to get restarted" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val configMapBuilder = scala.collection.mutable.Map( + "spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.serialization.inclusionMode" -> "NonDefault" + ) + + var faultInjectionRuleOption : Option[FaultInjectionRule] = None + + try { + // set-up logging + val logs = scala.collection.mutable.ListBuffer[CosmosDiagnosticsContext]() + + configMapBuilder += "spark.cosmos.account.clientBuilderInterceptors" -> "com.azure.cosmos.spark.TestCosmosClientBuilderInterceptor" + TestCosmosClientBuilderInterceptor.setCallback(builder => { + val thresholds = new CosmosDiagnosticsThresholds() + .setPointOperationLatencyThreshold(Duration.ZERO) + .setNonPointOperationLatencyThreshold(Duration.ZERO) + val telemetryCfg = new CosmosClientTelemetryConfig() + .showQueryMode(ShowQueryMode.ALL) + .diagnosticsHandler(new CompositeLoggingHandler(logs)) + .diagnosticsThresholds(thresholds) + builder.clientTelemetryConfig(telemetryCfg) + }) + + // set-up fault injection + configMapBuilder += "spark.cosmos.account.clientInterceptors" -> "com.azure.cosmos.spark.TestFaultInjectionClientInterceptor" + configMapBuilder += "spark.cosmos.write.flush.intervalInSeconds" -> "10" + configMapBuilder += "spark.cosmos.write.flush.noProgress.maxIntervalInSeconds" -> "30" + configMapBuilder += "spark.cosmos.write.flush.noProgress.maxRetryIntervalInSeconds" -> "300" + configMapBuilder += "spark.cosmos.write.onRetryCommitInterceptor" -> "com.azure.cosmos.spark.TestWriteOnRetryCommitInterceptor" + TestFaultInjectionClientInterceptor.setCallback(client => { + val faultInjectionResultBuilder = FaultInjectionResultBuilders + .getResultBuilder(FaultInjectionServerErrorType.RESPONSE_DELAY) + .delay(Duration.ofHours(10000)) + .times(1) + + val endpoints = new FaultInjectionEndpointBuilder( + FeedRange.forLogicalPartition(new PartitionKey("range_1"))) + .build() + + val result = faultInjectionResultBuilder.build + val condition = new FaultInjectionConditionBuilder() + .operationType(FaultInjectionOperationType.BATCH_ITEM) + .connectionType(FaultInjectionConnectionType.DIRECT) + .endpoints(endpoints) + .build + + faultInjectionRuleOption = Some(new FaultInjectionRuleBuilder("InjectedEndlessResponseDelay") + .condition(condition) + .result(result) + .build) + + TestWriteOnRetryCommitInterceptor.setCallback(() => faultInjectionRuleOption.get.disable()) + + CosmosFaultInjectionHelper.configureFaultInjectionRules( + client.getDatabase(cosmosDatabase).getContainer(cosmosContainer), + List(faultInjectionRuleOption.get).asJava).block + + client + }) + + val cfg = configMapBuilder.toMap + + val newSpark = getSpark + + // scalastyle:off underscore.import + // scalastyle:off import.grouping + import spark.implicits._ + val spark = newSpark + // scalastyle:on underscore.import + // scalastyle:on import.grouping + + val toBeIngested = scala.collection.mutable.ListBuffer[String]() + for (i <- 1 to 100) { + toBeIngested += s"record_$i" + } + + val df = toBeIngested.toDF("id") + + var bytesWrittenSnapshot = 0L + var recordsWrittenSnapshot = 0L + var totalRequestChargeSnapshot: Option[AccumulableInfo] = None + + val statusStore = spark.sharedState.statusStore + val oldCount = statusStore.executionsCount() + + spark.sparkContext + .addSparkListener( + new SparkListener { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val outputMetrics = taskEnd.taskMetrics.outputMetrics + logInfo(s"ON_TASK_END - Records written: ${outputMetrics.recordsWritten}, " + + s"Bytes written: ${outputMetrics.bytesWritten}, " + + s"${taskEnd.taskInfo.accumulables.mkString(", ")}") + bytesWrittenSnapshot = outputMetrics.bytesWritten + + recordsWrittenSnapshot = outputMetrics.recordsWritten + + taskEnd + .taskInfo + .accumulables + .filter(accumulableInfo => accumulableInfo.name.isDefined && + accumulableInfo.name.get.equals(CosmosConstants.MetricNames.TotalRequestCharge)) + .foreach( + accumulableInfo => { + totalRequestChargeSnapshot = Some(accumulableInfo) + } + ) + } + }) + + df.write.format("cosmos.oltp").mode("Append").options(cfg).save() + + // Wait until the new execution is started and being tracked. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsCount() > oldCount) + } + + // Wait for listener to finish computing the metrics for the execution. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsList().nonEmpty && + statusStore.executionsList().last.metricValues != null) + } + + recordsWrittenSnapshot shouldEqual 100 + bytesWrittenSnapshot > 0 shouldEqual true + + // that the write by spark is visible by the client query + // wait for a second to allow replication is completed. + Thread.sleep(1000) + + // the new item will be always persisted + val ids = queryItems("SELECT c.id FROM c ORDER by c.id").toArray + ids should have size 100 + val firstDoc = ids(0) + firstDoc.get("id").asText() shouldEqual "record_1" + + // validate logs + logs.nonEmpty shouldEqual true + } finally { + TestCosmosClientBuilderInterceptor.resetCallback() + TestFaultInjectionClientInterceptor.resetCallback() + faultInjectionRuleOption match { + case Some(rule) => rule.disable() + case None => + } + } + } + + class CompositeLoggingHandler(logs: scala.collection.mutable.ListBuffer[CosmosDiagnosticsContext]) extends CosmosDiagnosticsHandler { + /** + * This method will be invoked when an operation completed (successfully or failed) to allow diagnostic handlers to + * emit the diagnostics NOTE: Any code in handleDiagnostics should not execute any I/O operations, do thread + * switches or execute CPU intense work - if needed a diagnostics handler should queue this work asynchronously. The + * method handleDiagnostics will be invoked on the hot path - so, any inefficient diagnostics handler will increase + * end-to-end latency perceived by the application + * + * @param diagnosticsContext the Cosmos DB diagnostic context with metadata for the operation + * @param traceContext the Azure trace context + */ + override def handleDiagnostics(diagnosticsContext: CosmosDiagnosticsContext, traceContext: Context): Unit = { + logs += diagnosticsContext + + CosmosDiagnosticsHandler.DEFAULT_LOGGING_HANDLER.handleDiagnostics(diagnosticsContext, traceContext) + } + } +} + diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/TestCosmosClientBuilderInterceptor.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/TestCosmosClientBuilderInterceptor.scala index a30a8e70b4109..bdb914f90861c 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/TestCosmosClientBuilderInterceptor.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/TestCosmosClientBuilderInterceptor.scala @@ -5,8 +5,8 @@ package com.azure.cosmos.spark import com.azure.cosmos.CosmosClientBuilder class TestCosmosClientBuilderInterceptor extends CosmosClientBuilderInterceptor { - override def process(cosmosClientBuilder: CosmosClientBuilder): CosmosClientBuilder = { - TestCosmosClientBuilderInterceptor.callback(cosmosClientBuilder) + override def getClientBuilderInterceptor(configs: Map[String, String]): Option[CosmosClientBuilder => CosmosClientBuilder] = { + Some(TestCosmosClientBuilderInterceptor.callback) } } diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/TestFaultInjectionClientInterceptor.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/TestFaultInjectionClientInterceptor.scala new file mode 100644 index 0000000000000..9356bb823412b --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/TestFaultInjectionClientInterceptor.scala @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.CosmosAsyncClient + +class TestFaultInjectionClientInterceptor extends CosmosClientInterceptor { + override def getClientInterceptor(configs: Map[String, String]): Option[CosmosAsyncClient => CosmosAsyncClient] = { + Some(TestFaultInjectionClientInterceptor.callback) + } +} + +private[spark] object TestFaultInjectionClientInterceptor { + val defaultImplementation: CosmosAsyncClient => CosmosAsyncClient = client => client + var callback: CosmosAsyncClient => CosmosAsyncClient = defaultImplementation + def setCallback(interceptorCallback: CosmosAsyncClient => CosmosAsyncClient): Unit = { + callback = interceptorCallback + } + + def resetCallback(): Unit = { + callback = defaultImplementation + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/TestWriteOnRetryCommitInterceptor.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/TestWriteOnRetryCommitInterceptor.scala new file mode 100644 index 0000000000000..096aa570eab29 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/TestWriteOnRetryCommitInterceptor.scala @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +class TestWriteOnRetryCommitInterceptor extends WriteOnRetryCommitInterceptor { + /** + * This method will be invoked by the Cosmos DB Spark connector before retrying a commit during writes (currently + * only when bulk mode is enabled). + * + * @return A function that will be invoked before retrying a commit on the write path. + */ + override def beforeRetryCommit(): Unit = + { + TestWriteOnRetryCommitInterceptor.callback.apply() + } +} + +private[spark] object TestWriteOnRetryCommitInterceptor { + val defaultImplementation: () => Unit = () => {} + var callback: () => Unit = defaultImplementation + def setCallback(interceptorCallback: () => Unit): Unit = { + callback = interceptorCallback + } + + def resetCallback(): Unit = { + callback = defaultImplementation + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/utils/CosmosPatchTestHelper.scala b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/utils/CosmosPatchTestHelper.scala index 1f0d831df932e..b4ccc65a66a61 100644 --- a/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/utils/CosmosPatchTestHelper.scala +++ b/sdk/cosmos/azure-cosmos-spark_3_2-12/src/test/scala/com/azure/cosmos/spark/utils/CosmosPatchTestHelper.scala @@ -177,7 +177,8 @@ def getPatchFullTestSchemaWithSubpartitions(): StructType = { partitionKeyDefinition, writeConfigForPatch, DiagnosticsConfig(Option.empty, isClientTelemetryEnabled = false, None), - metricsPublisher) + metricsPublisher, + 1) } def getBulkWriterForPatchBulkUpdate(columnConfigsMap: TrieMap[String, CosmosPatchColumnConfig], @@ -196,7 +197,8 @@ def getPatchFullTestSchemaWithSubpartitions(): StructType = { partitionKeyDefinition, writeConfigForPatch, DiagnosticsConfig(Option.empty, isClientTelemetryEnabled = false, None), - new TestOutputMetricsPublisher) + new TestOutputMetricsPublisher, + 1) } def getPointWriterForPatch(columnConfigsMap: TrieMap[String, CosmosPatchColumnConfig], diff --git a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CosmosBulkAsyncTest.java b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CosmosBulkAsyncTest.java index caeb63d7d4169..e7d7ea67c0cef 100644 --- a/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CosmosBulkAsyncTest.java +++ b/sdk/cosmos/azure-cosmos-tests/src/test/java/com/azure/cosmos/CosmosBulkAsyncTest.java @@ -289,8 +289,8 @@ public void createItem_withBulk_and_operationLevelContext() { CosmosBulkExecutionOptions bulkExecutionOptions = new CosmosBulkExecutionOptions(); ImplementationBridgeHelpers.CosmosBulkExecutionOptionsHelper .getCosmosBulkExecutionOptionsAccessor() + .getImpl(bulkExecutionOptions) .setTargetedMicroBatchRetryRate( - bulkExecutionOptions, 0.25, 0.5); diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/CosmosBulkExecutionOptionsImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/CosmosBulkExecutionOptionsImpl.java index 1bba5bd5a8df0..31c0ea08cbcfa 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/CosmosBulkExecutionOptionsImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/CosmosBulkExecutionOptionsImpl.java @@ -53,6 +53,8 @@ public class CosmosBulkExecutionOptionsImpl implements OverridableRequestOptions private Set keywordIdentifiers; private Scheduler schedulerOverride = null; + private CosmosEndToEndOperationLatencyPolicyConfig e2ePolicy = null; + public CosmosBulkExecutionOptionsImpl(CosmosBulkExecutionOptionsImpl toBeCloned) { this.schedulerOverride = toBeCloned.schedulerOverride; this.initialMicroBatchSize = toBeCloned.initialMicroBatchSize; @@ -69,6 +71,7 @@ public CosmosBulkExecutionOptionsImpl(CosmosBulkExecutionOptionsImpl toBeCloned) this.diagnosticsTracker = toBeCloned.diagnosticsTracker; this.customSerializer = toBeCloned.customSerializer; this.customOptions = toBeCloned.customOptions; + this.e2ePolicy = toBeCloned.e2ePolicy; if (toBeCloned.excludeRegions != null) { this.excludeRegions = new ArrayList<>(toBeCloned.excludeRegions); @@ -278,7 +281,11 @@ public void setExcludedRegions(List excludeRegions) { @Override public CosmosEndToEndOperationLatencyPolicyConfig getCosmosEndToEndLatencyPolicyConfig() { - return null; + return this.e2ePolicy; + } + + public void setCosmosEndToEndLatencyPolicyConfig(CosmosEndToEndOperationLatencyPolicyConfig cfg) { + this.e2ePolicy = cfg; } @Override diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ImplementationBridgeHelpers.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ImplementationBridgeHelpers.java index a8624059aa20d..7b3115def5c4b 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ImplementationBridgeHelpers.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/ImplementationBridgeHelpers.java @@ -465,54 +465,6 @@ public static CosmosBulkExecutionOptionsAccessor getCosmosBulkExecutionOptionsAc } public interface CosmosBulkExecutionOptionsAccessor { - - void setOperationContext(CosmosBulkExecutionOptions options, - OperationContextAndListenerTuple operationContextAndListenerTuple); - - OperationContextAndListenerTuple getOperationContext(CosmosBulkExecutionOptions options); - - Duration getMaxMicroBatchInterval(CosmosBulkExecutionOptions options); - - CosmosBulkExecutionOptions setTargetedMicroBatchRetryRate( - CosmosBulkExecutionOptions options, - double minRetryRate, - double maxRetryRate); - - @SuppressWarnings({"unchecked"}) - T getLegacyBatchScopedContext(CosmosBulkExecutionOptions options); - - double getMinTargetedMicroBatchRetryRate(CosmosBulkExecutionOptions options); - - double getMaxTargetedMicroBatchRetryRate(CosmosBulkExecutionOptions options); - - int getMaxMicroBatchPayloadSizeInBytes(CosmosBulkExecutionOptions options); - - CosmosBulkExecutionOptions setMaxMicroBatchPayloadSizeInBytes( - CosmosBulkExecutionOptions options, - int maxMicroBatchPayloadSizeInBytes); - - int getMaxMicroBatchConcurrency(CosmosBulkExecutionOptions options); - - Integer getMaxConcurrentCosmosPartitions(CosmosBulkExecutionOptions options); - - CosmosBulkExecutionOptions setMaxConcurrentCosmosPartitions( - CosmosBulkExecutionOptions options, int mxConcurrentCosmosPartitions); - - CosmosBulkExecutionOptions setHeader(CosmosBulkExecutionOptions cosmosBulkExecutionOptions, - String name, String value); - - Map getHeader(CosmosBulkExecutionOptions cosmosBulkExecutionOptions); - - Map getCustomOptions(CosmosBulkExecutionOptions cosmosBulkExecutionOptions); - - int getMaxMicroBatchSize(CosmosBulkExecutionOptions cosmosBulkExecutionOptions); - - void setDiagnosticsTracker(CosmosBulkExecutionOptions cosmosBulkExecutionOptions, BulkExecutorDiagnosticsTracker tracker); - - BulkExecutorDiagnosticsTracker getDiagnosticsTracker(CosmosBulkExecutionOptions cosmosBulkExecutionOptions); - - CosmosBulkExecutionOptions setSchedulerOverride(CosmosBulkExecutionOptions cosmosBulkExecutionOptions, Scheduler customScheduler); - CosmosBulkExecutionOptions clone(CosmosBulkExecutionOptions toBeCloned); CosmosBulkExecutionOptionsImpl getImpl(CosmosBulkExecutionOptions options); } diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java index c1e38014c8b05..93e4f00af18c1 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/RxDocumentClientImpl.java @@ -4493,9 +4493,25 @@ public Mono executeBatchRequest(String collectionLink, boolean disableAutomaticIdGeneration) { DocumentClientRetryPolicy documentClientRetryPolicy = this.resetSessionTokenRetryPolicy.getRequestPolicy(null); AtomicReference requestReference = new AtomicReference<>(); - return handleCircuitBreakingFeedbackForPointOperation(ObservableHelper - .inlineIfPossibleAsObs(() -> executeBatchRequestInternal( - collectionLink, serverBatchRequest, options, documentClientRetryPolicy, disableAutomaticIdGeneration, requestReference), documentClientRetryPolicy), requestReference); + RequestOptions nonNullRequestOptions = options != null ? options : new RequestOptions(); + CosmosEndToEndOperationLatencyPolicyConfig endToEndPolicyConfig = + getEndToEndOperationLatencyPolicyConfig(nonNullRequestOptions, ResourceType.Document, OperationType.Batch); + ScopedDiagnosticsFactory scopedDiagnosticsFactory = new ScopedDiagnosticsFactory(this, false); + return handleCircuitBreakingFeedbackForPointOperation( + getPointOperationResponseMonoWithE2ETimeout( + nonNullRequestOptions, + endToEndPolicyConfig, + ObservableHelper + .inlineIfPossibleAsObs(() -> executeBatchRequestInternal( + collectionLink, + serverBatchRequest, + options, + documentClientRetryPolicy, + disableAutomaticIdGeneration, + requestReference), documentClientRetryPolicy), + scopedDiagnosticsFactory + ), + requestReference); } private Mono executeStoredProcedureInternal(String storedProcedureLink, diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/batch/BulkExecutor.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/batch/BulkExecutor.java index 2b65a04743ea6..78ad42db6fe36 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/batch/BulkExecutor.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/implementation/batch/BulkExecutor.java @@ -7,6 +7,7 @@ import com.azure.cosmos.CosmosAsyncClient; import com.azure.cosmos.CosmosAsyncContainer; import com.azure.cosmos.CosmosBridgeInternal; +import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig; import com.azure.cosmos.CosmosException; import com.azure.cosmos.CosmosItemSerializer; import com.azure.cosmos.ThrottlingRetryOptions; @@ -46,6 +47,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.TimeUnit; @@ -173,7 +175,9 @@ public BulkExecutor(CosmosAsyncContainer container, TimeUnit.MILLISECONDS)); Scheduler schedulerSnapshotFromOptions = cosmosBulkOptions.getSchedulerOverride(); - this.executionScheduler = schedulerSnapshotFromOptions != null ? schedulerSnapshotFromOptions : CosmosSchedulers.BULK_EXECUTOR_BOUNDED_ELASTIC; + this.executionScheduler = schedulerSnapshotFromOptions != null + ? schedulerSnapshotFromOptions + : CosmosSchedulers.BULK_EXECUTOR_BOUNDED_ELASTIC; logger.debug("Instantiated BulkExecutor, Context: {}", this.operationContextText); @@ -232,6 +236,14 @@ private void logInfoOrWarning(String msg, Object... args) { } } + private void logTraceOrWarning(String msg, Object... args) { + if (this.diagnosticsTracker == null || !this.diagnosticsTracker.verboseLoggingAfterReEnqueueingRetriesEnabled()) { + logger.trace(msg, args); + } else { + logger.warn(msg, args); + } + } + private void logDebugOrWarning(String msg, Object... args) { if (this.diagnosticsTracker == null || !this.diagnosticsTracker.verboseLoggingAfterReEnqueueingRetriesEnabled()) { logger.debug(msg, args); @@ -315,7 +327,7 @@ private Flux> executeCore() { return this.inputOperations .publishOn(this.executionScheduler) .onErrorMap(throwable -> { - logger.error("{}: Skipping an error operation while processing. Cause: {}, Context: {}", + logger.warn("{}: Error observed when processing inputOperations. Cause: {}, Context: {}", getThreadInfo(), throwable.getMessage(), this.operationContextText, @@ -414,7 +426,7 @@ private Flux> executeCore() { this.operationContextText, getThreadInfo()); } - logger.trace( + logTraceOrWarning( "Work left - TotalCount after decrement: {}, main sink completed {}, {}, Context: {} {}", totalCountAfterDecrement, mainSourceCompletedSnapshot, @@ -584,10 +596,22 @@ private Flux> executePartitionKeyRangeServ FluxSink groupSink, PartitionScopeThresholds thresholds) { + String batchTrackingId = UUID.randomUUID().toString(); + logTraceOrWarning( + "Executing batch of PKRangeId %s - batch TrackingId: %s", + serverRequest.getPartitionKeyRangeId(), + batchTrackingId); + return this.executeBatchRequest(serverRequest) .subscribeOn(this.executionScheduler) .flatMapMany(response -> { + logTraceOrWarning( + "Response for batch of PKRangeId %s - status code %s, ActivityId: %s, batch TrackingId %s", + serverRequest.getPartitionKeyRangeId(), + response.getStatusCode(), + response.getActivityId(), + batchTrackingId); if (diagnosticsTracker != null && response.getDiagnostics() != null) { diagnosticsTracker.trackDiagnostics(response.getDiagnostics().getDiagnosticsContext()); } @@ -820,6 +844,12 @@ private Mono executeBatchRequest(PartitionKeyRangeServerBat options.setExcludedRegions(cosmosBulkExecutionOptions.getExcludedRegions()); options.setKeywordIdentifiers(cosmosBulkExecutionOptions.getKeywordIdentifiers()); + CosmosEndToEndOperationLatencyPolicyConfig e2eLatencyPolicySnapshot = + cosmosBulkExecutionOptions.getCosmosEndToEndLatencyPolicyConfig(); + if (e2eLatencyPolicySnapshot != null) { + options.setCosmosEndToEndLatencyPolicyConfig(e2eLatencyPolicySnapshot); + } + // This logic is to handle custom bulk options which can be passed through encryption or through some other project Map customOptions = cosmosBulkExecutionOptions.getHeaders(); if (customOptions != null && !customOptions.isEmpty()) { diff --git a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java index ec749d4265d33..dc866924cbef1 100644 --- a/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java +++ b/sdk/cosmos/azure-cosmos/src/main/java/com/azure/cosmos/models/CosmosBulkExecutionOptions.java @@ -3,6 +3,7 @@ package com.azure.cosmos.models; +import com.azure.cosmos.CosmosEndToEndOperationLatencyPolicyConfig; import com.azure.cosmos.CosmosItemSerializer; import com.azure.cosmos.implementation.CosmosBulkExecutionOptionsImpl; import com.azure.cosmos.implementation.ImplementationBridgeHelpers; @@ -359,124 +360,23 @@ CosmosBulkExecutionOptions setSchedulerOverride(Scheduler customScheduler) { return this; } + CosmosEndToEndOperationLatencyPolicyConfig getEndToEndOperationLatencyPolicyConfig() { + return this.actualRequestOptions.getCosmosEndToEndLatencyPolicyConfig(); + } + + CosmosBulkExecutionOptions setEndToEndOperationLatencyPolicyConfig(CosmosEndToEndOperationLatencyPolicyConfig cfg) { + this.actualRequestOptions.setCosmosEndToEndLatencyPolicyConfig(cfg); + + return this; + } + + /////////////////////////////////////////////////////////////////////////////////////////// // the following helper/accessor only helps to access this class outside of this package.// /////////////////////////////////////////////////////////////////////////////////////////// static void initialize() { ImplementationBridgeHelpers.CosmosBulkExecutionOptionsHelper.setCosmosBulkExecutionOptionsAccessor( new ImplementationBridgeHelpers.CosmosBulkExecutionOptionsHelper.CosmosBulkExecutionOptionsAccessor() { - - @Override - public void setOperationContext(CosmosBulkExecutionOptions options, - OperationContextAndListenerTuple operationContextAndListenerTuple) { - options.setOperationContextAndListenerTuple(operationContextAndListenerTuple); - } - - @Override - public OperationContextAndListenerTuple getOperationContext(CosmosBulkExecutionOptions options) { - return options.getOperationContextAndListenerTuple(); - } - - @Override - @SuppressWarnings({"unchecked"}) - public T getLegacyBatchScopedContext(CosmosBulkExecutionOptions options) { - return (T)options.getLegacyBatchScopedContext(); - } - - @Override - public double getMinTargetedMicroBatchRetryRate(CosmosBulkExecutionOptions options) { - return options.getMinTargetedMicroBatchRetryRate(); - } - - @Override - public double getMaxTargetedMicroBatchRetryRate(CosmosBulkExecutionOptions options) { - return options.getMaxTargetedMicroBatchRetryRate(); - } - - @Override - public int getMaxMicroBatchPayloadSizeInBytes(CosmosBulkExecutionOptions options) { - return options.getMaxMicroBatchPayloadSizeInBytes(); - } - - @Override - public CosmosBulkExecutionOptions setMaxMicroBatchPayloadSizeInBytes( - CosmosBulkExecutionOptions options, - int maxMicroBatchPayloadSizeInBytes) { - - return options.setMaxMicroBatchPayloadSizeInBytes(maxMicroBatchPayloadSizeInBytes); - } - - @Override - public int getMaxMicroBatchConcurrency(CosmosBulkExecutionOptions options) { - return options.getMaxMicroBatchConcurrency(); - } - - @Override - public Integer getMaxConcurrentCosmosPartitions(CosmosBulkExecutionOptions options) { - return options.getMaxConcurrentCosmosPartitions(); - } - - @Override - public CosmosBulkExecutionOptions setMaxConcurrentCosmosPartitions( - CosmosBulkExecutionOptions options, int maxConcurrentCosmosPartitions) { - return options.setMaxConcurrentCosmosPartitions(maxConcurrentCosmosPartitions); - } - - @Override - public Duration getMaxMicroBatchInterval(CosmosBulkExecutionOptions options) { - return options.getMaxMicroBatchInterval(); - } - - @Override - public CosmosBulkExecutionOptions setTargetedMicroBatchRetryRate( - CosmosBulkExecutionOptions options, - double minRetryRate, - double maxRetryRate) { - - return options.setTargetedMicroBatchRetryRate(minRetryRate, maxRetryRate); - } - - @Override - public CosmosBulkExecutionOptions setHeader(CosmosBulkExecutionOptions cosmosBulkExecutionOptions, - String name, String value) { - return cosmosBulkExecutionOptions.setHeader(name, value); - } - - @Override - public Map getHeader(CosmosBulkExecutionOptions cosmosBulkExecutionOptions) { - return cosmosBulkExecutionOptions.getHeaders(); - } - - @Override - public Map getCustomOptions(CosmosBulkExecutionOptions cosmosBulkExecutionOptions) { - return cosmosBulkExecutionOptions.getHeaders(); - } - - @Override - public int getMaxMicroBatchSize(CosmosBulkExecutionOptions cosmosBulkExecutionOptions) { - if (cosmosBulkExecutionOptions == null) { - return BatchRequestResponseConstants.MAX_OPERATIONS_IN_DIRECT_MODE_BATCH_REQUEST; - } - - return cosmosBulkExecutionOptions.getMaxMicroBatchSize(); - } - - @Override - public void setDiagnosticsTracker(CosmosBulkExecutionOptions cosmosBulkExecutionOptions, BulkExecutorDiagnosticsTracker tracker) { - cosmosBulkExecutionOptions.setDiagnosticsTracker(tracker); - } - - @Override - public BulkExecutorDiagnosticsTracker getDiagnosticsTracker(CosmosBulkExecutionOptions cosmosBulkExecutionOptions) { - return cosmosBulkExecutionOptions.getDiagnosticsTracker(); - } - - @Override - public CosmosBulkExecutionOptions setSchedulerOverride(CosmosBulkExecutionOptions cosmosBulkExecutionOptions, Scheduler customScheduler) { - return cosmosBulkExecutionOptions.setSchedulerOverride(customScheduler); - } - - @Override public CosmosBulkExecutionOptions clone(CosmosBulkExecutionOptions toBeCloned) { return new CosmosBulkExecutionOptions(toBeCloned); }