diff --git a/Sources/ClientRuntime/Config/DefaultSDKRuntimeConfiguration.swift b/Sources/ClientRuntime/Config/DefaultSDKRuntimeConfiguration.swift index f9e8fd2c9..3d2ad5f62 100644 --- a/Sources/ClientRuntime/Config/DefaultSDKRuntimeConfiguration.swift +++ b/Sources/ClientRuntime/Config/DefaultSDKRuntimeConfiguration.swift @@ -15,10 +15,17 @@ public struct DefaultSDKRuntimeConfiguration: SDKRuntimeConfiguration { public let retryer: SDKRetryer public var clientLogMode: ClientLogMode public var endpoint: String? + + /// The partition ID to be used for this configuration. + /// + /// Requests made with the same partition ID will be grouped together for retry throttling purposes. + /// If no partition ID is provided, requests will be partitioned based on the hostname. + public var partitionID: String? public init( _ clientName: String, - clientLogMode: ClientLogMode = .request + clientLogMode: ClientLogMode = .request, + partitionID: String? = nil ) throws { self.encoder = nil self.decoder = nil @@ -28,5 +35,6 @@ public struct DefaultSDKRuntimeConfiguration: SDKRuntimeConfiguration { self.retryer = try SDKRetryer() self.logger = SwiftLogger(label: clientName) self.clientLogMode = clientLogMode + self.partitionID = partitionID } } diff --git a/Sources/ClientRuntime/Config/SDKRuntimeConfiguration.swift b/Sources/ClientRuntime/Config/SDKRuntimeConfiguration.swift index 3b84663a5..b09ecab29 100644 --- a/Sources/ClientRuntime/Config/SDKRuntimeConfiguration.swift +++ b/Sources/ClientRuntime/Config/SDKRuntimeConfiguration.swift @@ -18,4 +18,10 @@ public protocol SDKRuntimeConfiguration { var clientLogMode: ClientLogMode {get} var retryer: SDKRetryer {get} var endpoint: String? {get set} + + /// The partition ID to be used for this configuration. + /// + /// Requests made with the same partition ID will be grouped together for retry throttling purposes. + /// If no partition ID is provided, requests will be partitioned based on the hostname. + var partitionID: String? { get } } diff --git a/Sources/ClientRuntime/Middleware/RetryerMiddleware.swift b/Sources/ClientRuntime/Middleware/RetryerMiddleware.swift new file mode 100644 index 000000000..c72d7fa56 --- /dev/null +++ b/Sources/ClientRuntime/Middleware/RetryerMiddleware.swift @@ -0,0 +1,90 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +public struct RetryerMiddleware: Middleware { + + public var id: String = "Retryer" + + let retryer: SDKRetryer + + public init(retryer: SDKRetryer) { + self.retryer = retryer + } + + public func handle( + context: Context, + input: SdkHttpRequestBuilder, + next: H + ) async throws -> OperationOutput where + H: Handler, + Self.MInput == H.Input, + Self.MOutput == H.Output, + Self.Context == H.Context { + + // Select a partition ID to be used for throttling retry requests. Requests with the + // same partition ID will be "pooled" together for throttling purposes. + let partitionID: String + if let customPartitionID = context.getPartitionID(), !customPartitionID.isEmpty { + // use custom partition ID provided by context + partitionID = customPartitionID + } else if !input.host.isEmpty { + // fall back to the hostname for partition ID, which is a "commonsense" default + partitionID = input.host + } else { + throw SdkError.client(ClientError.unknownError("Partition ID could not be determined")) + } + + do { + let token = try await retryer.acquireToken(partitionId: partitionID) + return try await tryRequest( + token: token, + partitionID: partitionID, + context: context, + input: input, + next: next + ) + } catch { + throw SdkError.client(ClientError.retryError(error)) + } + } + + func tryRequest( + token: RetryToken, + errorType: RetryError? = nil, + partitionID: String, + context: Context, + input: SdkHttpRequestBuilder, + next: H + ) async throws -> OperationOutput where + H: Handler, + Self.MInput == H.Input, + Self.MOutput == H.Output, + Self.Context == H.Context { + + do { + let serviceResponse = try await next.handle(context: context, input: input) + retryer.recordSuccess(token: token) + return serviceResponse + } catch let error as SdkError where retryer.isErrorRetryable(error: error) { + let errorType = retryer.getErrorType(error: error) + let newToken = try await retryer.scheduleRetry(token: token, error: errorType) + // TODO: rewind the stream once streaming is properly implemented + return try await tryRequest( + token: newToken, + partitionID: partitionID, + context: context, + input: input, + next: next + ) + } + } + + public typealias MInput = SdkHttpRequestBuilder + public typealias MOutput = OperationOutput + public typealias Context = HttpContext +} diff --git a/Sources/ClientRuntime/Networking/Http/HttpContext.swift b/Sources/ClientRuntime/Networking/Http/HttpContext.swift index adb7f6e7a..bae2d5de1 100644 --- a/Sources/ClientRuntime/Networking/Http/HttpContext.swift +++ b/Sources/ClientRuntime/Networking/Http/HttpContext.swift @@ -45,6 +45,14 @@ public struct HttpContext: MiddlewareContext { public func getLogger() -> LogAgent? { return attributes.get(key: AttributeKey(name: "Logger")) } + + /// The partition ID to be used for this context. + /// + /// Requests made with the same partition ID will be grouped together for retry throttling purposes. + /// If no partition ID is provided, requests will be partitioned based on the hostname. + public func getPartitionID() -> String? { + return attributes.get(key: AttributeKey(name: "PartitionID")) + } } public class HttpContextBuilder { @@ -63,6 +71,7 @@ public class HttpContextBuilder { let idempotencyTokenGenerator = AttributeKey(name: "IdempotencyTokenGenerator") let hostPrefix = AttributeKey(name: "HostPrefix") let logger = AttributeKey(name: "Logger") + let partitionID = AttributeKey(name: "PartitionID") // We follow the convention of returning the builder object // itself from any configuration methods, and by adding the @@ -140,6 +149,18 @@ public class HttpContextBuilder { self.attributes.set(key: logger, value: value) return self } + + /// Sets the partition ID on the context builder. + /// + /// Requests made with the same partition ID will be grouped together for retry throttling purposes. + /// If no partition ID is provided, requests will be partitioned based on the hostname. + /// - Parameter value: The partition ID to be set on this builder, or `nil`. + /// - Returns: `self`, after the partition ID is set as specified. + @discardableResult + public func withPartitionID(value: String?) -> HttpContextBuilder { + self.attributes.set(key: partitionID, value: value) + return self + } public func build() -> HttpContext { return HttpContext(attributes: attributes) diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/ClientRuntimeTypes.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/ClientRuntimeTypes.kt index d65f8c9ca..20b096adf 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/ClientRuntimeTypes.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/ClientRuntimeTypes.kt @@ -62,6 +62,7 @@ object ClientRuntimeTypes { val QueryItemMiddleware = runtimeSymbol("QueryItemMiddleware") val HeaderMiddleware = runtimeSymbol("HeaderMiddleware") val SerializableBodyMiddleware = runtimeSymbol("SerializableBodyMiddleware") + val RetryerMiddleware = runtimeSymbol("RetryerMiddleware") val NoopHandler = runtimeSymbol("NoopHandler") object Providers { diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpBindingProtocolGenerator.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpBindingProtocolGenerator.kt index d5a3726a6..d2aaf4675 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpBindingProtocolGenerator.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpBindingProtocolGenerator.kt @@ -47,6 +47,7 @@ import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInp import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputQueryItemMiddleware import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputUrlHostMiddleware import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputUrlPathMiddleware +import software.amazon.smithy.swift.codegen.integration.middlewares.RetryMiddleware import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.HttpBodyMiddleware import software.amazon.smithy.swift.codegen.integration.middlewares.providers.HttpHeaderProvider import software.amazon.smithy.swift.codegen.integration.middlewares.providers.HttpQueryItemProvider @@ -404,6 +405,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { operationMiddleware.appendMiddleware(operation, LoggingMiddleware(ctx.model, ctx.symbolProvider)) operationMiddleware.appendMiddleware(operation, DeserializeMiddleware(ctx.model, ctx.symbolProvider)) + operationMiddleware.appendMiddleware(operation, RetryMiddleware(ctx.model, ctx.symbolProvider)) addProtocolSpecificMiddleware(ctx, operation) diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/middlewares/RetryMiddleware.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/middlewares/RetryMiddleware.kt new file mode 100644 index 000000000..35d2d57a5 --- /dev/null +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/middlewares/RetryMiddleware.kt @@ -0,0 +1,34 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.swift.codegen.integration.middlewares + +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.swift.codegen.ClientRuntimeTypes +import software.amazon.smithy.swift.codegen.SwiftWriter +import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils +import software.amazon.smithy.swift.codegen.middleware.MiddlewarePosition +import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable +import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep + +class RetryMiddleware( + val model: Model, + val symbolProvider: SymbolProvider +) : MiddlewareRenderable { + + override val name = "RetryMiddleware" + + override val middlewareStep = MiddlewareStep.FINALIZESTEP + + override val position = MiddlewarePosition.AFTER + + override fun render(writer: SwiftWriter, op: OperationShape, operationStackName: String) { + val output = MiddlewareShapeUtils.outputSymbol(symbolProvider, model, op) + val outputError = MiddlewareShapeUtils.outputErrorSymbol(op) + writer.write("$operationStackName.${middlewareStep.stringValue()}.intercept(position: ${position.stringValue()}, middleware: \$N<\$N, \$N>(retryer: config.retryer))", ClientRuntimeTypes.Middleware.RetryerMiddleware, output, outputError) + } +} diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/middleware/MiddlewareExecutionGenerator.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/middleware/MiddlewareExecutionGenerator.kt index 0c37b746b..a90c3cf61 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/middleware/MiddlewareExecutionGenerator.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/middleware/MiddlewareExecutionGenerator.kt @@ -49,6 +49,7 @@ class MiddlewareExecutionGenerator( writer.write(" .withOperation(value: \"${op.toLowerCamelCase()}\")") writer.write(" .withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator)") writer.write(" .withLogger(value: config.logger)") + writer.write(" .withPartitionID(value: config.partitionID)") val serviceShape = ctx.service httpProtocolCustomizable.renderContextAttributes(ctx, writer, serviceShape, op) diff --git a/smithy-swift-codegen/src/test/kotlin/ContentMd5MiddlewareTests.kt b/smithy-swift-codegen/src/test/kotlin/ContentMd5MiddlewareTests.kt index e9adb6982..c4d3bae1b 100644 --- a/smithy-swift-codegen/src/test/kotlin/ContentMd5MiddlewareTests.kt +++ b/smithy-swift-codegen/src/test/kotlin/ContentMd5MiddlewareTests.kt @@ -19,6 +19,7 @@ class ContentMd5MiddlewareTests { .withOperation(value: "idempotencyTokenWithStructure") .withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator) .withLogger(value: config.logger) + .withPartitionID(value: config.partitionID) var operation = ClientRuntime.OperationStack(id: "idempotencyTokenWithStructure") operation.initializeStep.intercept(position: .after, id: "IdempotencyTokenMiddleware") { (context, input, next) -> ClientRuntime.OperationOutput in let idempotencyTokenGenerator = context.getIdempotencyTokenGenerator() @@ -34,6 +35,7 @@ class ContentMd5MiddlewareTests { operation.serializeStep.intercept(position: .after, middleware: ContentTypeMiddleware(contentType: "application/xml")) operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.SerializableBodyMiddleware(xmlName: "IdempotencyToken")) operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware()) + operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryerMiddleware(retryer: config.retryer)) operation.deserializeStep.intercept(position: .before, middleware: ClientRuntime.LoggerMiddleware(clientLogMode: config.clientLogMode)) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware()) let result = try await operation.handleMiddleware(context: context.build(), input: input, next: client.getHandler()) diff --git a/smithy-swift-codegen/src/test/kotlin/HttpProtocolClientGeneratorTests.kt b/smithy-swift-codegen/src/test/kotlin/HttpProtocolClientGeneratorTests.kt index 9e338d3dc..88dc4cacc 100644 --- a/smithy-swift-codegen/src/test/kotlin/HttpProtocolClientGeneratorTests.kt +++ b/smithy-swift-codegen/src/test/kotlin/HttpProtocolClientGeneratorTests.kt @@ -128,6 +128,7 @@ class HttpProtocolClientGeneratorTests { .withOperation(value: "allocateWidget") .withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator) .withLogger(value: config.logger) + .withPartitionID(value: config.partitionID) var operation = ClientRuntime.OperationStack(id: "allocateWidget") operation.initializeStep.intercept(position: .after, id: "IdempotencyTokenMiddleware") { (context, input, next) -> ClientRuntime.OperationOutput in let idempotencyTokenGenerator = context.getIdempotencyTokenGenerator() @@ -142,6 +143,7 @@ class HttpProtocolClientGeneratorTests { operation.serializeStep.intercept(position: .after, middleware: ContentTypeMiddleware(contentType: "application/json")) operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.SerializableBodyMiddleware(xmlName: "AllocateWidgetInput")) operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware()) + operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryerMiddleware(retryer: config.retryer)) operation.deserializeStep.intercept(position: .before, middleware: ClientRuntime.LoggerMiddleware(clientLogMode: config.clientLogMode)) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware()) let result = try await operation.handleMiddleware(context: context.build(), input: input, next: client.getHandler()) diff --git a/smithy-swift-codegen/src/test/kotlin/IdempotencyTokenTraitTests.kt b/smithy-swift-codegen/src/test/kotlin/IdempotencyTokenTraitTests.kt index 546f7b5ae..2d87491f5 100644 --- a/smithy-swift-codegen/src/test/kotlin/IdempotencyTokenTraitTests.kt +++ b/smithy-swift-codegen/src/test/kotlin/IdempotencyTokenTraitTests.kt @@ -19,6 +19,7 @@ class IdempotencyTokenTraitTests { .withOperation(value: "idempotencyTokenWithStructure") .withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator) .withLogger(value: config.logger) + .withPartitionID(value: config.partitionID) var operation = ClientRuntime.OperationStack(id: "idempotencyTokenWithStructure") operation.initializeStep.intercept(position: .after, id: "IdempotencyTokenMiddleware") { (context, input, next) -> ClientRuntime.OperationOutput in let idempotencyTokenGenerator = context.getIdempotencyTokenGenerator() @@ -33,6 +34,7 @@ class IdempotencyTokenTraitTests { operation.serializeStep.intercept(position: .after, middleware: ContentTypeMiddleware(contentType: "application/xml")) operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.SerializableBodyMiddleware(xmlName: "IdempotencyToken")) operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware()) + operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryerMiddleware(retryer: config.retryer)) operation.deserializeStep.intercept(position: .before, middleware: ClientRuntime.LoggerMiddleware(clientLogMode: config.clientLogMode)) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware()) let result = try await operation.handleMiddleware(context: context.build(), input: input, next: client.getHandler()) diff --git a/smithy-swift-codegen/src/test/kotlin/RetryMiddlewareTests.kt b/smithy-swift-codegen/src/test/kotlin/RetryMiddlewareTests.kt new file mode 100644 index 000000000..23ab07f3e --- /dev/null +++ b/smithy-swift-codegen/src/test/kotlin/RetryMiddlewareTests.kt @@ -0,0 +1,24 @@ +import io.kotest.matchers.string.shouldContainOnlyOnce +import org.junit.jupiter.api.Test + +class RetryMiddlewareTests { + + @Test + fun `generates operation with retry middleware`() { + val context = setupTests("Isolated/contentmd5checksum.smithy", "aws.protocoltests.restxml#RestXml") + val contents = getFileContents(context.manifest, "/RestXml/RestXmlProtocolClient.swift") + val expectedContents = """ + operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryerMiddleware(retryer: config.retryer)) + """.trimIndent() + contents.shouldContainOnlyOnce(expectedContents) + } + private fun setupTests(smithyFile: String, serviceShapeId: String): TestContext { + val context = TestContext.initContextFrom(smithyFile, serviceShapeId, MockHttpRestXMLProtocolGenerator()) { model -> + model.defaultSettings(serviceShapeId, "RestXml", "2019-12-16", "Rest Xml Protocol") + } + context.generator.initializeMiddleware(context.generationCtx) + context.generator.generateProtocolClient(context.generationCtx) + context.generationCtx.delegator.flushWriters() + return context + } +}