Skip to content

Commit

Permalink
Merge pull request #325 from weaviate/async_classifications
Browse files Browse the repository at this point in the history
feature: async support for classifications package
  • Loading branch information
antas-marcin authored Nov 13, 2024
2 parents 0997183 + 0108749 commit 03a2a13
Show file tree
Hide file tree
Showing 10 changed files with 555 additions and 134 deletions.
8 changes: 6 additions & 2 deletions src/main/java/io/weaviate/client/base/AsyncClientResult.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package io.weaviate.client.base;

import java.util.concurrent.Future;
import org.apache.hc.core5.concurrent.FutureCallback;

import java.util.concurrent.Future;

public interface AsyncClientResult<T> {
Future<Result<T>> run();
default Future<Result<T>> run() {
return run(null);
}

Future<Result<T>> run(FutureCallback<Result<T>> callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.weaviate.client.base.http.async.AsyncHttpClient;
import io.weaviate.client.base.util.DbVersionProvider;
import io.weaviate.client.base.util.DbVersionSupport;
import io.weaviate.client.v1.async.classifications.Classifications;
import io.weaviate.client.v1.async.cluster.Cluster;
import io.weaviate.client.v1.async.data.Data;
import io.weaviate.client.v1.async.misc.Misc;
Expand Down Expand Up @@ -45,6 +46,10 @@ public Cluster cluster() {
return new Cluster(client, config);
}

public Classifications classifications() {
return new Classifications(client, config);
}

private DbVersionProvider initDbVersionProvider() {
DbVersionProvider.VersionGetter getter = () ->
Optional.ofNullable(this.getMeta())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.weaviate.client.v1.async.classifications;

import io.weaviate.client.Config;
import io.weaviate.client.v1.async.classifications.api.Getter;
import io.weaviate.client.v1.async.classifications.api.Scheduler;
import lombok.RequiredArgsConstructor;
import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient;

@RequiredArgsConstructor
public class Classifications {

private final CloseableHttpAsyncClient client;
private final Config config;


public Scheduler scheduler() {
return new Scheduler(client, config, getter());
}

public Getter getter() {
return new Getter(client, config);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package io.weaviate.client.v1.async.classifications.api;

import io.weaviate.client.Config;
import io.weaviate.client.base.AsyncBaseClient;
import io.weaviate.client.base.AsyncClientResult;
import io.weaviate.client.base.Result;
import io.weaviate.client.base.util.UrlEncoder;
import io.weaviate.client.v1.classifications.model.Classification;
import org.apache.commons.lang3.StringUtils;
import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient;
import org.apache.hc.core5.concurrent.FutureCallback;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;

public class Getter extends AsyncBaseClient<Classification> implements AsyncClientResult<Classification> {

private String id;

public Getter(CloseableHttpAsyncClient client, Config config) {
super(client, config);
}

public Getter withID(String id) {
this.id = id;
return this;
}

@Override
public Future<Result<Classification>> run(FutureCallback<Result<Classification>> callback) {
if (StringUtils.isBlank(id)) {
return CompletableFuture.completedFuture(null);
}
String path = String.format("/classifications/%s", UrlEncoder.encodePathParam(id));
return sendGetRequest(path, Classification.class, callback);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package io.weaviate.client.v1.async.classifications.api;

import io.weaviate.client.Config;
import io.weaviate.client.base.AsyncBaseClient;
import io.weaviate.client.base.AsyncClientResult;
import io.weaviate.client.base.Result;
import io.weaviate.client.base.http.async.ResponseParser;
import io.weaviate.client.v1.classifications.model.Classification;
import io.weaviate.client.v1.classifications.model.ClassificationFilters;
import io.weaviate.client.v1.filters.WhereFilter;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient;
import org.apache.hc.core5.concurrent.FutureCallback;
import org.apache.hc.core5.http.ContentType;
import org.apache.hc.core5.http.HttpResponse;
import org.apache.hc.core5.http.HttpStatus;

import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Future;

public class Scheduler extends AsyncBaseClient<Classification> implements AsyncClientResult<Classification> {

private String classificationType;
private String className;
private String[] classifyProperties;
private String[] basedOnProperties;
private WhereFilter sourceWhereFilter;
private WhereFilter trainingSetWhereFilter;
private WhereFilter targetWhereFilter;
private boolean waitForCompletion;
private Object settings;

private final Getter getter;

public Scheduler(CloseableHttpAsyncClient client, Config config, Getter getter) {
super(client, config);
this.getter = getter;
}

public Scheduler withType(String classificationType) {
this.classificationType = classificationType;
return this;
}

public Scheduler withClassName(String className) {
this.className = className;
return this;
}

public Scheduler withClassifyProperties(String[] classifyProperties) {
this.classifyProperties = classifyProperties;
return this;
}

public Scheduler withBasedOnProperties(String[] basedOnProperties) {
this.basedOnProperties = basedOnProperties;
return this;
}

public Scheduler withSourceWhereFilter(WhereFilter whereFilter) {
this.sourceWhereFilter = whereFilter;
return this;
}

public Scheduler withTrainingSetWhereFilter(WhereFilter whereFilter) {
this.trainingSetWhereFilter = whereFilter;
return this;
}

public Scheduler withTargetWhereFilter(WhereFilter whereFilter) {
this.targetWhereFilter = whereFilter;
return this;
}

public Scheduler withSettings(Object settings) {
this.settings = settings;
return this;
}

public Scheduler withWaitForCompletion() {
this.waitForCompletion = true;
return this;
}

@Override
public Future<Result<Classification>> run(FutureCallback<Result<Classification>> callback) {
Classification config = Classification.builder()
.basedOnProperties(basedOnProperties)
.className(className)
.classifyProperties(classifyProperties)
.type(classificationType)
.settings(settings)
.filters(getClassificationFilters(sourceWhereFilter, targetWhereFilter, trainingSetWhereFilter))
.build();

if (!waitForCompletion) {
return sendPostRequest("/classifications", config, Classification.class, callback);
}

CompletableFuture<Result<Classification>> future = new CompletableFuture<>();
FutureCallback<Result<Classification>> internalCallback = new FutureCallback<Result<Classification>>() {
@Override
public void completed(Result<Classification> classificationResult) {
future.complete(classificationResult);
}

@Override
public void failed(Exception e) {
future.completeExceptionally(e);
}

@Override
public void cancelled() {
future.cancel(true);
if (callback != null) {
callback.cancelled(); // TODO:AL propagate cancel() call from future to completable future
}
}
};

int[] httpCode = new int[1];
sendPostRequest("/classifications", config, internalCallback, new ResponseParser<Classification>() {
@Override
public Result<Classification> parse(HttpResponse response, String body, ContentType contentType) {
httpCode[0] = response.getCode();
return new Result<>(serializer.toResponse(response.getCode(), body, Classification.class));
}
});

return future.thenCompose(classificationResult -> {
if (httpCode[0] != HttpStatus.SC_CREATED) {
return CompletableFuture.completedFuture(classificationResult);
}
return getByIdRecursively(classificationResult.getResult().getId());
})
.whenComplete((classificationResult, throwable) -> {
if (callback != null) {
if (throwable != null) {
callback.failed((Exception) throwable);
} else {
callback.completed(classificationResult);
}
}
});
}

private CompletableFuture<Result<Classification>> getById(String id) {
CompletableFuture<Result<Classification>> future = new CompletableFuture<>();
getter.withID(id).run(new FutureCallback<Result<Classification>>() {
@Override
public void completed(Result<Classification> classificationResult) {
future.complete(classificationResult);
}

@Override
public void failed(Exception e) {
future.completeExceptionally(e);
}

@Override
public void cancelled() {
}
});
return future;
}

private CompletableFuture<Result<Classification>> getByIdRecursively(String id) {
return getById(id).thenCompose(classificationResult -> {
boolean isRunning = Optional.ofNullable(classificationResult)
.map(Result::getResult)
.map(Classification::getStatus)
.filter(status -> status.equals("running"))
.isPresent();

if (isRunning) {
try {
Thread.sleep(2000);
return getByIdRecursively(id);
} catch (InterruptedException e) {
throw new CompletionException(e);
}
}
return CompletableFuture.completedFuture(classificationResult);
});
}

private ClassificationFilters getClassificationFilters(WhereFilter sourceWhere, WhereFilter targetWhere, WhereFilter trainingSetWhere) {
if (ObjectUtils.anyNotNull(sourceWhere, targetWhere, trainingSetWhere)) {
return ClassificationFilters.builder()
.sourceWhere(sourceWhere)
.targetWhere(targetWhere)
.trainingSetWhere(trainingSetWhere)
.build();
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@

import io.weaviate.client.Config;
import io.weaviate.client.v1.async.cluster.api.NodesStatusGetter;
import lombok.RequiredArgsConstructor;
import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient;

@RequiredArgsConstructor
public class Cluster {

private final CloseableHttpAsyncClient client;
private final Config config;

public Cluster(CloseableHttpAsyncClient client, Config config) {
this.client = client;
this.config = config;
}

public NodesStatusGetter nodesStatusGetter() {
return new NodesStatusGetter(client, config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ public NodesStatusGetter withOutput(String output) {
return this;
}

@Override
public Future<Result<NodesStatusResponse>> run() {
return run(null);
}

@Override
public Future<Result<NodesStatusResponse>> run(FutureCallback<Result<NodesStatusResponse>> callback) {
return sendGetRequest(path(), NodesStatusResponse.class, callback);
Expand Down
Loading

0 comments on commit 03a2a13

Please sign in to comment.