Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve testing of custom HTTP client #56

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,32 @@
import static org.assertj.core.api.Assertions.assertThat;

import com.google.common.collect.ImmutableMap;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.stream.Stream;
import lombok.Builder;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.Value;
import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.header.Headers;
import org.apache.kafka.common.header.internals.RecordHeaders;
import org.apache.kafka.common.serialization.Serdes;
import org.apache.kafka.common.serialization.Serializer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder;
import software.amazon.awssdk.http.ExecutableHttpRequest;
import software.amazon.awssdk.http.HttpExecuteRequest;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.CreateBucketRequest;
import software.amazon.awssdk.utils.AttributeMap;

class AmazonS3LargeMessageClientRoundtripTest extends AmazonS3IntegrationTest {

Expand Down Expand Up @@ -94,6 +106,42 @@ void shouldRoundtrip(final RoundtripArgument argument) {
}
}

@Test
void shouldUseConfiguredSdkHttpClientBuilder() {
final String bucket = "bucket";
final String basePath = "s3://" + bucket + "/base/";
final Map<String, Object> properties = ImmutableMap.<String, Object>builder()
.put(AbstractLargeMessageConfig.MAX_BYTE_SIZE_CONFIG, 0)
.put(AbstractLargeMessageConfig.BASE_PATH_CONFIG, basePath)
.put(AbstractLargeMessageConfig.S3_SDK_HTTP_CLIENT_BUILDER_CONFIG, RecordingHttpClientBuilder.class)
.build();
final S3Client s3 = this.getS3Client();
s3.createBucket(CreateBucketRequest.builder().bucket(bucket).build());
final Map<String, Object> fullProperties = this.createStorerProperties(properties);
final AbstractLargeMessageConfig config = new AbstractLargeMessageConfig(fullProperties);
try (final LargeMessageStoringClient storer = config.getStorer();
final LargeMessageRetrievingClient retriever = config.getRetriever()) {

final Headers headers = new RecordHeaders();
final byte[] obj = serialize("foo");
final boolean isKey = false;
final byte[] data = storer.storeBytes(TOPIC, obj, isKey, headers);

final byte[] result = retriever.retrieveBytes(data, headers, isKey);
assertThat(result).isEqualTo(obj);
assertThat(RecordingHttpClient.REQUESTS)
.hasSize(2)
.anySatisfy(request -> {
assertThat(request.method()).isEqualTo(SdkHttpMethod.PUT);
assertThat(request.encodedPath()).startsWith("/" + bucket + "/base/" + TOPIC + "/values/");
})
.anySatisfy(request -> {
assertThat(request.method()).isEqualTo(SdkHttpMethod.GET);
assertThat(request.encodedPath()).startsWith("/" + bucket + "/base/" + TOPIC + "/values/");
});
}
}

private Map<String, Object> createStorerProperties(final Map<String, Object> properties) {
return ImmutableMap.<String, Object>builder()
.putAll(properties)
Expand All @@ -120,4 +168,36 @@ static class RoundtripArgument {
boolean isPathStyleAccess;
String compressionType;
}

@RequiredArgsConstructor
public static class RecordingHttpClientBuilder<T extends SdkHttpClient.Builder<T>>
implements SdkHttpClient.Builder<T> {

@Override
public SdkHttpClient buildWithDefaults(final AttributeMap attributeMap) {
return new RecordingHttpClient(new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap));
}
}

@RequiredArgsConstructor
private static class RecordingHttpClient implements SdkHttpClient {
private static final Collection<SdkHttpRequest> REQUESTS = new ConcurrentLinkedQueue<>();
private final @NonNull SdkHttpClient wrapped;

@Override
public ExecutableHttpRequest prepareRequest(final HttpExecuteRequest httpExecuteRequest) {
REQUESTS.add(httpExecuteRequest.httpRequest());
return this.wrapped.prepareRequest(httpExecuteRequest);
}

@Override
public String clientName() {
return "MockSdkHttpClient";
}

@Override
public void close() {
this.wrapped.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,16 @@

import static com.bakdata.kafka.LargeMessageRetrievingClientTest.serializeUri;
import static org.assertj.core.api.Assertions.assertThat;
import static software.amazon.awssdk.core.client.config.SdkClientOption.CONFIGURED_SYNC_HTTP_CLIENT_BUILDER;
import static software.amazon.awssdk.core.client.config.SdkClientOption.SYNC_HTTP_CLIENT;

import com.google.common.collect.ImmutableMap;
import java.lang.reflect.Field;
import java.util.Map;
import java.util.function.Supplier;
import org.apache.kafka.common.header.internals.RecordHeaders;
import org.apache.kafka.common.serialization.Serdes;
import org.apache.kafka.common.serialization.Serializer;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.http.ExecutableHttpRequest;
import software.amazon.awssdk.http.HttpExecuteRequest;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.CreateBucketRequest;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.utils.AttributeMap;

class LargeMessageRetrievingClientS3IntegrationTest extends AmazonS3IntegrationTest {

Expand All @@ -69,31 +59,6 @@ void shouldReadBackedText() {
}
}

@Test
void shouldUseConfiguredSdkHttpClientBuilder() {
final String bucket = "bucket";
final String basePath = "s3://" + bucket + "/base/";
final Map<String, Object> properties = ImmutableMap.<String, Object>builder()
.put(AbstractLargeMessageConfig.S3_REGION_CONFIG, "us-east-1")
.put(AbstractLargeMessageConfig.MAX_BYTE_SIZE_CONFIG, 0)
.put(AbstractLargeMessageConfig.BASE_PATH_CONFIG, basePath)
.put(AbstractLargeMessageConfig.S3_SDK_HTTP_CLIENT_BUILDER_CONFIG, MockSdkHttpClientBuilder.class.getName())
.build();
AbstractLargeMessageConfig config = new AbstractLargeMessageConfig(properties);
LargeMessageRetrievingClient retriever = config.getRetriever();
// Get private field clientFactories
Map<String, Supplier<BlobStorageClient>> clientFactories = getPrivateField(retriever, "clientFactories", Map.class);
BlobStorageClient blobStorageClient = clientFactories.get("s3").get();
// Get private field s3Client
S3Client s3Client = getPrivateField(blobStorageClient, "s3", S3Client.class);
// Get private field clientConfiguration
SdkClientConfiguration clientConfiguration = getPrivateField(s3Client, "clientConfiguration", SdkClientConfiguration.class);
// Get private field attributes
AttributeMap attributes = getPrivateField(clientConfiguration, "attributes", AttributeMap.class);
assertThat(attributes.get(SYNC_HTTP_CLIENT)).isExactlyInstanceOf(MockSdkHttpClient.class);
assertThat(attributes.get(CONFIGURED_SYNC_HTTP_CLIENT_BUILDER)).isExactlyInstanceOf(MockSdkHttpClientBuilder.class);
}

private LargeMessageRetrievingClient createRetriever() {
final Map<String, String> properties = this.getLargeMessageConfig();
final AbstractLargeMessageConfig config = new AbstractLargeMessageConfig(properties);
Expand All @@ -106,37 +71,4 @@ private void store(final String bucket, final String key, final String s) {
.key(key)
.build(), RequestBody.fromString(s));
}

private static <T> T getPrivateField(Object object, String fieldName, Class<T> fieldType) {
try {
Field field = object.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
return fieldType.cast(field.get(object));
} catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}

public static class MockSdkHttpClientBuilder implements SdkHttpClient.Builder {
@Override
public SdkHttpClient buildWithDefaults(AttributeMap attributeMap) {
return new MockSdkHttpClient();
}
}

private static class MockSdkHttpClient implements SdkHttpClient {
@Override
public ExecutableHttpRequest prepareRequest(HttpExecuteRequest httpExecuteRequest) {
return null;
}

public String clientName() {
return "MockSdkHttpClient";
}

@Override
public void close() {

}
}
}