Skip to content

Commit

Permalink
feat: Move retry middleware from SDK (#502)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbelkins authored Jan 5, 2023
1 parent 56aa830 commit 03ec067
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@ public struct DefaultSDKRuntimeConfiguration: SDKRuntimeConfiguration {
public let retryer: SDKRetryer
public var clientLogMode: ClientLogMode
public var endpoint: String?

/// The partition ID to be used for this configuration.
///
/// Requests made with the same partition ID will be grouped together for retry throttling purposes.
/// If no partition ID is provided, requests will be partitioned based on the hostname.
public var partitionID: String?

public init(
_ clientName: String,
clientLogMode: ClientLogMode = .request
clientLogMode: ClientLogMode = .request,
partitionID: String? = nil
) throws {
self.encoder = nil
self.decoder = nil
Expand All @@ -28,5 +35,6 @@ public struct DefaultSDKRuntimeConfiguration: SDKRuntimeConfiguration {
self.retryer = try SDKRetryer()
self.logger = SwiftLogger(label: clientName)
self.clientLogMode = clientLogMode
self.partitionID = partitionID
}
}
6 changes: 6 additions & 0 deletions Sources/ClientRuntime/Config/SDKRuntimeConfiguration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,10 @@ public protocol SDKRuntimeConfiguration {
var clientLogMode: ClientLogMode {get}
var retryer: SDKRetryer {get}
var endpoint: String? {get set}

/// The partition ID to be used for this configuration.
///
/// Requests made with the same partition ID will be grouped together for retry throttling purposes.
/// If no partition ID is provided, requests will be partitioned based on the hostname.
var partitionID: String? { get }
}
90 changes: 90 additions & 0 deletions Sources/ClientRuntime/Middleware/RetryerMiddleware.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

public struct RetryerMiddleware<Output: HttpResponseBinding,
OutputError: HttpResponseBinding>: Middleware {

public var id: String = "Retryer"

let retryer: SDKRetryer

public init(retryer: SDKRetryer) {
self.retryer = retryer
}

public func handle<H>(
context: Context,
input: SdkHttpRequestBuilder,
next: H
) async throws -> OperationOutput<Output> where
H: Handler,
Self.MInput == H.Input,
Self.MOutput == H.Output,
Self.Context == H.Context {

// Select a partition ID to be used for throttling retry requests. Requests with the
// same partition ID will be "pooled" together for throttling purposes.
let partitionID: String
if let customPartitionID = context.getPartitionID(), !customPartitionID.isEmpty {
// use custom partition ID provided by context
partitionID = customPartitionID
} else if !input.host.isEmpty {
// fall back to the hostname for partition ID, which is a "commonsense" default
partitionID = input.host
} else {
throw SdkError<OutputError>.client(ClientError.unknownError("Partition ID could not be determined"))
}

do {
let token = try await retryer.acquireToken(partitionId: partitionID)
return try await tryRequest(
token: token,
partitionID: partitionID,
context: context,
input: input,
next: next
)
} catch {
throw SdkError<OutputError>.client(ClientError.retryError(error))
}
}

func tryRequest<H>(
token: RetryToken,
errorType: RetryError? = nil,
partitionID: String,
context: Context,
input: SdkHttpRequestBuilder,
next: H
) async throws -> OperationOutput<Output> where
H: Handler,
Self.MInput == H.Input,
Self.MOutput == H.Output,
Self.Context == H.Context {

do {
let serviceResponse = try await next.handle(context: context, input: input)
retryer.recordSuccess(token: token)
return serviceResponse
} catch let error as SdkError<OutputError> where retryer.isErrorRetryable(error: error) {
let errorType = retryer.getErrorType(error: error)
let newToken = try await retryer.scheduleRetry(token: token, error: errorType)
// TODO: rewind the stream once streaming is properly implemented
return try await tryRequest(
token: newToken,
partitionID: partitionID,
context: context,
input: input,
next: next
)
}
}

public typealias MInput = SdkHttpRequestBuilder
public typealias MOutput = OperationOutput<Output>
public typealias Context = HttpContext
}
21 changes: 21 additions & 0 deletions Sources/ClientRuntime/Networking/Http/HttpContext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ public struct HttpContext: MiddlewareContext {
public func getLogger() -> LogAgent? {
return attributes.get(key: AttributeKey<LogAgent>(name: "Logger"))
}

/// The partition ID to be used for this context.
///
/// Requests made with the same partition ID will be grouped together for retry throttling purposes.
/// If no partition ID is provided, requests will be partitioned based on the hostname.
public func getPartitionID() -> String? {
return attributes.get(key: AttributeKey<String>(name: "PartitionID"))
}
}

public class HttpContextBuilder {
Expand All @@ -63,6 +71,7 @@ public class HttpContextBuilder {
let idempotencyTokenGenerator = AttributeKey<IdempotencyTokenGenerator>(name: "IdempotencyTokenGenerator")
let hostPrefix = AttributeKey<String>(name: "HostPrefix")
let logger = AttributeKey<LogAgent>(name: "Logger")
let partitionID = AttributeKey<String>(name: "PartitionID")

// We follow the convention of returning the builder object
// itself from any configuration methods, and by adding the
Expand Down Expand Up @@ -140,6 +149,18 @@ public class HttpContextBuilder {
self.attributes.set(key: logger, value: value)
return self
}

/// Sets the partition ID on the context builder.
///
/// Requests made with the same partition ID will be grouped together for retry throttling purposes.
/// If no partition ID is provided, requests will be partitioned based on the hostname.
/// - Parameter value: The partition ID to be set on this builder, or `nil`.
/// - Returns: `self`, after the partition ID is set as specified.
@discardableResult
public func withPartitionID(value: String?) -> HttpContextBuilder {
self.attributes.set(key: partitionID, value: value)
return self
}

public func build() -> HttpContext {
return HttpContext(attributes: attributes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ object ClientRuntimeTypes {
val QueryItemMiddleware = runtimeSymbol("QueryItemMiddleware")
val HeaderMiddleware = runtimeSymbol("HeaderMiddleware")
val SerializableBodyMiddleware = runtimeSymbol("SerializableBodyMiddleware")
val RetryerMiddleware = runtimeSymbol("RetryerMiddleware")
val NoopHandler = runtimeSymbol("NoopHandler")

object Providers {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInp
import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputQueryItemMiddleware
import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputUrlHostMiddleware
import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputUrlPathMiddleware
import software.amazon.smithy.swift.codegen.integration.middlewares.RetryMiddleware
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.HttpBodyMiddleware
import software.amazon.smithy.swift.codegen.integration.middlewares.providers.HttpHeaderProvider
import software.amazon.smithy.swift.codegen.integration.middlewares.providers.HttpQueryItemProvider
Expand Down Expand Up @@ -404,6 +405,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {

operationMiddleware.appendMiddleware(operation, LoggingMiddleware(ctx.model, ctx.symbolProvider))
operationMiddleware.appendMiddleware(operation, DeserializeMiddleware(ctx.model, ctx.symbolProvider))
operationMiddleware.appendMiddleware(operation, RetryMiddleware(ctx.model, ctx.symbolProvider))

addProtocolSpecificMiddleware(ctx, operation)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

package software.amazon.smithy.swift.codegen.integration.middlewares

import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
import software.amazon.smithy.swift.codegen.SwiftWriter
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils
import software.amazon.smithy.swift.codegen.middleware.MiddlewarePosition
import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable
import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep

class RetryMiddleware(
val model: Model,
val symbolProvider: SymbolProvider
) : MiddlewareRenderable {

override val name = "RetryMiddleware"

override val middlewareStep = MiddlewareStep.FINALIZESTEP

override val position = MiddlewarePosition.AFTER

override fun render(writer: SwiftWriter, op: OperationShape, operationStackName: String) {
val output = MiddlewareShapeUtils.outputSymbol(symbolProvider, model, op)
val outputError = MiddlewareShapeUtils.outputErrorSymbol(op)
writer.write("$operationStackName.${middlewareStep.stringValue()}.intercept(position: ${position.stringValue()}, middleware: \$N<\$N, \$N>(retryer: config.retryer))", ClientRuntimeTypes.Middleware.RetryerMiddleware, output, outputError)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class MiddlewareExecutionGenerator(
writer.write(" .withOperation(value: \"${op.toLowerCamelCase()}\")")
writer.write(" .withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator)")
writer.write(" .withLogger(value: config.logger)")
writer.write(" .withPartitionID(value: config.partitionID)")

val serviceShape = ctx.service
httpProtocolCustomizable.renderContextAttributes(ctx, writer, serviceShape, op)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ContentMd5MiddlewareTests {
.withOperation(value: "idempotencyTokenWithStructure")
.withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator)
.withLogger(value: config.logger)
.withPartitionID(value: config.partitionID)
var operation = ClientRuntime.OperationStack<IdempotencyTokenWithStructureInput, IdempotencyTokenWithStructureOutputResponse, IdempotencyTokenWithStructureOutputError>(id: "idempotencyTokenWithStructure")
operation.initializeStep.intercept(position: .after, id: "IdempotencyTokenMiddleware") { (context, input, next) -> ClientRuntime.OperationOutput<IdempotencyTokenWithStructureOutputResponse> in
let idempotencyTokenGenerator = context.getIdempotencyTokenGenerator()
Expand All @@ -34,6 +35,7 @@ class ContentMd5MiddlewareTests {
operation.serializeStep.intercept(position: .after, middleware: ContentTypeMiddleware<IdempotencyTokenWithStructureInput, IdempotencyTokenWithStructureOutputResponse>(contentType: "application/xml"))
operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.SerializableBodyMiddleware<IdempotencyTokenWithStructureInput, IdempotencyTokenWithStructureOutputResponse>(xmlName: "IdempotencyToken"))
operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware())
operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryerMiddleware<IdempotencyTokenWithStructureOutputResponse, IdempotencyTokenWithStructureOutputError>(retryer: config.retryer))
operation.deserializeStep.intercept(position: .before, middleware: ClientRuntime.LoggerMiddleware<IdempotencyTokenWithStructureOutputResponse, IdempotencyTokenWithStructureOutputError>(clientLogMode: config.clientLogMode))
operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware<IdempotencyTokenWithStructureOutputResponse, IdempotencyTokenWithStructureOutputError>())
let result = try await operation.handleMiddleware(context: context.build(), input: input, next: client.getHandler())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class HttpProtocolClientGeneratorTests {
.withOperation(value: "allocateWidget")
.withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator)
.withLogger(value: config.logger)
.withPartitionID(value: config.partitionID)
var operation = ClientRuntime.OperationStack<AllocateWidgetInput, AllocateWidgetOutputResponse, AllocateWidgetOutputError>(id: "allocateWidget")
operation.initializeStep.intercept(position: .after, id: "IdempotencyTokenMiddleware") { (context, input, next) -> ClientRuntime.OperationOutput<AllocateWidgetOutputResponse> in
let idempotencyTokenGenerator = context.getIdempotencyTokenGenerator()
Expand All @@ -142,6 +143,7 @@ class HttpProtocolClientGeneratorTests {
operation.serializeStep.intercept(position: .after, middleware: ContentTypeMiddleware<AllocateWidgetInput, AllocateWidgetOutputResponse>(contentType: "application/json"))
operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.SerializableBodyMiddleware<AllocateWidgetInput, AllocateWidgetOutputResponse>(xmlName: "AllocateWidgetInput"))
operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware())
operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryerMiddleware<AllocateWidgetOutputResponse, AllocateWidgetOutputError>(retryer: config.retryer))
operation.deserializeStep.intercept(position: .before, middleware: ClientRuntime.LoggerMiddleware<AllocateWidgetOutputResponse, AllocateWidgetOutputError>(clientLogMode: config.clientLogMode))
operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware<AllocateWidgetOutputResponse, AllocateWidgetOutputError>())
let result = try await operation.handleMiddleware(context: context.build(), input: input, next: client.getHandler())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class IdempotencyTokenTraitTests {
.withOperation(value: "idempotencyTokenWithStructure")
.withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator)
.withLogger(value: config.logger)
.withPartitionID(value: config.partitionID)
var operation = ClientRuntime.OperationStack<IdempotencyTokenWithStructureInput, IdempotencyTokenWithStructureOutputResponse, IdempotencyTokenWithStructureOutputError>(id: "idempotencyTokenWithStructure")
operation.initializeStep.intercept(position: .after, id: "IdempotencyTokenMiddleware") { (context, input, next) -> ClientRuntime.OperationOutput<IdempotencyTokenWithStructureOutputResponse> in
let idempotencyTokenGenerator = context.getIdempotencyTokenGenerator()
Expand All @@ -33,6 +34,7 @@ class IdempotencyTokenTraitTests {
operation.serializeStep.intercept(position: .after, middleware: ContentTypeMiddleware<IdempotencyTokenWithStructureInput, IdempotencyTokenWithStructureOutputResponse>(contentType: "application/xml"))
operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.SerializableBodyMiddleware<IdempotencyTokenWithStructureInput, IdempotencyTokenWithStructureOutputResponse>(xmlName: "IdempotencyToken"))
operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware())
operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryerMiddleware<IdempotencyTokenWithStructureOutputResponse, IdempotencyTokenWithStructureOutputError>(retryer: config.retryer))
operation.deserializeStep.intercept(position: .before, middleware: ClientRuntime.LoggerMiddleware<IdempotencyTokenWithStructureOutputResponse, IdempotencyTokenWithStructureOutputError>(clientLogMode: config.clientLogMode))
operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware<IdempotencyTokenWithStructureOutputResponse, IdempotencyTokenWithStructureOutputError>())
let result = try await operation.handleMiddleware(context: context.build(), input: input, next: client.getHandler())
Expand Down
24 changes: 24 additions & 0 deletions smithy-swift-codegen/src/test/kotlin/RetryMiddlewareTests.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import io.kotest.matchers.string.shouldContainOnlyOnce
import org.junit.jupiter.api.Test

class RetryMiddlewareTests {

@Test
fun `generates operation with retry middleware`() {
val context = setupTests("Isolated/contentmd5checksum.smithy", "aws.protocoltests.restxml#RestXml")
val contents = getFileContents(context.manifest, "/RestXml/RestXmlProtocolClient.swift")
val expectedContents = """
operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryerMiddleware<IdempotencyTokenWithStructureOutputResponse, IdempotencyTokenWithStructureOutputError>(retryer: config.retryer))
""".trimIndent()
contents.shouldContainOnlyOnce(expectedContents)
}
private fun setupTests(smithyFile: String, serviceShapeId: String): TestContext {
val context = TestContext.initContextFrom(smithyFile, serviceShapeId, MockHttpRestXMLProtocolGenerator()) { model ->
model.defaultSettings(serviceShapeId, "RestXml", "2019-12-16", "Rest Xml Protocol")
}
context.generator.initializeMiddleware(context.generationCtx)
context.generator.generateProtocolClient(context.generationCtx)
context.generationCtx.delegator.flushWriters()
return context
}
}

0 comments on commit 03ec067

Please sign in to comment.