Skip to content

Commit

Permalink
Merge pull request #237 from weaviate/grpc-auth-test
Browse files Browse the repository at this point in the history
Add gRPC authorization suppport
  • Loading branch information
antas-marcin authored Nov 14, 2023
2 parents 5d266f2 + d100de3 commit 79e1ec3
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 21 deletions.
16 changes: 12 additions & 4 deletions src/main/java/io/weaviate/client/WeaviateClient.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package io.weaviate.client;

import io.weaviate.client.base.grpc.GrpcClient;
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.http.builder.HttpApacheClientBuilder;
import io.weaviate.client.base.http.impl.CommonsHttpClientImpl;
import io.weaviate.client.base.util.DbVersionProvider;
import io.weaviate.client.base.util.DbVersionSupport;
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import io.weaviate.client.v1.backup.Backup;
import io.weaviate.client.v1.batch.Batch;
Expand All @@ -25,20 +27,26 @@ public class WeaviateClient {
private final DbVersionProvider dbVersionProvider;
private final DbVersionSupport dbVersionSupport;
private final HttpClient httpClient;
private final WeaviateGrpc.WeaviateBlockingStub grpcClient;

public WeaviateClient(Config config) {
this(config, new CommonsHttpClientImpl(config.getHeaders(), null, HttpApacheClientBuilder.build(config)));
this(config, new CommonsHttpClientImpl(config.getHeaders(), null, HttpApacheClientBuilder.build(config)), null);
}

public WeaviateClient(Config config, AccessTokenProvider tokenProvider) {
this(config, new CommonsHttpClientImpl(config.getHeaders(), tokenProvider, HttpApacheClientBuilder.build(config)));
this(config, new CommonsHttpClientImpl(config.getHeaders(), tokenProvider, HttpApacheClientBuilder.build(config)), tokenProvider);
}

public WeaviateClient(Config config, HttpClient httpClient) {
public WeaviateClient(Config config, HttpClient httpClient, AccessTokenProvider tokenProvider) {
this.config = config;
this.httpClient = httpClient;
dbVersionProvider = initDbVersionProvider();
dbVersionSupport = new DbVersionSupport(dbVersionProvider);
if (this.config.useGRPC()) {
this.grpcClient = GrpcClient.create(config, tokenProvider);
} else {
this.grpcClient = null;
}
}

public Misc misc() {
Expand All @@ -56,7 +64,7 @@ public Data data() {

public Batch batch() {
dbVersionProvider.refresh();
return new Batch(httpClient, config, dbVersionSupport, data());
return new Batch(httpClient, config, dbVersionSupport, grpcClient, data());
}

public Backup backup() {
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/io/weaviate/client/base/grpc/GrpcClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
import io.grpc.stub.MetadataUtils;
import io.weaviate.client.Config;
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import java.util.Map;

public class GrpcClient {

public static WeaviateGrpc.WeaviateBlockingStub create(Config config) {
public static WeaviateGrpc.WeaviateBlockingStub create(Config config, AccessTokenProvider tokenProvider) {
Metadata headers = new Metadata();
if (config.getHeaders() != null) {
for (Map.Entry<String, String> e : config.getHeaders().entrySet()) {
headers.put(Metadata.Key.of(e.getKey(), Metadata.ASCII_STRING_MARSHALLER), e.getValue());
}
}
if (tokenProvider != null) {
headers.put(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER), String.format("Bearer %s", tokenProvider.getAccessToken()));
}
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(getAddress(config));
if (config.isGRPCSecured()) {
channelBuilder = channelBuilder.useTransportSecurity();
Expand Down
10 changes: 7 additions & 3 deletions src/main/java/io/weaviate/client/v1/batch/Batch.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.util.BeaconPath;
import io.weaviate.client.base.util.DbVersionSupport;
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter;
import io.weaviate.client.v1.batch.api.ObjectsBatcher;
import io.weaviate.client.v1.batch.api.ReferencePayloadBuilder;
Expand All @@ -15,14 +16,17 @@
public class Batch {
private final Config config;
private final HttpClient httpClient;
private final WeaviateGrpc.WeaviateBlockingStub grpcClient;
private final BeaconPath beaconPath;
private final ObjectsPath objectsPath;
private final ReferencesPath referencesPath;
private final Data data;

public Batch(HttpClient httpClient, Config config, DbVersionSupport dbVersionSupport, Data data) {
public Batch(HttpClient httpClient, Config config, DbVersionSupport dbVersionSupport,
WeaviateGrpc.WeaviateBlockingStub grpcClient, Data data) {
this.config = config;
this.httpClient = httpClient;
this.grpcClient = grpcClient;
this.beaconPath = new BeaconPath(dbVersionSupport);
this.objectsPath = new ObjectsPath();
this.referencesPath = new ReferencesPath();
Expand All @@ -34,7 +38,7 @@ public ObjectsBatcher objectsBatcher() {
}

public ObjectsBatcher objectsBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig) {
return ObjectsBatcher.create(httpClient, config, data, objectsPath, batchRetriesConfig);
return ObjectsBatcher.create(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig);
}

public ObjectsBatcher objectsAutoBatcher() {
Expand All @@ -60,7 +64,7 @@ public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.AutoBatchConfig autoBatc

public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig,
ObjectsBatcher.AutoBatchConfig autoBatchConfig) {
return ObjectsBatcher.createAuto(httpClient, config, data, objectsPath, batchRetriesConfig, autoBatchConfig);
return ObjectsBatcher.createAuto(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig, autoBatchConfig);
}

public ObjectsBatchDeleter objectsBatchDeleter() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,16 @@ public class ObjectsBatcher extends BaseClient<ObjectGetResponse[]>
private final List<WeaviateObject> objects;
private String consistencyLevel;
private final List<CompletableFuture<Result<ObjectGetResponse[]>>> undoneFutures;
private WeaviateGrpc.WeaviateBlockingStub grpcClient;
private boolean useGRPC;
private final WeaviateGrpc.WeaviateBlockingStub grpcClient;
private final boolean useGRPC;


private ObjectsBatcher(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath,
WeaviateGrpc.WeaviateBlockingStub grpcClient,
BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) {
super(httpClient, config);
if (config.useGRPC()) {
this.useGRPC = config.useGRPC();
this.grpcClient = GrpcClient.create(config);
}
this.useGRPC = config.useGRPC();
this.grpcClient = grpcClient;
this.data = data;
this.objectsPath = objectsPath;
this.objects = new ArrayList<>();
Expand All @@ -100,16 +99,18 @@ private ObjectsBatcher(HttpClient httpClient, Config config, Data data, ObjectsP
}

public static ObjectsBatcher create(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath,
WeaviateGrpc.WeaviateBlockingStub grpcClient,
BatchRetriesConfig batchRetriesConfig) {
Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig");
return new ObjectsBatcher(httpClient, config, data, objectsPath, batchRetriesConfig, null);
return new ObjectsBatcher(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig, null);
}

public static ObjectsBatcher createAuto(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath,
WeaviateGrpc.WeaviateBlockingStub grpcClient,
BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) {
Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig");
Assert.requiredNotNull(autoBatchConfig, "autoBatchConfig");
return new ObjectsBatcher(httpClient, config, data, objectsPath, batchRetriesConfig, autoBatchConfig);
return new ObjectsBatcher(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig, autoBatchConfig);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,19 @@
import io.weaviate.client.base.Result;
import io.weaviate.client.base.WeaviateError;
import io.weaviate.client.v1.auth.exception.AuthException;
import io.weaviate.client.v1.batch.model.ObjectGetResponse;
import io.weaviate.client.v1.data.model.WeaviateObject;
import io.weaviate.client.v1.misc.model.Meta;
import io.weaviate.client.v1.schema.model.DataType;
import io.weaviate.client.v1.schema.model.Property;
import io.weaviate.client.v1.schema.model.WeaviateClass;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.assertj.core.api.InstanceOfAssertFactories.ARRAY;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
Expand All @@ -17,27 +29,34 @@

import static io.weaviate.integration.client.WeaviateVersion.EXPECTED_WEAVIATE_VERSION;
import static org.assertj.core.api.Assertions.assertThat;
import org.testcontainers.shaded.com.fasterxml.jackson.databind.annotation.JsonAppend;

public class AuthWCSUsersApiKeyTest {

private Config config;
private static String host;
private static Integer port;
private static String grpcHost;
private static Integer grpcPort;
private static final String API_KEY = "my-secret-key";
private static final String INVALID_API_KEY = "my-not-so-secret-key";

@ClassRule
public static DockerComposeContainer compose = new DockerComposeContainer(
new File("src/test/resources/docker-compose-wcs.yaml")
).withExposedService("weaviate-auth-wcs_1", 8085, Wait.forHttp("/v1/.well-known/ready").forStatusCode(200));
).withExposedService("weaviate-auth-wcs_1", 8085, Wait.forListeningPorts(8085))
.withExposedService("weaviate-auth-wcs_1", 50051, Wait.forListeningPorts(50051));

@Before
public void before() {
String host = compose.getServiceHost("weaviate-auth-wcs_1", 8085);
Integer port = compose.getServicePort("weaviate-auth-wcs_1", 8085);
config = new Config("http", host + ":" + port);
host = compose.getServiceHost("weaviate-auth-wcs_1", 8085);
port = compose.getServicePort("weaviate-auth-wcs_1", 8085);
grpcHost = compose.getServiceHost("weaviate-auth-wcs_1", 50051);
grpcPort = compose.getServicePort("weaviate-auth-wcs_1", 50051);
}

@Test
public void shouldAuthenticateWithValidApiKey() throws AuthException {
Config config = new Config("http", host + ":" + port);
WeaviateClient client = WeaviateAuthClient.apiKey(config, API_KEY);
Result<Meta> meta = client.misc().metaGetter().run();

Expand All @@ -50,6 +69,7 @@ public void shouldAuthenticateWithValidApiKey() throws AuthException {

@Test
public void shouldNotAuthenticateWithInvalidApiKey() throws AuthException {
Config config = new Config("http", host + ":" + port);
WeaviateClient client = WeaviateAuthClient.apiKey(config, INVALID_API_KEY);
Result<Meta> meta = client.misc().metaGetter().run();

Expand All @@ -59,4 +79,55 @@ public void shouldNotAuthenticateWithInvalidApiKey() throws AuthException {
.extracting(Result::getError)
.returns(401, WeaviateError::getStatusCode);
}

@Test
public void shouldAuthenticateWithValidApiKeyUsingGRPC() throws AuthException {
Config config = new Config("http", host + ":" + port);
config.setGRPCHost(grpcHost + ":" + grpcPort);
WeaviateClient client = WeaviateAuthClient.apiKey(config, API_KEY);

Result<Boolean> deleteAll = client.schema().allDeleter().run();
assertThat(deleteAll).isNotNull()
.returns(false, Result::hasErrors)
.extracting(Result::getResult).isEqualTo(Boolean.TRUE);

String id = "00000000-0000-0000-0000-000000000001";
String className = "TestGRPC";
String propertyName = "name";
List<Property> properties = new ArrayList<>();
properties.add(Property.builder().name("name").dataType(Collections.singletonList(DataType.TEXT)).build());
WeaviateClass clazz = WeaviateClass.builder().className(className).properties(properties).build();
Result<Boolean> createClass = client.schema().classCreator().withClass(clazz).run();

assertThat(createClass).isNotNull()
.returns(false, Result::hasErrors)
.returns(true, Result::getResult);

Map<String, Object> props = new HashMap<>();
props.put("name", "John Doe");

WeaviateObject obj = WeaviateObject.builder().id(id).className(className).properties(props).build();

Result<ObjectGetResponse[]> result = client.batch().objectsBatcher()
.withObjects(obj)
.run();
assertThat(result).isNotNull()
.returns(false, Result::hasErrors)
.extracting(Result::getResult).asInstanceOf(ARRAY)
.hasSize(1);

Result<List<WeaviateObject>> resultObj = client.data().objectsGetter().withClassName(className).withID(id).run();
assertThat(resultObj).isNotNull()
.returns(false, Result::hasErrors)
.extracting(Result::getResult).isNotNull()
.extracting(r -> r.get(0)).isNotNull()
.satisfies(o -> {
assertThat(o.getId()).isEqualTo(obj.getId());
assertThat(o.getProperties()).isNotNull()
.extracting(Map::size).isEqualTo(obj.getProperties().size());
assertThat(o.getProperties()).isNotEmpty().satisfies(p -> {
assertThat(p.get(propertyName)).isNotNull();
});
});
}
}

0 comments on commit 79e1ec3

Please sign in to comment.