Skip to content

Commit

Permalink
Merge pull request #238 from weaviate/grpc-channel-shutdown-fix
Browse files Browse the repository at this point in the history
Fix gRPC channel shutdown error
  • Loading branch information
antas-marcin authored Nov 14, 2023
2 parents 79e1ec3 + 065285d commit 0bc8bf6
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 51 deletions.
14 changes: 3 additions & 11 deletions src/main/java/io/weaviate/client/WeaviateClient.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package io.weaviate.client;

import io.weaviate.client.base.grpc.GrpcClient;
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.http.builder.HttpApacheClientBuilder;
import io.weaviate.client.base.http.impl.CommonsHttpClientImpl;
import io.weaviate.client.base.util.DbVersionProvider;
import io.weaviate.client.base.util.DbVersionSupport;
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import io.weaviate.client.v1.backup.Backup;
import io.weaviate.client.v1.batch.Batch;
Expand All @@ -18,16 +16,14 @@
import io.weaviate.client.v1.misc.Misc;
import io.weaviate.client.v1.misc.api.MetaGetter;
import io.weaviate.client.v1.schema.Schema;
import java.util.Calendar;
import java.util.Date;
import java.util.Optional;

public class WeaviateClient {
private final Config config;
private final DbVersionProvider dbVersionProvider;
private final DbVersionSupport dbVersionSupport;
private final HttpClient httpClient;
private final WeaviateGrpc.WeaviateBlockingStub grpcClient;
private final AccessTokenProvider tokenProvider;

public WeaviateClient(Config config) {
this(config, new CommonsHttpClientImpl(config.getHeaders(), null, HttpApacheClientBuilder.build(config)), null);
Expand All @@ -42,11 +38,7 @@ public WeaviateClient(Config config, HttpClient httpClient, AccessTokenProvider
this.httpClient = httpClient;
dbVersionProvider = initDbVersionProvider();
dbVersionSupport = new DbVersionSupport(dbVersionProvider);
if (this.config.useGRPC()) {
this.grpcClient = GrpcClient.create(config, tokenProvider);
} else {
this.grpcClient = null;
}
this.tokenProvider = tokenProvider;
}

public Misc misc() {
Expand All @@ -64,7 +56,7 @@ public Data data() {

public Batch batch() {
dbVersionProvider.refresh();
return new Batch(httpClient, config, dbVersionSupport, grpcClient, data());
return new Batch(httpClient, config, dbVersionSupport, tokenProvider, data());
}

public Backup backup() {
Expand Down
24 changes: 22 additions & 2 deletions src/main/java/io/weaviate/client/base/grpc/GrpcClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,31 @@
import io.grpc.stub.MetadataUtils;
import io.weaviate.client.Config;
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import java.util.Map;
import lombok.AccessLevel;
import lombok.experimental.FieldDefaults;

@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class GrpcClient {
WeaviateGrpc.WeaviateBlockingStub client;
ManagedChannel channel;

public static WeaviateGrpc.WeaviateBlockingStub create(Config config, AccessTokenProvider tokenProvider) {
private GrpcClient(WeaviateGrpc.WeaviateBlockingStub client, ManagedChannel channel) {
this.client = client;
this.channel = channel;
}

public WeaviateProtoBatch.BatchObjectsReply batchObjects(WeaviateProtoBatch.BatchObjectsRequest request) {
return this.client.batchObjects(request);
}

public void shutdown() {
this.channel.shutdown();
}

public static GrpcClient create(Config config, AccessTokenProvider tokenProvider) {
Metadata headers = new Metadata();
if (config.getHeaders() != null) {
for (Map.Entry<String, String> e : config.getHeaders().entrySet()) {
Expand All @@ -29,7 +48,8 @@ public static WeaviateGrpc.WeaviateBlockingStub create(Config config, AccessToke
}
ManagedChannel channel = channelBuilder.build();
WeaviateGrpc.WeaviateBlockingStub blockingStub = WeaviateGrpc.newBlockingStub(channel);
return blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers));
WeaviateGrpc.WeaviateBlockingStub client = blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers));
return new GrpcClient(client, channel);
}

private static String getAddress(Config config) {
Expand Down
14 changes: 7 additions & 7 deletions src/main/java/io/weaviate/client/v1/batch/Batch.java
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
package io.weaviate.client.v1.batch;

import io.weaviate.client.Config;
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.util.BeaconPath;
import io.weaviate.client.base.util.DbVersionSupport;
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter;
import io.weaviate.client.v1.batch.api.ObjectsBatcher;
import io.weaviate.client.v1.batch.api.ReferencePayloadBuilder;
import io.weaviate.client.v1.batch.api.ReferencesBatcher;
import io.weaviate.client.v1.batch.util.ObjectsPath;
import io.weaviate.client.v1.batch.util.ReferencesPath;
import io.weaviate.client.Config;
import io.weaviate.client.v1.data.Data;

public class Batch {
private final Config config;
private final HttpClient httpClient;
private final WeaviateGrpc.WeaviateBlockingStub grpcClient;
private final AccessTokenProvider tokenProvider;
private final BeaconPath beaconPath;
private final ObjectsPath objectsPath;
private final ReferencesPath referencesPath;
private final Data data;

public Batch(HttpClient httpClient, Config config, DbVersionSupport dbVersionSupport,
WeaviateGrpc.WeaviateBlockingStub grpcClient, Data data) {
AccessTokenProvider tokenProvider, Data data) {
this.config = config;
this.httpClient = httpClient;
this.grpcClient = grpcClient;
this.tokenProvider = tokenProvider;
this.beaconPath = new BeaconPath(dbVersionSupport);
this.objectsPath = new ObjectsPath();
this.referencesPath = new ReferencesPath();
Expand All @@ -38,7 +38,7 @@ public ObjectsBatcher objectsBatcher() {
}

public ObjectsBatcher objectsBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig) {
return ObjectsBatcher.create(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig);
return ObjectsBatcher.create(httpClient, config, data, objectsPath, tokenProvider, batchRetriesConfig);
}

public ObjectsBatcher objectsAutoBatcher() {
Expand All @@ -64,7 +64,7 @@ public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.AutoBatchConfig autoBatc

public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig,
ObjectsBatcher.AutoBatchConfig autoBatchConfig) {
return ObjectsBatcher.createAuto(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig, autoBatchConfig);
return ObjectsBatcher.createAuto(httpClient, config, data, objectsPath, tokenProvider, batchRetriesConfig, autoBatchConfig);
}

public ObjectsBatchDeleter objectsBatchDeleter() {
Expand Down
69 changes: 38 additions & 31 deletions src/main/java/io/weaviate/client/v1/batch/api/ObjectsBatcher.java
Original file line number Diff line number Diff line change
@@ -1,39 +1,27 @@
package io.weaviate.client.v1.batch.api;

import io.weaviate.client.base.grpc.GrpcClient;
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
import io.weaviate.client.v1.batch.grpc.BatchObjectConverter;
import io.weaviate.client.v1.batch.model.ObjectGetResponse;
import io.weaviate.client.v1.batch.model.ObjectGetResponseStatus;
import io.weaviate.client.v1.batch.model.ObjectsBatchRequestBody;
import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result;
import io.weaviate.client.v1.batch.util.ObjectsPath;
import io.weaviate.client.v1.data.replication.model.ConsistencyLevel;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBase;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import io.weaviate.client.Config;
import io.weaviate.client.base.BaseClient;
import io.weaviate.client.base.ClientResult;
import io.weaviate.client.base.Response;
import io.weaviate.client.base.Result;
import io.weaviate.client.base.WeaviateErrorMessage;
import io.weaviate.client.base.WeaviateErrorResponse;
import io.weaviate.client.base.grpc.GrpcClient;
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.util.Assert;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBase;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import io.weaviate.client.v1.batch.grpc.BatchObjectConverter;
import io.weaviate.client.v1.batch.model.ObjectGetResponse;
import io.weaviate.client.v1.batch.model.ObjectGetResponseStatus;
import io.weaviate.client.v1.batch.model.ObjectsBatchRequestBody;
import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result;
import io.weaviate.client.v1.batch.util.ObjectsPath;
import io.weaviate.client.v1.data.Data;
import io.weaviate.client.v1.data.model.WeaviateObject;

import io.weaviate.client.v1.data.replication.model.ConsistencyLevel;
import java.io.Closeable;
import java.net.ConnectException;
import java.net.SocketTimeoutException;
Expand All @@ -53,6 +41,17 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;

public class ObjectsBatcher extends BaseClient<ObjectGetResponse[]>
implements ClientResult<ObjectGetResponse[]>, Closeable {
Expand All @@ -68,16 +67,18 @@ public class ObjectsBatcher extends BaseClient<ObjectGetResponse[]>
private final List<WeaviateObject> objects;
private String consistencyLevel;
private final List<CompletableFuture<Result<ObjectGetResponse[]>>> undoneFutures;
private final WeaviateGrpc.WeaviateBlockingStub grpcClient;
private final boolean useGRPC;
private final AccessTokenProvider tokenProvider;
private final Config config;


private ObjectsBatcher(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath,
WeaviateGrpc.WeaviateBlockingStub grpcClient,
AccessTokenProvider tokenProvider,
BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) {
super(httpClient, config);
this.config = config;
this.useGRPC = config.useGRPC();
this.grpcClient = grpcClient;
this.tokenProvider = tokenProvider;
this.data = data;
this.objectsPath = objectsPath;
this.objects = new ArrayList<>();
Expand All @@ -99,18 +100,18 @@ private ObjectsBatcher(HttpClient httpClient, Config config, Data data, ObjectsP
}

public static ObjectsBatcher create(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath,
WeaviateGrpc.WeaviateBlockingStub grpcClient,
AccessTokenProvider tokenProvider,
BatchRetriesConfig batchRetriesConfig) {
Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig");
return new ObjectsBatcher(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig, null);
return new ObjectsBatcher(httpClient, config, data, objectsPath, tokenProvider, batchRetriesConfig, null);
}

public static ObjectsBatcher createAuto(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath,
WeaviateGrpc.WeaviateBlockingStub grpcClient,
AccessTokenProvider tokenProvider,
BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) {
Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig");
Assert.requiredNotNull(autoBatchConfig, "autoBatchConfig");
return new ObjectsBatcher(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig, autoBatchConfig);
return new ObjectsBatcher(httpClient, config, data, objectsPath, tokenProvider, batchRetriesConfig, autoBatchConfig);
}


Expand Down Expand Up @@ -306,7 +307,13 @@ private Result<ObjectGetResponse[]> internalGrpcRun(List<WeaviateObject> batch)
}

WeaviateProtoBatch.BatchObjectsRequest batchObjectsRequest = batchObjectsRequestBuilder.build();
WeaviateProtoBatch.BatchObjectsReply batchObjectsReply = this.grpcClient.batchObjects(batchObjectsRequest);
WeaviateProtoBatch.BatchObjectsReply batchObjectsReply;
GrpcClient grpcClient = GrpcClient.create(this.config, this.tokenProvider);
try {
batchObjectsReply = grpcClient.batchObjects(batchObjectsRequest);
} finally {
grpcClient.shutdown();
}

List<WeaviateErrorMessage> weaviateErrorMessages = batchObjectsReply.getErrorsList().stream()
.map(WeaviateProtoBatch.BatchObjectsReply.BatchError::getError)
Expand Down

0 comments on commit 0bc8bf6

Please sign in to comment.