diff --git a/src/main/java/io/weaviate/client/base/AsyncBaseClient.java b/src/main/java/io/weaviate/client/base/AsyncBaseClient.java index 743d0460..bfb0b66d 100644 --- a/src/main/java/io/weaviate/client/base/AsyncBaseClient.java +++ b/src/main/java/io/weaviate/client/base/AsyncBaseClient.java @@ -12,7 +12,7 @@ import org.apache.hc.core5.http.HttpHeaders; public abstract class AsyncBaseClient { - private final CloseableHttpAsyncClient client; + protected final CloseableHttpAsyncClient client; private final Config config; private final Serializer serializer; @@ -71,12 +71,16 @@ protected Future> sendHeadRequest(String endpoint, FutureCallback> sendRequest(String endpoint, Object payload, String method, Class classOfT, FutureCallback> callback, ResponseParser parser) { + return client.execute(SimpleRequestProducer.create(getRequest(endpoint, payload, method)), new WeaviateResponseConsumer<>(classOfT, parser), callback); + } + + protected SimpleHttpRequest getRequest(String endpoint, Object payload, String method) { SimpleHttpRequest req = new SimpleHttpRequest(method, String.format("%s%s", config.getBaseURL(), endpoint)); req.addHeader(HttpHeaders.ACCEPT, "*/*"); req.addHeader(HttpHeaders.CONTENT_TYPE, "application/json"); if (payload != null) { req.setBody(serializer.toJsonString(payload), ContentType.APPLICATION_JSON); } - return client.execute(SimpleRequestProducer.create(req), new WeaviateResponseConsumer<>(classOfT, parser), callback); + return req; } } diff --git a/src/main/java/io/weaviate/client/base/AsyncBaseGraphQLClient.java b/src/main/java/io/weaviate/client/base/AsyncBaseGraphQLClient.java new file mode 100644 index 00000000..d6de2a7b --- /dev/null +++ b/src/main/java/io/weaviate/client/base/AsyncBaseGraphQLClient.java @@ -0,0 +1,20 @@ +package io.weaviate.client.base; + +import io.weaviate.client.Config; +import io.weaviate.client.base.http.async.WeaviateGraphQLTypedResponseConsumer; +import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse; +import java.util.concurrent.Future; +import org.apache.hc.client5.http.async.methods.SimpleRequestProducer; +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import org.apache.hc.core5.concurrent.FutureCallback; + +public class AsyncBaseGraphQLClient extends AsyncBaseClient { + public AsyncBaseGraphQLClient(CloseableHttpAsyncClient client, Config config) { + super(client, config); + } + + protected Future>> sendGraphQLTypedRequest(Object payload, Class classOfC, + FutureCallback>> callback) { + return client.execute(SimpleRequestProducer.create(getRequest("/graphql", payload, "POST")), new WeaviateGraphQLTypedResponseConsumer<>(classOfC), callback); + } +} diff --git a/src/main/java/io/weaviate/client/base/BaseClient.java b/src/main/java/io/weaviate/client/base/BaseClient.java index cabc57c1..81cd6ed9 100644 --- a/src/main/java/io/weaviate/client/base/BaseClient.java +++ b/src/main/java/io/weaviate/client/base/BaseClient.java @@ -8,7 +8,7 @@ public abstract class BaseClient { private final HttpClient client; private final Config config; - private final Serializer serializer; + protected final Serializer serializer; public BaseClient(HttpClient client, Config config) { this.config = config; @@ -42,9 +42,7 @@ protected Response sendHeadRequest(String endpoint, Class classOfT) { private Response sendRequest(String endpoint, Object payload, String method, Class classOfT) { try { - String url = config.getBaseURL() + endpoint; - String json = toJsonString(payload); - HttpResponse response = this.sendHttpRequest(url, json, method); + HttpResponse response = this.sendHttpRequest(endpoint, payload, method); int statusCode = response.getStatusCode(); String responseBody = response.getBody(); @@ -61,7 +59,9 @@ private Response sendRequest(String endpoint, Object payload, String method, } } - private HttpResponse sendHttpRequest(String address, String json, String method) throws Exception { + protected HttpResponse sendHttpRequest(String endpoint, Object payload, String method) throws Exception { + String address = config.getBaseURL() + endpoint; + String json = toJsonString(payload); if (method.equals("POST")) { return client.sendPostRequest(address, json); } @@ -80,15 +80,15 @@ private HttpResponse sendHttpRequest(String address, String json, String method) return client.sendGetRequest(address); } - private C toResponse(String response, Class classOfT) { - return serializer.toObject(response, classOfT); + protected C toResponse(String response, Class classOfT) { + return serializer.toResponse(response, classOfT); } private String toJsonString(Object object) { return serializer.toJsonString(object); } - private WeaviateErrorResponse getWeaviateErrorResponse(Exception e) { + protected WeaviateErrorResponse getWeaviateErrorResponse(Exception e) { WeaviateErrorMessage error = WeaviateErrorMessage.builder().message(e.getMessage()).throwable(e).build(); return WeaviateErrorResponse.builder().error(Collections.singletonList(error)).build(); } diff --git a/src/main/java/io/weaviate/client/base/BaseGraphQLClient.java b/src/main/java/io/weaviate/client/base/BaseGraphQLClient.java new file mode 100644 index 00000000..15eda633 --- /dev/null +++ b/src/main/java/io/weaviate/client/base/BaseGraphQLClient.java @@ -0,0 +1,35 @@ +package io.weaviate.client.base; + +import io.weaviate.client.Config; +import io.weaviate.client.base.http.HttpClient; +import io.weaviate.client.base.http.HttpResponse; +import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse; + +public abstract class BaseGraphQLClient extends BaseClient { + public BaseGraphQLClient(HttpClient client, Config config) { + super(client, config); + } + + private GraphQLTypedResponse toResponseTyped(String response, Class classOfC) { + return serializer.toGraphQLTypedResponse(response, classOfC); + } + + protected Response> sendGraphQLTypedRequest(Object payload, Class classOfC) { + try { + HttpResponse response = this.sendHttpRequest("/graphql", payload, "POST"); + int statusCode = response.getStatusCode(); + String responseBody = response.getBody(); + + if (statusCode < 399) { + GraphQLTypedResponse body = toResponseTyped(responseBody, classOfC); + return new Response<>(statusCode, body, null); + } + + WeaviateErrorResponse error = toResponse(responseBody, WeaviateErrorResponse.class); + return new Response<>(statusCode, null, error); + } catch (Exception e) { + WeaviateErrorResponse errors = getWeaviateErrorResponse(e); + return new Response<>(0, null, errors); + } + } +} diff --git a/src/main/java/io/weaviate/client/base/Serializer.java b/src/main/java/io/weaviate/client/base/Serializer.java index d6c8ae45..a1df177a 100644 --- a/src/main/java/io/weaviate/client/base/Serializer.java +++ b/src/main/java/io/weaviate/client/base/Serializer.java @@ -2,6 +2,11 @@ import com.google.gson.Gson; import com.google.gson.GsonBuilder; +import com.google.gson.reflect.TypeToken; +import io.weaviate.client.base.util.GroupHitDeserializer; +import io.weaviate.client.v1.graphql.model.GraphQLGetBaseObject; +import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse; +import java.lang.reflect.Type; public class Serializer { private Gson gson; @@ -10,7 +15,20 @@ public Serializer() { this.gson = new GsonBuilder().disableHtmlEscaping().create(); } - public T toObject(String response, Class classOfT) { + public GraphQLTypedResponse toGraphQLTypedResponse(String response, Class classOfT) { + Gson gsonTyped = new GsonBuilder() + .disableHtmlEscaping() + .registerTypeAdapter(GraphQLGetBaseObject.Additional.Group.GroupHit.class, new GroupHitDeserializer()) + .create(); + return gsonTyped.fromJson(response, + TypeToken.getParameterized(GraphQLTypedResponse.class, classOfT).getType()); + } + + public C toResponse(String response, Type typeOfT) { + return gson.fromJson(response, typeOfT); + } + + public T toResponse(String response, Class classOfT) { return gson.fromJson(response, classOfT); } @@ -27,13 +45,28 @@ public Result toResult(int statusCode, String body, Class classOfT) { public Response toResponse(int statusCode, String body, Class classOfT) { if (statusCode < 399) { - T obj = toObject(body, classOfT); + T obj = toResponse(body, classOfT); return new Response<>(statusCode, obj, null); } return new Response<>(statusCode, null, toWeaviateError(body)); } + public Response> toGraphQLTypedResponse(int statusCode, String body, Class classOfC) { + if (statusCode < 399) { + GraphQLTypedResponse obj = toGraphQLTypedResponse(body, classOfC); + return new Response<>(statusCode, obj, null); + } + return new Response<>(statusCode, null, toWeaviateError(body)); + } + + public Result> toGraphQLTypedResult(int statusCode, String body, Class classOfC) { + if (statusCode < 399) { + return new Result<>(toGraphQLTypedResponse(statusCode, body, classOfC)); + } + return new Result<>(statusCode, null, toWeaviateError(body)); + } + public WeaviateErrorResponse toWeaviateError(String body) { - return toObject(body, WeaviateErrorResponse.class); + return toResponse(body, WeaviateErrorResponse.class); } } diff --git a/src/main/java/io/weaviate/client/base/http/async/WeaviateGraphQLTypedResponseConsumer.java b/src/main/java/io/weaviate/client/base/http/async/WeaviateGraphQLTypedResponseConsumer.java new file mode 100644 index 00000000..5de7a8f2 --- /dev/null +++ b/src/main/java/io/weaviate/client/base/http/async/WeaviateGraphQLTypedResponseConsumer.java @@ -0,0 +1,34 @@ +package io.weaviate.client.base.http.async; + +import io.weaviate.client.base.Result; +import io.weaviate.client.base.Serializer; +import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.HttpException; +import org.apache.hc.core5.http.HttpResponse; +import org.apache.hc.core5.http.nio.entity.BasicAsyncEntityConsumer; +import org.apache.hc.core5.http.nio.support.AbstractAsyncResponseConsumer; +import org.apache.hc.core5.http.protocol.HttpContext; + +public class WeaviateGraphQLTypedResponseConsumer extends AbstractAsyncResponseConsumer>, byte[]> { + private final Serializer serializer; + private final Class classOfT; + + public WeaviateGraphQLTypedResponseConsumer(Class classOfT) { + super(new BasicAsyncEntityConsumer()); + this.serializer = new Serializer(); + this.classOfT = classOfT; + } + + @Override + protected Result> buildResult(HttpResponse response, byte[] entity, ContentType contentType) { + String body = (entity != null) ? new String(entity, StandardCharsets.UTF_8) : ""; + return serializer.toGraphQLTypedResult(response.getCode(), body, classOfT); + } + + @Override + public void informationResponse(HttpResponse response, HttpContext context) throws HttpException, IOException { + } +} diff --git a/src/main/java/io/weaviate/client/base/util/GroupHitDeserializer.java b/src/main/java/io/weaviate/client/base/util/GroupHitDeserializer.java new file mode 100644 index 00000000..270bef84 --- /dev/null +++ b/src/main/java/io/weaviate/client/base/util/GroupHitDeserializer.java @@ -0,0 +1,32 @@ +package io.weaviate.client.base.util; + +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonDeserializer; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParseException; +import com.google.gson.reflect.TypeToken; +import io.weaviate.client.v1.graphql.model.GraphQLGetBaseObject; +import java.lang.reflect.Type; +import java.util.Map; + +public class GroupHitDeserializer implements JsonDeserializer { + + @Override + public GraphQLGetBaseObject.Additional.Group.GroupHit deserialize(JsonElement json, Type typeOfT, + JsonDeserializationContext context) throws JsonParseException { + JsonObject jsonObject = json.getAsJsonObject(); + + GraphQLGetBaseObject.Additional.Group.GroupHit.AdditionalGroupHit additional = + context.deserialize(jsonObject.get("_additional"), GraphQLGetBaseObject.Additional.Group.GroupHit.AdditionalGroupHit.class); + + // Remove _additional from the JSON object + jsonObject.remove("_additional"); + + // Deserialize the rest into a Map + Type mapType = new TypeToken>() {}.getType(); + Map properties = context.deserialize(jsonObject, mapType); + + return new GraphQLGetBaseObject.Additional.Group.GroupHit(properties, additional); + } +} diff --git a/src/main/java/io/weaviate/client/v1/async/graphql/api/Get.java b/src/main/java/io/weaviate/client/v1/async/graphql/api/Get.java index f7daf26c..ac330197 100644 --- a/src/main/java/io/weaviate/client/v1/async/graphql/api/Get.java +++ b/src/main/java/io/weaviate/client/v1/async/graphql/api/Get.java @@ -1,23 +1,40 @@ package io.weaviate.client.v1.async.graphql.api; import io.weaviate.client.Config; -import io.weaviate.client.base.AsyncBaseClient; +import io.weaviate.client.base.AsyncBaseGraphQLClient; import io.weaviate.client.base.AsyncClientResult; import io.weaviate.client.base.Result; import io.weaviate.client.v1.filters.WhereFilter; +import io.weaviate.client.v1.graphql.model.GraphQLGetBaseObject; import io.weaviate.client.v1.graphql.model.GraphQLQuery; import io.weaviate.client.v1.graphql.model.GraphQLResponse; -import io.weaviate.client.v1.graphql.query.argument.*; +import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse; +import io.weaviate.client.v1.graphql.query.argument.AskArgument; +import io.weaviate.client.v1.graphql.query.argument.Bm25Argument; +import io.weaviate.client.v1.graphql.query.argument.GroupArgument; +import io.weaviate.client.v1.graphql.query.argument.GroupByArgument; +import io.weaviate.client.v1.graphql.query.argument.HybridArgument; +import io.weaviate.client.v1.graphql.query.argument.NearAudioArgument; +import io.weaviate.client.v1.graphql.query.argument.NearDepthArgument; +import io.weaviate.client.v1.graphql.query.argument.NearImageArgument; +import io.weaviate.client.v1.graphql.query.argument.NearImuArgument; +import io.weaviate.client.v1.graphql.query.argument.NearObjectArgument; +import io.weaviate.client.v1.graphql.query.argument.NearTextArgument; +import io.weaviate.client.v1.graphql.query.argument.NearThermalArgument; +import io.weaviate.client.v1.graphql.query.argument.NearVectorArgument; +import io.weaviate.client.v1.graphql.query.argument.NearVideoArgument; +import io.weaviate.client.v1.graphql.query.argument.SortArgument; +import io.weaviate.client.v1.graphql.query.argument.SortArguments; +import io.weaviate.client.v1.graphql.query.argument.WhereArgument; import io.weaviate.client.v1.graphql.query.builder.GetBuilder; import io.weaviate.client.v1.graphql.query.fields.Field; import io.weaviate.client.v1.graphql.query.fields.Fields; import io.weaviate.client.v1.graphql.query.fields.GenerativeSearchBuilder; +import java.util.concurrent.Future; import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; import org.apache.hc.core5.concurrent.FutureCallback; -import java.util.concurrent.Future; - -public class Get extends AsyncBaseClient implements AsyncClientResult { +public class Get extends AsyncBaseGraphQLClient implements AsyncClientResult { private final GetBuilder.GetBuilderBuilder getBuilder; public Get(CloseableHttpAsyncClient client, Config config) { @@ -161,13 +178,99 @@ public Get withAutocut(Integer autocut) { return this; } - @Override - public Future> run(FutureCallback> callback) { + private GraphQLQuery getQuery() { String getQuery = getBuilder.build() .buildQuery(); - GraphQLQuery query = GraphQLQuery.builder() + return GraphQLQuery.builder() .query(getQuery) .build(); - return sendPostRequest("/graphql", query, GraphQLResponse.class, callback); + } + + @Override + public Future> run(FutureCallback> callback) { + return sendPostRequest("/graphql", getQuery(), GraphQLResponse.class, callback); + } + + /** + * This method provides a better way of serializing a GraphQL response using one's defined classes. + * Example: + * In Weaviate we have defined collection named Soup with name and price properties. + * For client to be able to properly serialize GraphQL response to an Object with + * convenient methods accessing GraphQL settings one can create a class, example: + *
{@code
+   * import com.google.gson.annotations.SerializedName;
+   *
+   * public class Soups {
+   *   {@literal @}SerializedName(value = "Soup")
+   *   List soups;
+   *
+   *   public List getSoups() {
+   *     return soups;
+   *   }
+   *
+   *   public static class Soup extends GraphQLGetBaseObject {
+   *     String name;
+   *     Float price;
+   *
+   *     public String getName() {
+   *       return name;
+   *     }
+   *
+   *     public Float getPrice() {
+   *       return price;
+   *     }
+   *   }
+   * }
+   * }
+ * + * @param classOfC - class describing Weaviate object, example: Soups class + * @param - Class of C + * @return Result of GraphQLTypedResponse of a given class + * @see GraphQLGetBaseObject + */ + public Future>> run(final Class classOfC) { + return run(classOfC, null); + } + + /** + * This method provides a better way of serializing a GraphQL response using one's defined classes. + * Example: + * In Weaviate we have defined collection named Soup with name and price properties. + * For client to be able to properly serialize GraphQL response to an Object with + * convenient methods accessing GraphQL settings one can create a class, example: + *
{@code
+   * import com.google.gson.annotations.SerializedName;
+   *
+   * public class Soups {
+   *   {@literal @}SerializedName(value = "Soup")
+   *   List soups;
+   *
+   *   public List getSoups() {
+   *     return soups;
+   *   }
+   *
+   *   public static class Soup extends GraphQLGetBaseObject {
+   *     String name;
+   *     Float price;
+   *
+   *     public String getName() {
+   *       return name;
+   *     }
+   *
+   *     public Float getPrice() {
+   *       return price;
+   *     }
+   *   }
+   * }
+   * }
+ * + * @param classOfC - class describing Weaviate object, example: Soups class + * @param callback - Result of GraphQLTypedResponse of a given class callback + * @param - Class of C + * @return Result of GraphQLTypedResponse of a given class + * @see GraphQLGetBaseObject + */ + public Future>> run(final Class classOfC, FutureCallback>> callback) { + return sendGraphQLTypedRequest(getQuery(), classOfC, callback); } } diff --git a/src/main/java/io/weaviate/client/v1/auth/nimbus/BaseAuth.java b/src/main/java/io/weaviate/client/v1/auth/nimbus/BaseAuth.java index c84d326d..aea3b353 100644 --- a/src/main/java/io/weaviate/client/v1/auth/nimbus/BaseAuth.java +++ b/src/main/java/io/weaviate/client/v1/auth/nimbus/BaseAuth.java @@ -44,7 +44,7 @@ public AuthResponse getIdAndTokenEndpoint(Config config) throws AuthException { log(msg); throw new AuthException(msg); case 200: - OIDCConfig oidcConfig = serializer.toObject(response.getBody(), OIDCConfig.class); + OIDCConfig oidcConfig = serializer.toResponse(response.getBody(), OIDCConfig.class); HttpResponse resp = sendGetRequest(client, oidcConfig.getHref()); if (resp.getStatusCode() != 200) { String errorMessage = String.format("OIDC configuration url %s returned status code %s", oidcConfig.getHref(), resp.getStatusCode()); diff --git a/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLGetBaseObject.java b/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLGetBaseObject.java new file mode 100644 index 00000000..8016050d --- /dev/null +++ b/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLGetBaseObject.java @@ -0,0 +1,72 @@ +package io.weaviate.client.v1.graphql.model; + +import com.google.gson.annotations.SerializedName; +import java.util.List; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Getter; + +@Getter +public class GraphQLGetBaseObject { + @SerializedName(value = "_additional") + Additional additional; + + @Getter + public static class Additional { + String id; + Float certainty; + Float distance; + String creationTimeUnix; + String lastUpdateTimeUnix; + String explainScore; + String score; + Float[] vector; + Map vectors; + Generate generate; + Group group; + + @Getter + public static class Generate { + String singleResult; + String groupedResult; + String error; + Debug debug; + + @Getter + public static class Debug { + String prompt; + } + } + + @Getter + public static class Group { + public String id; + public GroupedBy groupedBy; + public Integer count; + public Float maxDistance; + public Float minDistance; + public List hits; + + @Getter + public static class GroupedBy { + public String value; + public String[] path; + } + + @Getter + @AllArgsConstructor + public static class GroupHit { + @SerializedName("properties") + Map properties; + @SerializedName(value = "_additional") + AdditionalGroupHit additional; + + @Getter + public static class AdditionalGroupHit { + String id; + Float distance; + } + } + } + } +} diff --git a/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLResponse.java b/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLResponse.java index 68b173ef..5c8feab0 100644 --- a/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLResponse.java +++ b/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLResponse.java @@ -16,8 +16,8 @@ @ToString @EqualsAndHashCode @FieldDefaults(level = AccessLevel.PRIVATE) -public class GraphQLResponse { - Object data; +public class GraphQLResponse { + T data; GraphQLError[] errors; diff --git a/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLTypedResponse.java b/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLTypedResponse.java new file mode 100644 index 00000000..9af18310 --- /dev/null +++ b/src/main/java/io/weaviate/client/v1/graphql/model/GraphQLTypedResponse.java @@ -0,0 +1,32 @@ +package io.weaviate.client.v1.graphql.model; + +import com.google.gson.annotations.SerializedName; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@ToString +@EqualsAndHashCode +@AllArgsConstructor +@FieldDefaults(level = AccessLevel.PRIVATE) +public class GraphQLTypedResponse { + Operation data; + GraphQLError[] errors; + + @Getter + @ToString + @EqualsAndHashCode + @AllArgsConstructor + @FieldDefaults(level = AccessLevel.PRIVATE) + public static class Operation { + @SerializedName(value = "Get", alternate = {"Aggregate", "Explore"}) + private T objects; + } +} + + diff --git a/src/main/java/io/weaviate/client/v1/graphql/query/Get.java b/src/main/java/io/weaviate/client/v1/graphql/query/Get.java index 3a6a0fc5..f181f15d 100644 --- a/src/main/java/io/weaviate/client/v1/graphql/query/Get.java +++ b/src/main/java/io/weaviate/client/v1/graphql/query/Get.java @@ -1,5 +1,16 @@ package io.weaviate.client.v1.graphql.query; +import io.weaviate.client.Config; +import io.weaviate.client.base.BaseGraphQLClient; +import io.weaviate.client.base.ClientResult; +import io.weaviate.client.base.Response; +import io.weaviate.client.base.Result; +import io.weaviate.client.base.http.HttpClient; +import io.weaviate.client.v1.filters.WhereFilter; +import io.weaviate.client.v1.graphql.model.GraphQLGetBaseObject; +import io.weaviate.client.v1.graphql.model.GraphQLQuery; +import io.weaviate.client.v1.graphql.model.GraphQLResponse; +import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse; import io.weaviate.client.v1.graphql.query.argument.AskArgument; import io.weaviate.client.v1.graphql.query.argument.Bm25Argument; import io.weaviate.client.v1.graphql.query.argument.GroupArgument; @@ -21,17 +32,8 @@ import io.weaviate.client.v1.graphql.query.fields.Field; import io.weaviate.client.v1.graphql.query.fields.Fields; import io.weaviate.client.v1.graphql.query.fields.GenerativeSearchBuilder; -import io.weaviate.client.Config; -import io.weaviate.client.base.BaseClient; -import io.weaviate.client.base.ClientResult; -import io.weaviate.client.base.Response; -import io.weaviate.client.base.Result; -import io.weaviate.client.base.http.HttpClient; -import io.weaviate.client.v1.filters.WhereFilter; -import io.weaviate.client.v1.graphql.model.GraphQLQuery; -import io.weaviate.client.v1.graphql.model.GraphQLResponse; -public class Get extends BaseClient implements ClientResult { +public class Get extends BaseGraphQLClient implements ClientResult { private final GetBuilder.GetBuilderBuilder getBuilder; public Get(HttpClient httpClient, Config config) { @@ -176,4 +178,48 @@ public Result run() { Response resp = sendPostRequest("/graphql", query, GraphQLResponse.class); return new Result<>(resp); } + + /** + * This method provides a better way of serializing a GraphQL response using one's defined classes. + * Example: + * In Weaviate we have defined collection named Soup with name and price properties. + * For client to be able to properly serialize GraphQL response to an Object with + * convenient methods accessing GraphQL settings one can create a class, example: + *
{@code
+   * import com.google.gson.annotations.SerializedName;
+   *
+   * public class Soups {
+   *   {@literal @}SerializedName(value = "Soup")
+   *   List soups;
+   *
+   *   public List getSoups() {
+   *     return soups;
+   *   }
+   *
+   *   public static class Soup extends GraphQLGetBaseObject {
+   *     String name;
+   *     Float price;
+   *
+   *     public String getName() {
+   *       return name;
+   *     }
+   *
+   *     public Float getPrice() {
+   *       return price;
+   *     }
+   *   }
+   * }
+   * }
+ * + * @param classOfC - class describing Weaviate object, example: Soups class + * @param - Class of C + * @return Result of GraphQLTypedResponse of a given class + * @see GraphQLGetBaseObject + */ + public Result> run(Class classOfC) { + String getQuery = getBuilder.build().buildQuery(); + GraphQLQuery query = GraphQLQuery.builder().query(getQuery).build(); + Response> resp = sendGraphQLTypedRequest(query, classOfC); + return new Result<>(resp); + } } diff --git a/src/test/java/io/weaviate/client/base/SerializerTest.java b/src/test/java/io/weaviate/client/base/SerializerTest.java index 6afb1203..a7f2dd82 100644 --- a/src/test/java/io/weaviate/client/base/SerializerTest.java +++ b/src/test/java/io/weaviate/client/base/SerializerTest.java @@ -7,13 +7,13 @@ public class SerializerTest extends TestCase { @Test - public void testToObject() { + public void testToResponse() { // given Serializer s = new Serializer(); String description = "test äüëö"; String jsonString = "{\"description\":\""+description+"\"}"; // when - TestObj deserialized = s.toObject(jsonString, TestObj.class); + TestObj deserialized = s.toResponse(jsonString, TestObj.class); // then Assert.assertNotNull(deserialized); Assert.assertEquals(description, deserialized.getDescription()); @@ -37,7 +37,7 @@ public void testErrorResponse() { Serializer s = new Serializer(); String jsonString = "{\"error\":[{\"message\":\"get extend: unknown capability: featureProjection\"}]}"; // when - WeaviateErrorResponse deserialized = s.toObject(jsonString, WeaviateErrorResponse.class); + WeaviateErrorResponse deserialized = s.toResponse(jsonString, WeaviateErrorResponse.class); // then Assert.assertNotNull(deserialized); Assert.assertNull(deserialized.getMessage()); @@ -53,7 +53,7 @@ public void testErrorResponseWithNoError() { Serializer s = new Serializer(); String jsonString = "{\"code\":601,\"message\":\"id in body must be of type uuid: \\\"TODO_4\\\"\"}"; // when - WeaviateErrorResponse deserialized = s.toObject(jsonString, WeaviateErrorResponse.class); + WeaviateErrorResponse deserialized = s.toResponse(jsonString, WeaviateErrorResponse.class); // then Assert.assertNotNull(deserialized); Assert.assertNull(deserialized.getError()); diff --git a/src/test/java/io/weaviate/client/v1/graphql/model/GraphQLTypedResponseTest.java b/src/test/java/io/weaviate/client/v1/graphql/model/GraphQLTypedResponseTest.java new file mode 100644 index 00000000..483f8d23 --- /dev/null +++ b/src/test/java/io/weaviate/client/v1/graphql/model/GraphQLTypedResponseTest.java @@ -0,0 +1,86 @@ +package io.weaviate.client.v1.graphql.model; + +import com.google.gson.annotations.SerializedName; +import com.google.gson.reflect.TypeToken; +import io.weaviate.client.base.Serializer; +import java.io.IOException; +import java.lang.reflect.Type; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.List; +import lombok.Getter; +import static org.assertj.core.api.Assertions.assertThat; +import org.junit.Test; + +public class GraphQLTypedResponseTest { + + @Test + public void testGraphQLGetResponse() throws IOException { + // given + Serializer s = new Serializer(); + String json = new String(Files.readAllBytes(Paths.get("src/test/resources/json/graphql-response.json"))); + // when + Type responseType = TypeToken.getParameterized(GraphQLTypedResponse.class, Soups.class).getType(); + GraphQLTypedResponse resp = s.toResponse(json, responseType); + // + assertThat(resp).isNotNull() + .extracting(o -> o.getData().getObjects().getSoups()) + .extracting(o -> o.get(0)).isNotNull() + .extracting(Soups.Soup::getName).isEqualTo("JustSoup"); + } + + @Test + public void testGraphQLGetResponseSoups() throws IOException { + // given + Serializer s = new Serializer(); + String json = new String(Files.readAllBytes(Paths.get("src/test/resources/json/graphql-response.json"))); + // when + GraphQLTypedResponse resp = s.toGraphQLTypedResponse(json, Soups.class); + // + assertThat(resp).isNotNull() + .extracting(o -> o.getData().getObjects().getSoups()) + .extracting(o -> o.get(0)).isNotNull() + .extracting(Soups.Soup::getName).isEqualTo("JustSoup"); + } + + @Test + public void testGraphQLGetResponseSoups2() throws IOException { + // given + Serializer s = new Serializer(); + String json = new String(Files.readAllBytes(Paths.get("src/test/resources/json/graphql-group-by-response.json"))); + // when + GraphQLTypedResponse resp = s.toGraphQLTypedResponse(json, Passages.class); + // then + assertThat(resp).isNotNull() + .extracting(o -> o.getData().getObjects().getPassages()) + .extracting(o -> o.get(0)).isNotNull() + .extracting(GraphQLGetBaseObject::getAdditional).isNotNull() + .extracting(GraphQLGetBaseObject.Additional::getGroup).isNotNull() + .extracting(GraphQLGetBaseObject.Additional.Group::getHits).isNotNull() + .extracting(o -> o.get(0)).isNotNull() + .extracting(GraphQLGetBaseObject.Additional.Group.GroupHit::getProperties).isNotNull() + .extracting(o -> o.get("name")).isEqualTo("test-name"); + } +} + +@Getter +class Soups { + @SerializedName(value = "Soup") + List soups; + + @Getter + public static class Soup extends GraphQLGetBaseObject { + String name; + } +} + +@Getter +class Passages { + @SerializedName(value = "Passage") + List passages; + + @Getter + public static class Passage extends GraphQLGetBaseObject { + String name; + } +} diff --git a/src/test/java/io/weaviate/integration/client/async/graphql/ClientGraphQLTypedTest.java b/src/test/java/io/weaviate/integration/client/async/graphql/ClientGraphQLTypedTest.java new file mode 100644 index 00000000..81148a5f --- /dev/null +++ b/src/test/java/io/weaviate/integration/client/async/graphql/ClientGraphQLTypedTest.java @@ -0,0 +1,54 @@ +package io.weaviate.integration.client.async.graphql; + +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.async.WeaviateAsyncClient; +import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse; +import io.weaviate.client.v1.graphql.query.fields.Field; +import io.weaviate.integration.client.WeaviateDockerCompose; +import io.weaviate.integration.client.WeaviateTestGenerics; +import io.weaviate.integration.tests.graphql.ClientGraphQLTypedTestSuite; +import java.util.concurrent.ExecutionException; +import java.util.function.Supplier; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +public class ClientGraphQLTypedTest { + private WeaviateClient client; + private final WeaviateTestGenerics testGenerics = new WeaviateTestGenerics(); + + @ClassRule + public static WeaviateDockerCompose compose = new WeaviateDockerCompose(); + + @Before + public void before() { + String httpHost = compose.getHttpHostAddress(); + Config config = new Config("http", httpHost); + + client = new WeaviateClient(config); + testGenerics.createTestSchemaAndData(client); + } + + @After + public void after() { + testGenerics.cleanupWeaviate(client); + } + + @Test + public void testGraphQLGet() { + Supplier>> supplyPizza =() -> { + try (WeaviateAsyncClient asyncClient = client.async()) { + return asyncClient.graphQL().get() + .withClassName("Pizza") + .withFields(Field.builder().name("name").build(), Field.builder().name("description").build()) + .run(ClientGraphQLTypedTestSuite.Pizzas.class).get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }; + ClientGraphQLTypedTestSuite.testGraphQLGet(supplyPizza); + } +} diff --git a/src/test/java/io/weaviate/integration/client/graphql/AbstractClientGraphQLTest.java b/src/test/java/io/weaviate/integration/client/graphql/AbstractClientGraphQLTest.java index acdea371..a4ffad12 100644 --- a/src/test/java/io/weaviate/integration/client/graphql/AbstractClientGraphQLTest.java +++ b/src/test/java/io/weaviate/integration/client/graphql/AbstractClientGraphQLTest.java @@ -117,7 +117,7 @@ protected void assertIds(String className, Result gqlResult, St protected List getGroups(List> result) { Serializer serializer = new Serializer(); String jsonString = serializer.toJsonString(result); - AdditionalGroupByAdditional[] response = serializer.toObject(jsonString, AdditionalGroupByAdditional[].class); + AdditionalGroupByAdditional[] response = serializer.toResponse(jsonString, AdditionalGroupByAdditional[].class); Assertions.assertThat(response).isNotNull().hasSize(3); return Arrays.stream(response).map(AdditionalGroupByAdditional::get_additional).map(Additional::getGroup).collect(Collectors.toList()); } diff --git a/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLTypedTest.java b/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLTypedTest.java new file mode 100644 index 00000000..9835f258 --- /dev/null +++ b/src/test/java/io/weaviate/integration/client/graphql/ClientGraphQLTypedTest.java @@ -0,0 +1,46 @@ +package io.weaviate.integration.client.graphql; + +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse; +import io.weaviate.client.v1.graphql.query.fields.Field; +import io.weaviate.integration.client.WeaviateDockerCompose; +import io.weaviate.integration.client.WeaviateTestGenerics; +import io.weaviate.integration.tests.graphql.ClientGraphQLTypedTestSuite; +import java.util.function.Supplier; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +public class ClientGraphQLTypedTest { + private WeaviateClient client; + private final WeaviateTestGenerics testGenerics = new WeaviateTestGenerics(); + + @ClassRule + public static WeaviateDockerCompose compose = new WeaviateDockerCompose(); + + @Before + public void before() { + String httpHost = compose.getHttpHostAddress(); + Config config = new Config("http", httpHost); + + client = new WeaviateClient(config); + testGenerics.createTestSchemaAndData(client); + } + + @After + public void after() { + testGenerics.cleanupWeaviate(client); + } + + @Test + public void testGraphQLGet() { + Supplier>> supplyPizza = () -> client.graphQL().get() + .withClassName("Pizza") + .withFields(Field.builder().name("name").build(), Field.builder().name("description").build()) + .run(ClientGraphQLTypedTestSuite.Pizzas.class); + ClientGraphQLTypedTestSuite.testGraphQLGet(supplyPizza); + } +} diff --git a/src/test/java/io/weaviate/integration/tests/graphql/ClientGraphQLTypedTestSuite.java b/src/test/java/io/weaviate/integration/tests/graphql/ClientGraphQLTypedTestSuite.java new file mode 100644 index 00000000..d85d3838 --- /dev/null +++ b/src/test/java/io/weaviate/integration/tests/graphql/ClientGraphQLTypedTestSuite.java @@ -0,0 +1,54 @@ +package io.weaviate.integration.tests.graphql; + +import com.google.gson.annotations.SerializedName; +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.graphql.model.GraphQLGetBaseObject; +import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse; +import java.util.List; +import java.util.function.Supplier; +import lombok.Getter; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +public class ClientGraphQLTypedTestSuite { + + @Getter + public static class Pizzas { + @SerializedName(value = "Pizza") + List pizzas; + + @Getter + public static class Pizza extends GraphQLGetBaseObject { + String name; + String description; + String bestBefore; + Float price; + } + } + + public static void testGraphQLGet(Supplier>> supplyPizza) { + // given + Result> result = supplyPizza.get(); + // then + assertNotNull(result); + assertFalse(result.hasErrors()); + GraphQLTypedResponse gqlResult = result.getResult(); + assertNotNull(gqlResult); + assertNotNull(gqlResult.getData()); + GraphQLTypedResponse.Operation resp = gqlResult.getData(); + assertNotNull(resp.getObjects()); + assertNotNull(resp.getObjects().getPizzas()); + List pizzas = resp.getObjects().getPizzas(); + assertTrue(pizzas.size() == 4); + String name = pizzas.get(0).getName(); + assertNotNull(name); + assertTrue(name.length() > 0); + String description = pizzas.get(0).getDescription(); + assertNotNull(description); + assertTrue(description.length() > 0); + assertNull(pizzas.get(0).getPrice()); + assertNull(pizzas.get(0).getBestBefore()); + } +} diff --git a/src/test/resources/json/graphql-group-by-response.json b/src/test/resources/json/graphql-group-by-response.json new file mode 100644 index 00000000..1fd71a2d --- /dev/null +++ b/src/test/resources/json/graphql-group-by-response.json @@ -0,0 +1,303 @@ +{ + "data": { + "Get": { + "Passage": [ + { + "_additional": { + "group": { + "count": 10, + "groupedBy": { + "path": [ + "ofDocument" + ], + "value": "weaviate://localhost/Document/00000000-0000-0000-0000-00000000000a" + }, + "hits": [ + { + "_additional": { + "distance": 1.1920929e-7, + "id": "00000000-0000-0000-0000-000000000001" + }, + "name": "test-name", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000a" + } + } + ] + }, + { + "_additional": { + "distance": 0.002315104, + "id": "00000000-0000-0000-0000-000000000009" + }, + "name": "name09", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000a" + } + } + ] + }, + { + "_additional": { + "distance": 0.0023562908, + "id": "00000000-0000-0000-0000-000000000007" + }, + "name": "name07", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000a" + } + } + ] + }, + { + "_additional": { + "distance": 0.0025094151, + "id": "00000000-0000-0000-0000-000000000008" + }, + "name": "name08", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000a" + } + } + ] + }, + { + "_additional": { + "distance": 0.0027094483, + "id": "00000000-0000-0000-0000-000000000006" + }, + "name": "name06", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000a" + } + } + ] + }, + { + "_additional": { + "distance": 0.0027621984, + "id": "00000000-0000-0000-0000-000000000010" + }, + "name": "name10", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000a" + } + } + ] + }, + { + "_additional": { + "distance": 0.0028537512, + "id": "00000000-0000-0000-0000-000000000005" + }, + "name": "name05", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000a" + } + } + ] + }, + { + "_additional": { + "distance": 0.0033442974, + "id": "00000000-0000-0000-0000-000000000004" + }, + "name": "name04", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000a" + } + } + ] + }, + { + "_additional": { + "distance": 0.0041819215, + "id": "00000000-0000-0000-0000-000000000003" + }, + "name": "name03", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000a" + } + } + ] + }, + { + "_additional": { + "distance": 0.005713105, + "id": "00000000-0000-0000-0000-000000000002" + }, + "name": "name02", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000a" + } + } + ] + } + ], + "id": 0, + "maxDistance": 0.005713105, + "minDistance": 1.1920929e-7 + } + }, + "ofDocument": null + }, + { + "_additional": { + "group": { + "count": 4, + "groupedBy": { + "path": [ + "ofDocument" + ], + "value": "weaviate://localhost/Document/00000000-0000-0000-0000-00000000000b" + }, + "hits": [ + { + "_additional": { + "distance": 0.0025349855, + "id": "00000000-0000-0000-0000-000000000011" + }, + "name": "name11", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000b" + } + } + ] + }, + { + "_additional": { + "distance": 0.0028856993, + "id": "00000000-0000-0000-0000-000000000013" + }, + "name": "name13", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000b" + } + } + ] + }, + { + "_additional": { + "distance": 0.0033005476, + "id": "00000000-0000-0000-0000-000000000012" + }, + "name": "name12", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000b" + } + } + ] + }, + { + "_additional": { + "distance": 0.004168868, + "id": "00000000-0000-0000-0000-000000000014" + }, + "name": "name14", + "ofDocument": [ + { + "_additional": { + "id": "00000000-0000-0000-0000-00000000000b" + } + } + ] + } + ], + "id": 1, + "maxDistance": 0.004168868, + "minDistance": 0.0025349855 + } + }, + "ofDocument": null + }, + { + "_additional": { + "group": { + "count": 6, + "groupedBy": { + "path": [ + "ofDocument" + ], + "value": "" + }, + "hits": [ + { + "_additional": { + "distance": 0.0034632683, + "id": "00000000-0000-0000-0000-000000000016" + }, + "ofDocument": null + }, + { + "_additional": { + "distance": 0.0040759444, + "id": "00000000-0000-0000-0000-000000000017" + }, + "ofDocument": null + }, + { + "_additional": { + "distance": 0.0041413307, + "id": "00000000-0000-0000-0000-000000000015" + }, + "ofDocument": null + }, + { + "_additional": { + "distance": 0.004283905, + "id": "00000000-0000-0000-0000-000000000020" + }, + "ofDocument": null + }, + { + "_additional": { + "distance": 0.0045325756, + "id": "00000000-0000-0000-0000-000000000019" + }, + "ofDocument": null + }, + { + "_additional": { + "distance": 0.0049524903, + "id": "00000000-0000-0000-0000-000000000018" + }, + "ofDocument": null + } + ], + "id": 2, + "maxDistance": 0.0049524903, + "minDistance": 0.0034632683 + } + }, + "ofDocument": null + } + ] + } + } +} diff --git a/src/test/resources/json/graphql-response.json b/src/test/resources/json/graphql-response.json new file mode 100644 index 00000000..d33c90e0 --- /dev/null +++ b/src/test/resources/json/graphql-response.json @@ -0,0 +1,14 @@ +{ + "data": { + "Get": { + "Soup": [ + { + "_additional": { + "certainty": 0.9999998211860657 + }, + "name": "JustSoup" + } + ] + } + } +}