Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gRPC channel shutdown error #238

Merged
merged 1 commit into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions src/main/java/io/weaviate/client/WeaviateClient.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package io.weaviate.client;

import io.weaviate.client.base.grpc.GrpcClient;
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.http.builder.HttpApacheClientBuilder;
import io.weaviate.client.base.http.impl.CommonsHttpClientImpl;
import io.weaviate.client.base.util.DbVersionProvider;
import io.weaviate.client.base.util.DbVersionSupport;
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import io.weaviate.client.v1.backup.Backup;
import io.weaviate.client.v1.batch.Batch;
Expand All @@ -18,16 +16,14 @@
import io.weaviate.client.v1.misc.Misc;
import io.weaviate.client.v1.misc.api.MetaGetter;
import io.weaviate.client.v1.schema.Schema;
import java.util.Calendar;
import java.util.Date;
import java.util.Optional;

public class WeaviateClient {
private final Config config;
private final DbVersionProvider dbVersionProvider;
private final DbVersionSupport dbVersionSupport;
private final HttpClient httpClient;
private final WeaviateGrpc.WeaviateBlockingStub grpcClient;
private final AccessTokenProvider tokenProvider;

public WeaviateClient(Config config) {
this(config, new CommonsHttpClientImpl(config.getHeaders(), null, HttpApacheClientBuilder.build(config)), null);
Expand All @@ -42,11 +38,7 @@ public WeaviateClient(Config config, HttpClient httpClient, AccessTokenProvider
this.httpClient = httpClient;
dbVersionProvider = initDbVersionProvider();
dbVersionSupport = new DbVersionSupport(dbVersionProvider);
if (this.config.useGRPC()) {
this.grpcClient = GrpcClient.create(config, tokenProvider);
} else {
this.grpcClient = null;
}
this.tokenProvider = tokenProvider;
}

public Misc misc() {
Expand All @@ -64,7 +56,7 @@ public Data data() {

public Batch batch() {
dbVersionProvider.refresh();
return new Batch(httpClient, config, dbVersionSupport, grpcClient, data());
return new Batch(httpClient, config, dbVersionSupport, tokenProvider, data());
}

public Backup backup() {
Expand Down
24 changes: 22 additions & 2 deletions src/main/java/io/weaviate/client/base/grpc/GrpcClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,31 @@
import io.grpc.stub.MetadataUtils;
import io.weaviate.client.Config;
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import java.util.Map;
import lombok.AccessLevel;
import lombok.experimental.FieldDefaults;

@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class GrpcClient {
WeaviateGrpc.WeaviateBlockingStub client;
ManagedChannel channel;

public static WeaviateGrpc.WeaviateBlockingStub create(Config config, AccessTokenProvider tokenProvider) {
private GrpcClient(WeaviateGrpc.WeaviateBlockingStub client, ManagedChannel channel) {
this.client = client;
this.channel = channel;
}

public WeaviateProtoBatch.BatchObjectsReply batchObjects(WeaviateProtoBatch.BatchObjectsRequest request) {
return this.client.batchObjects(request);
}

public void shutdown() {
this.channel.shutdown();
}

public static GrpcClient create(Config config, AccessTokenProvider tokenProvider) {
Metadata headers = new Metadata();
if (config.getHeaders() != null) {
for (Map.Entry<String, String> e : config.getHeaders().entrySet()) {
Expand All @@ -29,7 +48,8 @@ public static WeaviateGrpc.WeaviateBlockingStub create(Config config, AccessToke
}
ManagedChannel channel = channelBuilder.build();
WeaviateGrpc.WeaviateBlockingStub blockingStub = WeaviateGrpc.newBlockingStub(channel);
return blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers));
WeaviateGrpc.WeaviateBlockingStub client = blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers));
return new GrpcClient(client, channel);
}

private static String getAddress(Config config) {
Expand Down
14 changes: 7 additions & 7 deletions src/main/java/io/weaviate/client/v1/batch/Batch.java
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
package io.weaviate.client.v1.batch;

import io.weaviate.client.Config;
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.util.BeaconPath;
import io.weaviate.client.base.util.DbVersionSupport;
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import io.weaviate.client.v1.batch.api.ObjectsBatchDeleter;
import io.weaviate.client.v1.batch.api.ObjectsBatcher;
import io.weaviate.client.v1.batch.api.ReferencePayloadBuilder;
import io.weaviate.client.v1.batch.api.ReferencesBatcher;
import io.weaviate.client.v1.batch.util.ObjectsPath;
import io.weaviate.client.v1.batch.util.ReferencesPath;
import io.weaviate.client.Config;
import io.weaviate.client.v1.data.Data;

public class Batch {
private final Config config;
private final HttpClient httpClient;
private final WeaviateGrpc.WeaviateBlockingStub grpcClient;
private final AccessTokenProvider tokenProvider;
private final BeaconPath beaconPath;
private final ObjectsPath objectsPath;
private final ReferencesPath referencesPath;
private final Data data;

public Batch(HttpClient httpClient, Config config, DbVersionSupport dbVersionSupport,
WeaviateGrpc.WeaviateBlockingStub grpcClient, Data data) {
AccessTokenProvider tokenProvider, Data data) {
this.config = config;
this.httpClient = httpClient;
this.grpcClient = grpcClient;
this.tokenProvider = tokenProvider;
this.beaconPath = new BeaconPath(dbVersionSupport);
this.objectsPath = new ObjectsPath();
this.referencesPath = new ReferencesPath();
Expand All @@ -38,7 +38,7 @@ public ObjectsBatcher objectsBatcher() {
}

public ObjectsBatcher objectsBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig) {
return ObjectsBatcher.create(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig);
return ObjectsBatcher.create(httpClient, config, data, objectsPath, tokenProvider, batchRetriesConfig);
}

public ObjectsBatcher objectsAutoBatcher() {
Expand All @@ -64,7 +64,7 @@ public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.AutoBatchConfig autoBatc

public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig,
ObjectsBatcher.AutoBatchConfig autoBatchConfig) {
return ObjectsBatcher.createAuto(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig, autoBatchConfig);
return ObjectsBatcher.createAuto(httpClient, config, data, objectsPath, tokenProvider, batchRetriesConfig, autoBatchConfig);
}

public ObjectsBatchDeleter objectsBatchDeleter() {
Expand Down
69 changes: 38 additions & 31 deletions src/main/java/io/weaviate/client/v1/batch/api/ObjectsBatcher.java
Original file line number Diff line number Diff line change
@@ -1,39 +1,27 @@
package io.weaviate.client.v1.batch.api;

import io.weaviate.client.base.grpc.GrpcClient;
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
import io.weaviate.client.v1.batch.grpc.BatchObjectConverter;
import io.weaviate.client.v1.batch.model.ObjectGetResponse;
import io.weaviate.client.v1.batch.model.ObjectGetResponseStatus;
import io.weaviate.client.v1.batch.model.ObjectsBatchRequestBody;
import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result;
import io.weaviate.client.v1.batch.util.ObjectsPath;
import io.weaviate.client.v1.data.replication.model.ConsistencyLevel;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBase;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import io.weaviate.client.Config;
import io.weaviate.client.base.BaseClient;
import io.weaviate.client.base.ClientResult;
import io.weaviate.client.base.Response;
import io.weaviate.client.base.Result;
import io.weaviate.client.base.WeaviateErrorMessage;
import io.weaviate.client.base.WeaviateErrorResponse;
import io.weaviate.client.base.grpc.GrpcClient;
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.util.Assert;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBase;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import io.weaviate.client.v1.batch.grpc.BatchObjectConverter;
import io.weaviate.client.v1.batch.model.ObjectGetResponse;
import io.weaviate.client.v1.batch.model.ObjectGetResponseStatus;
import io.weaviate.client.v1.batch.model.ObjectsBatchRequestBody;
import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result;
import io.weaviate.client.v1.batch.util.ObjectsPath;
import io.weaviate.client.v1.data.Data;
import io.weaviate.client.v1.data.model.WeaviateObject;

import io.weaviate.client.v1.data.replication.model.ConsistencyLevel;
import java.io.Closeable;
import java.net.ConnectException;
import java.net.SocketTimeoutException;
Expand All @@ -53,6 +41,17 @@
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;

public class ObjectsBatcher extends BaseClient<ObjectGetResponse[]>
implements ClientResult<ObjectGetResponse[]>, Closeable {
Expand All @@ -68,16 +67,18 @@ public class ObjectsBatcher extends BaseClient<ObjectGetResponse[]>
private final List<WeaviateObject> objects;
private String consistencyLevel;
private final List<CompletableFuture<Result<ObjectGetResponse[]>>> undoneFutures;
private final WeaviateGrpc.WeaviateBlockingStub grpcClient;
private final boolean useGRPC;
private final AccessTokenProvider tokenProvider;
private final Config config;


private ObjectsBatcher(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath,
WeaviateGrpc.WeaviateBlockingStub grpcClient,
AccessTokenProvider tokenProvider,
BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) {
super(httpClient, config);
this.config = config;
this.useGRPC = config.useGRPC();
this.grpcClient = grpcClient;
this.tokenProvider = tokenProvider;
this.data = data;
this.objectsPath = objectsPath;
this.objects = new ArrayList<>();
Expand All @@ -99,18 +100,18 @@ private ObjectsBatcher(HttpClient httpClient, Config config, Data data, ObjectsP
}

public static ObjectsBatcher create(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath,
WeaviateGrpc.WeaviateBlockingStub grpcClient,
AccessTokenProvider tokenProvider,
BatchRetriesConfig batchRetriesConfig) {
Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig");
return new ObjectsBatcher(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig, null);
return new ObjectsBatcher(httpClient, config, data, objectsPath, tokenProvider, batchRetriesConfig, null);
}

public static ObjectsBatcher createAuto(HttpClient httpClient, Config config, Data data, ObjectsPath objectsPath,
WeaviateGrpc.WeaviateBlockingStub grpcClient,
AccessTokenProvider tokenProvider,
BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) {
Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig");
Assert.requiredNotNull(autoBatchConfig, "autoBatchConfig");
return new ObjectsBatcher(httpClient, config, data, objectsPath, grpcClient, batchRetriesConfig, autoBatchConfig);
return new ObjectsBatcher(httpClient, config, data, objectsPath, tokenProvider, batchRetriesConfig, autoBatchConfig);
}


Expand Down Expand Up @@ -306,7 +307,13 @@ private Result<ObjectGetResponse[]> internalGrpcRun(List<WeaviateObject> batch)
}

WeaviateProtoBatch.BatchObjectsRequest batchObjectsRequest = batchObjectsRequestBuilder.build();
WeaviateProtoBatch.BatchObjectsReply batchObjectsReply = this.grpcClient.batchObjects(batchObjectsRequest);
WeaviateProtoBatch.BatchObjectsReply batchObjectsReply;
GrpcClient grpcClient = GrpcClient.create(this.config, this.tokenProvider);
try {
batchObjectsReply = grpcClient.batchObjects(batchObjectsRequest);
} finally {
grpcClient.shutdown();
}

List<WeaviateErrorMessage> weaviateErrorMessages = batchObjectsReply.getErrorsList().stream()
.map(WeaviateProtoBatch.BatchObjectsReply.BatchError::getError)
Expand Down