From d100de3cd51a13a1c7d03eb66043f0b394130942 Mon Sep 17 00:00:00 2001 From: Marcin Antas Date: Tue, 14 Nov 2023 13:28:25 +0100 Subject: [PATCH] Add gRPC authorization suppport --- .../io/weaviate/client/WeaviateClient.java | 16 +++- .../weaviate/client/base/grpc/GrpcClient.java | 6 +- .../io/weaviate/client/v1/batch/Batch.java | 10 ++- .../client/v1/batch/api/ObjectsBatcher.java | 17 ++-- .../client/auth/AuthWCSUsersApiKeyTest.java | 81 +++++++++++++++++-- 5 files changed, 109 insertions(+), 21 deletions(-) diff --git a/src/main/java/io/weaviate/client/WeaviateClient.java b/src/main/java/io/weaviate/client/WeaviateClient.java index c3efb8a7..f4b7356b 100644 --- a/src/main/java/io/weaviate/client/WeaviateClient.java +++ b/src/main/java/io/weaviate/client/WeaviateClient.java @@ -1,10 +1,12 @@ package io.weaviate.client; +import io.weaviate.client.base.grpc.GrpcClient; import io.weaviate.client.base.http.HttpClient; import io.weaviate.client.base.http.builder.HttpApacheClientBuilder; import io.weaviate.client.base.http.impl.CommonsHttpClientImpl; import io.weaviate.client.base.util.DbVersionProvider; import io.weaviate.client.base.util.DbVersionSupport; +import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc; import io.weaviate.client.v1.auth.provider.AccessTokenProvider; import io.weaviate.client.v1.backup.Backup; import io.weaviate.client.v1.batch.Batch; @@ -25,20 +27,26 @@ public class WeaviateClient { private final DbVersionProvider dbVersionProvider; private final DbVersionSupport dbVersionSupport; private final HttpClient httpClient; + private final WeaviateGrpc.WeaviateBlockingStub grpcClient; public WeaviateClient(Config config) { - this(config, new CommonsHttpClientImpl(config.getHeaders(), null, HttpApacheClientBuilder.build(config))); + this(config, new CommonsHttpClientImpl(config.getHeaders(), null, HttpApacheClientBuilder.build(config)), null); } public WeaviateClient(Config config, AccessTokenProvider tokenProvider) { - this(config, new CommonsHttpClientImpl(config.getHeaders(), tokenProvider, HttpApacheClientBuilder.build(config))); + this(config, new CommonsHttpClientImpl(config.getHeaders(), tokenProvider, HttpApacheClientBuilder.build(config)), tokenProvider); } - public WeaviateClient(Config config, HttpClient httpClient) { + public WeaviateClient(Config config, HttpClient httpClient, AccessTokenProvider tokenProvider) { this.config = config; this.httpClient = httpClient; dbVersionProvider = initDbVersionProvider(); dbVersionSupport = new DbVersionSupport(dbVersionProvider); + if (this.config.useGRPC()) { + this.grpcClient = GrpcClient.create(config, tokenProvider); + } else { + this.grpcClient = null; + } } public Misc misc() { @@ -56,7 +64,7 @@ public Data data() { public Batch batch() { dbVersionProvider.refresh(); - return new Batch(httpClient, config, dbVersionSupport, data()); + return new Batch(httpClient, config, dbVersionSupport, grpcClient, data()); } public Backup backup() { diff --git a/src/main/java/io/weaviate/client/base/grpc/GrpcClient.java b/src/main/java/io/weaviate/client/base/grpc/GrpcClient.java index 97f2fd81..8ef41f64 100644 --- a/src/main/java/io/weaviate/client/base/grpc/GrpcClient.java +++ b/src/main/java/io/weaviate/client/base/grpc/GrpcClient.java @@ -6,17 +6,21 @@ import io.grpc.stub.MetadataUtils; import io.weaviate.client.Config; import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc; +import io.weaviate.client.v1.auth.provider.AccessTokenProvider; import java.util.Map; public class GrpcClient { - public static WeaviateGrpc.WeaviateBlockingStub create(Config config) { + public static WeaviateGrpc.WeaviateBlockingStub create(Config config, AccessTokenProvider tokenProvider) { Metadata headers = new Metadata(); if (config.getHeaders() != null) { for (Map.Entry e : config.getHeaders().entrySet()) { headers.put(Metadata.Key.of(e.getKey(), Metadata.ASCII_STRING_MARSHALLER), e.getValue()); } } + if (tokenProvider != null) { + headers.put(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER), String.format("Bearer %s", tokenProvider.getAccessToken())); + } ManagedChannelBuilder channelBuilder = ManagedChannelBuilder.forTarget(getAddress(config)); if (config.isGRPCSecured()) { channelBuilder = channelBuilder.useTransportSecurity(); diff --git a/src/main/java/io/weaviate/client/v1/batch/Batch.java b/src/main/java/io/weaviate/client/v1/batch/Batch.java index ce5d6f02..246bf885 100644 --- a/src/main/java/io/weaviate/client/v1/batch/Batch.java +++ b/src/main/java/io/weaviate/client/v1/batch/Batch.java @@ -3,6 +3,7 @@ import io.weaviate.client.base.http.HttpClient; import io.weaviate.client.base.util.BeaconPath; import io.weaviate.client.base.util.DbVersionSupport; +import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc; import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter; import io.weaviate.client.v1.batch.api.ObjectsBatcher; import io.weaviate.client.v1.batch.api.ReferencePayloadBuilder; @@ -15,14 +16,17 @@ public class Batch { private final Config config; private final HttpClient httpClient; + private final WeaviateGrpc.WeaviateBlockingStub grpcClient; private final BeaconPath beaconPath; private final ObjectsPath objectsPath; private final ReferencesPath referencesPath; private final Data data; - public Batch(HttpClient httpClient, Config config, DbVersionSupport dbVersionSupport, Data data) { + public Batch(HttpClient httpClient, Config config, DbVersionSupport dbVersionSupport, + WeaviateGrpc.WeaviateBlockingStub grpcClient, Data data) { this.config = config; this.httpClient = httpClient; + this.grpcClient = grpcClient; this.beaconPath = new BeaconPath(dbVersionSupport); this.objectsPath = new ObjectsPath(); this.referencesPath = new ReferencesPath(); @@ -34,7 +38,7 @@ public ObjectsBatcher objectsBatcher() { } public ObjectsBatcher objectsBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig) { - return ObjectsBatcher.create(httpClient, config, data, objectsPath, batchRetriesConfig); + return ObjectsBatcher.create(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig); } public ObjectsBatcher objectsAutoBatcher() { @@ -60,7 +64,7 @@ public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.AutoBatchConfig autoBatc public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, ObjectsBatcher.AutoBatchConfig autoBatchConfig) { - return ObjectsBatcher.createAuto(httpClient, config, data, objectsPath, batchRetriesConfig, autoBatchConfig); + return ObjectsBatcher.createAuto(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig, autoBatchConfig); } public ObjectsBatchDeleter objectsBatchDeleter() { diff --git a/src/main/java/io/weaviate/client/v1/batch/api/ObjectsBatcher.java b/src/main/java/io/weaviate/client/v1/batch/api/ObjectsBatcher.java index d7b90fb3..907edd33 100644 --- a/src/main/java/io/weaviate/client/v1/batch/api/ObjectsBatcher.java +++ b/src/main/java/io/weaviate/client/v1/batch/api/ObjectsBatcher.java @@ -68,17 +68,16 @@ public class ObjectsBatcher extends BaseClient private final List objects; private String consistencyLevel; private final List>> undoneFutures; - private WeaviateGrpc.WeaviateBlockingStub grpcClient; - private boolean useGRPC; + private final WeaviateGrpc.WeaviateBlockingStub grpcClient; + private final boolean useGRPC; private ObjectsBatcher(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath, + WeaviateGrpc.WeaviateBlockingStub grpcClient, BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) { super(httpClient, config); - if (config.useGRPC()) { - this.useGRPC = config.useGRPC(); - this.grpcClient = GrpcClient.create(config); - } + this.useGRPC = config.useGRPC(); + this.grpcClient = grpcClient; this.data = data; this.objectsPath = objectsPath; this.objects = new ArrayList<>(); @@ -100,16 +99,18 @@ private ObjectsBatcher(HttpClient httpClient, Config config, Data data, ObjectsP } public static ObjectsBatcher create(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath, + WeaviateGrpc.WeaviateBlockingStub grpcClient, BatchRetriesConfig batchRetriesConfig) { Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig"); - return new ObjectsBatcher(httpClient, config, data, objectsPath, batchRetriesConfig, null); + return new ObjectsBatcher(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig, null); } public static ObjectsBatcher createAuto(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath, + WeaviateGrpc.WeaviateBlockingStub grpcClient, BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) { Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig"); Assert.requiredNotNull(autoBatchConfig, "autoBatchConfig"); - return new ObjectsBatcher(httpClient, config, data, objectsPath, batchRetriesConfig, autoBatchConfig); + return new ObjectsBatcher(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig, autoBatchConfig); } diff --git a/src/test/java/io/weaviate/integration/client/auth/AuthWCSUsersApiKeyTest.java b/src/test/java/io/weaviate/integration/client/auth/AuthWCSUsersApiKeyTest.java index d210402c..90c30f2e 100644 --- a/src/test/java/io/weaviate/integration/client/auth/AuthWCSUsersApiKeyTest.java +++ b/src/test/java/io/weaviate/integration/client/auth/AuthWCSUsersApiKeyTest.java @@ -6,7 +6,19 @@ import io.weaviate.client.base.Result; import io.weaviate.client.base.WeaviateError; import io.weaviate.client.v1.auth.exception.AuthException; +import io.weaviate.client.v1.batch.model.ObjectGetResponse; +import io.weaviate.client.v1.data.model.WeaviateObject; import io.weaviate.client.v1.misc.model.Meta; +import io.weaviate.client.v1.schema.model.DataType; +import io.weaviate.client.v1.schema.model.Property; +import io.weaviate.client.v1.schema.model.WeaviateClass; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import static org.assertj.core.api.InstanceOfAssertFactories.ARRAY; import org.junit.Before; import org.junit.ClassRule; import org.junit.Test; @@ -17,27 +29,34 @@ import static io.weaviate.integration.client.WeaviateVersion.EXPECTED_WEAVIATE_VERSION; import static org.assertj.core.api.Assertions.assertThat; +import org.testcontainers.shaded.com.fasterxml.jackson.databind.annotation.JsonAppend; public class AuthWCSUsersApiKeyTest { - private Config config; + private static String host; + private static Integer port; + private static String grpcHost; + private static Integer grpcPort; private static final String API_KEY = "my-secret-key"; private static final String INVALID_API_KEY = "my-not-so-secret-key"; @ClassRule public static DockerComposeContainer compose = new DockerComposeContainer( new File("src/test/resources/docker-compose-wcs.yaml") - ).withExposedService("weaviate-auth-wcs_1", 8085, Wait.forHttp("/v1/.well-known/ready").forStatusCode(200)); + ).withExposedService("weaviate-auth-wcs_1", 8085, Wait.forListeningPorts(8085)) + .withExposedService("weaviate-auth-wcs_1", 50051, Wait.forListeningPorts(50051)); @Before public void before() { - String host = compose.getServiceHost("weaviate-auth-wcs_1", 8085); - Integer port = compose.getServicePort("weaviate-auth-wcs_1", 8085); - config = new Config("http", host + ":" + port); + host = compose.getServiceHost("weaviate-auth-wcs_1", 8085); + port = compose.getServicePort("weaviate-auth-wcs_1", 8085); + grpcHost = compose.getServiceHost("weaviate-auth-wcs_1", 50051); + grpcPort = compose.getServicePort("weaviate-auth-wcs_1", 50051); } @Test public void shouldAuthenticateWithValidApiKey() throws AuthException { + Config config = new Config("http", host + ":" + port); WeaviateClient client = WeaviateAuthClient.apiKey(config, API_KEY); Result meta = client.misc().metaGetter().run(); @@ -50,6 +69,7 @@ public void shouldAuthenticateWithValidApiKey() throws AuthException { @Test public void shouldNotAuthenticateWithInvalidApiKey() throws AuthException { + Config config = new Config("http", host + ":" + port); WeaviateClient client = WeaviateAuthClient.apiKey(config, INVALID_API_KEY); Result meta = client.misc().metaGetter().run(); @@ -59,4 +79,55 @@ public void shouldNotAuthenticateWithInvalidApiKey() throws AuthException { .extracting(Result::getError) .returns(401, WeaviateError::getStatusCode); } + + @Test + public void shouldAuthenticateWithValidApiKeyUsingGRPC() throws AuthException { + Config config = new Config("http", host + ":" + port); + config.setGRPCHost(grpcHost + ":" + grpcPort); + WeaviateClient client = WeaviateAuthClient.apiKey(config, API_KEY); + + Result deleteAll = client.schema().allDeleter().run(); + assertThat(deleteAll).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).isEqualTo(Boolean.TRUE); + + String id = "00000000-0000-0000-0000-000000000001"; + String className = "TestGRPC"; + String propertyName = "name"; + List properties = new ArrayList<>(); + properties.add(Property.builder().name("name").dataType(Collections.singletonList(DataType.TEXT)).build()); + WeaviateClass clazz = WeaviateClass.builder().className(className).properties(properties).build(); + Result createClass = client.schema().classCreator().withClass(clazz).run(); + + assertThat(createClass).isNotNull() + .returns(false, Result::hasErrors) + .returns(true, Result::getResult); + + Map props = new HashMap<>(); + props.put("name", "John Doe"); + + WeaviateObject obj = WeaviateObject.builder().id(id).className(className).properties(props).build(); + + Result result = client.batch().objectsBatcher() + .withObjects(obj) + .run(); + assertThat(result).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).asInstanceOf(ARRAY) + .hasSize(1); + + Result> resultObj = client.data().objectsGetter().withClassName(className).withID(id).run(); + assertThat(resultObj).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).isNotNull() + .extracting(r -> r.get(0)).isNotNull() + .satisfies(o -> { + assertThat(o.getId()).isEqualTo(obj.getId()); + assertThat(o.getProperties()).isNotNull() + .extracting(Map::size).isEqualTo(obj.getProperties().size()); + assertThat(o.getProperties()).isNotEmpty().satisfies(p -> { + assertThat(p.get(propertyName)).isNotNull(); + }); + }); + } }