diff --git a/Sources/ClientRuntime/Orchestrator/Orchestrator.swift b/Sources/ClientRuntime/Orchestrator/Orchestrator.swift index 84e0ba69b..0898d9f6f 100644 --- a/Sources/ClientRuntime/Orchestrator/Orchestrator.swift +++ b/Sources/ClientRuntime/Orchestrator/Orchestrator.swift @@ -265,6 +265,13 @@ public struct Orchestrator< // If we can't get errorInfo, we definitely can't retry guard let errorInfo = retryErrorInfoProvider(error) else { return } + // If the body is a nonseekable stream, we also can't retry + do { + guard try readyBodyForRetry(request: copiedRequest) else { return } + } catch { + return + } + // When refreshing fails it throws, indicating we're done retrying do { try await strategy.refreshRetryTokenForRetry(tokenToRenew: token, errorInfo: errorInfo) @@ -277,6 +284,25 @@ public struct Orchestrator< } } + /// Readies the body for retry, and indicates whether the request body may be safely used in a retry. + /// - Parameter request: The request to be retried. + /// - Returns: `true` if the body of the request is safe to retry, `false` otherwise. In general, a request body is retriable if it is not a stream, or + /// if the stream is seekable and successfully seeks to the start position / offset zero. + private func readyBodyForRetry(request: RequestType) throws -> Bool { + switch request.body { + case .stream(let stream): + guard stream.isSeekable else { return false } + do { + try stream.seek(toOffset: 0) + return true + } catch { + return false + } + case .data, .noStream: + return true + } + } + private func attempt(context: InterceptorContextType, attemptCount: Int) async { // If anything in here fails, the attempt short-circuits and we go to modifyBeforeAttemptCompletion, // with the thrown error in context.result diff --git a/Tests/ClientRuntimeTests/OrchestratorTests/OrchestratorTests.swift b/Tests/ClientRuntimeTests/OrchestratorTests/OrchestratorTests.swift index 2e63544b6..084519220 100644 --- a/Tests/ClientRuntimeTests/OrchestratorTests/OrchestratorTests.swift +++ b/Tests/ClientRuntimeTests/OrchestratorTests/OrchestratorTests.swift @@ -15,6 +15,7 @@ import SmithyRetriesAPI import SmithyRetries @_spi(SmithyReadWrite) import SmithyJSON @_spi(SmithyReadWrite) import SmithyReadWrite +import SmithyStreams class OrchestratorTests: XCTestCase { struct TestInput { @@ -167,9 +168,11 @@ class OrchestratorTests: XCTestCase { } class TraceExecuteRequest: ExecuteRequest { - var succeedAfter: Int + let succeedAfter: Int var trace: Trace + private(set) var requestCount = 0 + init(succeedAfter: Int = 0, trace: Trace) { self.succeedAfter = succeedAfter self.trace = trace @@ -177,10 +180,11 @@ class OrchestratorTests: XCTestCase { public func execute(request: HTTPRequest, attributes: Context) async throws -> HTTPResponse { trace.append("executeRequest") - if succeedAfter <= 0 { + if succeedAfter - requestCount <= 0 { + requestCount += 1 return HTTPResponse(body: request.body, statusCode: .ok) } else { - succeedAfter -= 1 + requestCount += 1 return HTTPResponse(body: request.body, statusCode: .internalServerError) } } @@ -233,7 +237,7 @@ class OrchestratorTests: XCTestCase { throw try UnknownHTTPServiceError.makeError(baseError: baseError) } }) - .retryStrategy(DefaultRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ExponentialBackoffStrategy()))) + .retryStrategy(DefaultRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ImmediateBackoffStrategy()))) .retryErrorInfoProvider({ e in trace.append("errorInfo") return DefaultRetryErrorInfoProvider.errorInfo(for: e) @@ -530,7 +534,7 @@ class OrchestratorTests: XCTestCase { let initialTokenTrace = Trace() let initialToken = await asyncResult { return try await self.traceOrchestrator(trace: initialTokenTrace) - .retryStrategy(ThrowingRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ExponentialBackoffStrategy()))) + .retryStrategy(ThrowingRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ImmediateBackoffStrategy()))) .build() .execute(input: TestInput(foo: "")) } @@ -1315,4 +1319,60 @@ class OrchestratorTests: XCTestCase { } } } + + /// Used in retry tests to perform the next retry without waiting, so that tests complete without delay. + private struct ImmediateBackoffStrategy: RetryBackoffStrategy { + func computeNextBackoffDelay(attempt: Int) -> TimeInterval { 0.0 } + } + + func test_retry_retriesDataBody() async throws { + let input = TestInput(foo: "bar") + let trace = Trace() + let executeRequest = TraceExecuteRequest(succeedAfter: 2, trace: trace) + let orchestrator = traceOrchestrator(trace: trace) + .retryStrategy(DefaultRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ImmediateBackoffStrategy()))) + .serialize({ (input: TestInput, builder: HTTPRequestBuilder, context) in + builder.withBody(.data(Data("\"\(input.foo)\"".utf8))) + }) + .executeRequest(executeRequest) + let result = await asyncResult { + return try await orchestrator.build().execute(input: input) + } + XCTAssertNoThrow(try result.get()) + XCTAssertEqual(executeRequest.requestCount, 3) + } + + func test_retry_doesntRetryNonSeekableStreamBody() async throws { + let input = TestInput(foo: "bar") + let trace = Trace() + let executeRequest = TraceExecuteRequest(succeedAfter: 2, trace: trace) + let orchestrator = traceOrchestrator(trace: trace) + .retryStrategy(DefaultRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ImmediateBackoffStrategy()))) + .serialize({ (input: TestInput, builder: HTTPRequestBuilder, context) in + builder.withBody(.stream(BufferedStream(data: Data("\"\(input.foo)\"".utf8), isClosed: true))) + }) + .executeRequest(executeRequest) + let result = await asyncResult { + return try await orchestrator.build().execute(input: input) + } + XCTAssertThrowsError(try result.get()) + XCTAssertEqual(executeRequest.requestCount, 1) + } + + func test_retry_nonSeekableStreamBodySucceeds() async throws { + let input = TestInput(foo: "bar") + let trace = Trace() + let executeRequest = TraceExecuteRequest(succeedAfter: 0, trace: trace) + let orchestrator = traceOrchestrator(trace: trace) + .retryStrategy(DefaultRetryStrategy(options: RetryStrategyOptions(backoffStrategy: ImmediateBackoffStrategy()))) + .serialize({ (input: TestInput, builder: HTTPRequestBuilder, context) in + builder.withBody(.stream(BufferedStream(data: Data("\"\(input.foo)\"".utf8), isClosed: true))) + }) + .executeRequest(executeRequest) + let result = await asyncResult { + return try await orchestrator.build().execute(input: input) + } + XCTAssertNoThrow(try result.get()) + XCTAssertEqual(executeRequest.requestCount, 1) + } }