Skip to content

Commit

Permalink
Merge pull request #326 from weaviate/async-batch
Browse files Browse the repository at this point in the history
Implement batch async
  • Loading branch information
antas-marcin authored Nov 14, 2024
2 parents 03a2a13 + 77e10a8 commit 90a9549
Show file tree
Hide file tree
Showing 47 changed files with 2,018 additions and 721 deletions.
40 changes: 40 additions & 0 deletions src/main/java/io/weaviate/client/base/grpc/AsyncGrpcClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package io.weaviate.client.base.grpc;

import com.google.common.util.concurrent.ListenableFuture;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.stub.MetadataUtils;
import io.weaviate.client.Config;
import io.weaviate.client.base.grpc.base.BaseGrpcClient;
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import lombok.AccessLevel;
import lombok.experimental.FieldDefaults;

@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class AsyncGrpcClient extends BaseGrpcClient {
WeaviateGrpc.WeaviateFutureStub client;
ManagedChannel channel;

private AsyncGrpcClient(WeaviateGrpc.WeaviateFutureStub client, ManagedChannel channel) {
this.client = client;
this.channel = channel;
}

public ListenableFuture<WeaviateProtoBatch.BatchObjectsReply> batchObjects(WeaviateProtoBatch.BatchObjectsRequest request) {
return this.client.batchObjects(request);
}

public void shutdown() {
this.channel.shutdown();
}

public static AsyncGrpcClient create(Config config, AccessTokenProvider tokenProvider) {
Metadata headers = getHeaders(config, tokenProvider);
ManagedChannel channel = buildChannel(config);
WeaviateGrpc.WeaviateFutureStub stub = WeaviateGrpc.newFutureStub(channel);
WeaviateGrpc.WeaviateFutureStub client = stub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers));
return new AsyncGrpcClient(client, channel);
}
}
37 changes: 4 additions & 33 deletions src/main/java/io/weaviate/client/base/grpc/GrpcClient.java
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
package io.weaviate.client.base.grpc;

import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.stub.MetadataUtils;
import io.weaviate.client.Config;
import io.weaviate.client.base.grpc.base.BaseGrpcClient;
import io.weaviate.client.grpc.protocol.v1.WeaviateGrpc;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import java.util.Map;
import lombok.AccessLevel;
import lombok.experimental.FieldDefaults;

@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class GrpcClient {
public class GrpcClient extends BaseGrpcClient {
WeaviateGrpc.WeaviateBlockingStub client;
ManagedChannel channel;

Expand All @@ -31,38 +30,10 @@ public void shutdown() {
}

public static GrpcClient create(Config config, AccessTokenProvider tokenProvider) {
Metadata headers = new Metadata();
if (config.getHeaders() != null) {
for (Map.Entry<String, String> e : config.getHeaders().entrySet()) {
headers.put(Metadata.Key.of(e.getKey(), Metadata.ASCII_STRING_MARSHALLER), e.getValue());
}
}
if (tokenProvider != null) {
headers.put(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER), String.format("Bearer %s", tokenProvider.getAccessToken()));
}
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(getAddress(config));
if (config.isGRPCSecured()) {
channelBuilder = channelBuilder.useTransportSecurity();
} else {
channelBuilder.usePlaintext();
}
ManagedChannel channel = channelBuilder.build();
Metadata headers = getHeaders(config, tokenProvider);
ManagedChannel channel = buildChannel(config);
WeaviateGrpc.WeaviateBlockingStub blockingStub = WeaviateGrpc.newBlockingStub(channel);
WeaviateGrpc.WeaviateBlockingStub client = blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers));
return new GrpcClient(client, channel);
}

private static String getAddress(Config config) {
if (config.getGRPCHost() != null) {
String host = config.getGRPCHost();
if (host.contains(":")) {
return host;
}
if (config.isGRPCSecured()) {
return String.format("%s:443", host);
}
return String.format("%s:80", host);
}
return "";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package io.weaviate.client.base.grpc.base;

import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.weaviate.client.Config;
import io.weaviate.client.v1.auth.provider.AccessTokenProvider;
import java.util.Map;

public class BaseGrpcClient {

protected static Metadata getHeaders(Config config, AccessTokenProvider tokenProvider) {
Metadata headers = new Metadata();
if (config.getHeaders() != null) {
for (Map.Entry<String, String> e : config.getHeaders().entrySet()) {
headers.put(Metadata.Key.of(e.getKey(), Metadata.ASCII_STRING_MARSHALLER), e.getValue());
}
}
if (tokenProvider != null) {
headers.put(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER), String.format("Bearer %s", tokenProvider.getAccessToken()));
}
return headers;
}

protected static ManagedChannel buildChannel(Config config) {
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(getAddress(config));
if (config.isGRPCSecured()) {
channelBuilder = channelBuilder.useTransportSecurity();
} else {
channelBuilder.usePlaintext();
}
return channelBuilder.build();
}

private static String getAddress(Config config) {
if (config.getGRPCHost() != null) {
String host = config.getGRPCHost();
if (host.contains(":")) {
return host;
}
if (config.isGRPCSecured()) {
return String.format("%s:443", host);
}
return String.format("%s:80", host);
}
return "";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.weaviate.client.base.http.async.AsyncHttpClient;
import io.weaviate.client.base.util.DbVersionProvider;
import io.weaviate.client.base.util.DbVersionSupport;
import io.weaviate.client.v1.async.batch.Batch;
import io.weaviate.client.v1.async.classifications.Classifications;
import io.weaviate.client.v1.async.cluster.Cluster;
import io.weaviate.client.v1.async.data.Data;
Expand Down Expand Up @@ -42,6 +43,10 @@ public Data data() {
return new Data(client, config, dbVersionSupport);
}

public Batch batch() {
return new Batch(client, config, dbVersionSupport, data());
}

public Cluster cluster() {
return new Cluster(client, config);
}
Expand Down
104 changes: 104 additions & 0 deletions src/main/java/io/weaviate/client/v1/async/batch/Batch.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package io.weaviate.client.v1.async.batch;

import io.weaviate.client.Config;
import io.weaviate.client.base.util.BeaconPath;
import io.weaviate.client.base.util.DbVersionSupport;
import io.weaviate.client.v1.async.batch.api.ObjectsBatchDeleter;
import io.weaviate.client.v1.async.batch.api.ObjectsBatcher;
import io.weaviate.client.v1.async.data.Data;
import io.weaviate.client.v1.batch.api.ReferencePayloadBuilder;
import io.weaviate.client.v1.batch.util.ObjectsPath;
import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient;

public class Batch {
private final CloseableHttpAsyncClient client;
private final Config config;
private final ObjectsPath objectsPath;
private final BeaconPath beaconPath;
private final Data data;

public Batch(CloseableHttpAsyncClient client, Config config, DbVersionSupport dbVersionSupport, Data data) {
this.client = client;
this.config = config;
this.objectsPath = new ObjectsPath();
this.beaconPath = new BeaconPath(dbVersionSupport);
this.data = data;
}

public ObjectsBatcher objectsBatcher() {
return objectsBatcher(ObjectsBatcher.BatchRetriesConfig.defaultConfig().build());
}

public ObjectsBatcher objectsBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig) {
// TODO: add support for missing arguments
// return ObjectsBatcher.create(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig);
return ObjectsBatcher.create(client, config, data, objectsPath, null, null, batchRetriesConfig);
}

public ObjectsBatcher objectsAutoBatcher() {
return objectsAutoBatcher(
ObjectsBatcher.BatchRetriesConfig.defaultConfig().build(),
ObjectsBatcher.AutoBatchConfig.defaultConfig().build()
);
}

public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig) {
return objectsAutoBatcher(
batchRetriesConfig,
ObjectsBatcher.AutoBatchConfig.defaultConfig().build()
);
}

public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.AutoBatchConfig autoBatchConfig) {
return objectsAutoBatcher(
ObjectsBatcher.BatchRetriesConfig.defaultConfig().build(),
autoBatchConfig
);
}

public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig,
ObjectsBatcher.AutoBatchConfig autoBatchConfig) {
// TODO: add support for missing arguments
// return ObjectsBatcher.create(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig);
return ObjectsBatcher.createAuto(client, config, data, objectsPath, null, null, batchRetriesConfig, autoBatchConfig);
}

public ObjectsBatchDeleter objectsBatchDeleter() {
return new ObjectsBatchDeleter(client, config, objectsPath);
}

public ReferencePayloadBuilder referencePayloadBuilder() {
return new ReferencePayloadBuilder(beaconPath);
}

// TODO: implement async ReferencesBatcher
// public ReferencesBatcher referencesBatcher(ReferencesBatcher.BatchRetriesConfig batchRetriesConfig) {
// return ReferencesBatcher.create(httpClient, config, referencesPath, batchRetriesConfig);
// }
//
// public ReferencesBatcher referencesAutoBatcher() {
// return referencesAutoBatcher(
// ReferencesBatcher.BatchRetriesConfig.defaultConfig().build(),
// ReferencesBatcher.AutoBatchConfig.defaultConfig().build()
// );
// }
//
// public ReferencesBatcher referencesAutoBatcher(ReferencesBatcher.BatchRetriesConfig batchRetriesConfig) {
// return referencesAutoBatcher(
// batchRetriesConfig,
// ReferencesBatcher.AutoBatchConfig.defaultConfig().build()
// );
// }
//
// public ReferencesBatcher referencesAutoBatcher(ReferencesBatcher.AutoBatchConfig autoBatchConfig) {
// return referencesAutoBatcher(
// ReferencesBatcher.BatchRetriesConfig.defaultConfig().build(),
// autoBatchConfig
// );
// }
//
// public ReferencesBatcher referencesAutoBatcher(ReferencesBatcher.BatchRetriesConfig batchRetriesConfig,
// ReferencesBatcher.AutoBatchConfig autoBatchConfig) {
// return ReferencesBatcher.createAuto(httpClient, config, referencesPath, batchRetriesConfig, autoBatchConfig);
// }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package io.weaviate.client.v1.async.batch.api;

import io.weaviate.client.Config;
import io.weaviate.client.base.AsyncBaseClient;
import io.weaviate.client.base.AsyncClientResult;
import io.weaviate.client.base.Result;
import io.weaviate.client.v1.batch.model.BatchDeleteResponse;
import io.weaviate.client.v1.batch.util.ObjectsPath;
import io.weaviate.client.v1.filters.WhereFilter;
import java.util.concurrent.Future;
import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient;
import org.apache.hc.core5.concurrent.FutureCallback;

public class ObjectsBatchDeleter extends AsyncBaseClient<BatchDeleteResponse> implements AsyncClientResult<BatchDeleteResponse> {
private final ObjectsPath objectsPath;
private String className;
private String consistencyLevel;
private String tenant;
private WhereFilter where;
private String output;
private Boolean dryRun;

public ObjectsBatchDeleter(CloseableHttpAsyncClient client, Config config, ObjectsPath objectsPath) {
super(client, config);
this.objectsPath = objectsPath;
}

public ObjectsBatchDeleter withClassName(String className) {
this.className = className;
return this;
}

public ObjectsBatchDeleter withConsistencyLevel(String consistencyLevel) {
this.consistencyLevel = consistencyLevel;
return this;
}

public ObjectsBatchDeleter withTenant(String tenant) {
this.tenant = tenant;
return this;
}

public ObjectsBatchDeleter withWhere(WhereFilter where) {
this.where = where;
return this;
}

public ObjectsBatchDeleter withOutput(String output) {
this.output = output;
return this;
}

public ObjectsBatchDeleter withDryRun(Boolean dryRun) {
this.dryRun = dryRun;
return this;
}

@Override
public Future<Result<BatchDeleteResponse>> run() {
return run(null);
}

@Override
public Future<Result<BatchDeleteResponse>> run(FutureCallback<Result<BatchDeleteResponse>> callback) {
io.weaviate.client.v1.batch.api.ObjectsBatchDeleter.BatchDeleteMatch match = io.weaviate.client.v1.batch.api.ObjectsBatchDeleter.BatchDeleteMatch.builder()
.className(className)
.whereFilter(where)
.build();
io.weaviate.client.v1.batch.api.ObjectsBatchDeleter.BatchDelete batchDelete = io.weaviate.client.v1.batch.api.ObjectsBatchDeleter.BatchDelete.builder()
.dryRun(dryRun)
.output(output)
.match(match)
.build();
String path = objectsPath.buildDelete(ObjectsPath.Params.builder()
.consistencyLevel(consistencyLevel)
.tenant(tenant)
.build());
return sendDeleteRequest(path, batchDelete, BatchDeleteResponse.class, callback);
}
}
Loading

0 comments on commit 90a9549

Please sign in to comment.