Skip to content

Commit

Permalink
Merge pull request #330 from weaviate/support-for-generic-response
Browse files Browse the repository at this point in the history
Support for GraphQL response with custom generic classes
  • Loading branch information
antas-marcin authored Nov 26, 2024
2 parents 7d41375 + d18a048 commit e536dc8
Show file tree
Hide file tree
Showing 21 changed files with 1,008 additions and 40 deletions.
8 changes: 6 additions & 2 deletions src/main/java/io/weaviate/client/base/AsyncBaseClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.apache.hc.core5.http.HttpHeaders;

public abstract class AsyncBaseClient<T> {
private final CloseableHttpAsyncClient client;
protected final CloseableHttpAsyncClient client;
private final Config config;
private final Serializer serializer;

Expand Down Expand Up @@ -71,12 +71,16 @@ protected Future<Result<T>> sendHeadRequest(String endpoint, FutureCallback<Resu
}

private Future<Result<T>> sendRequest(String endpoint, Object payload, String method, Class<T> classOfT, FutureCallback<Result<T>> callback, ResponseParser<T> parser) {
return client.execute(SimpleRequestProducer.create(getRequest(endpoint, payload, method)), new WeaviateResponseConsumer<>(classOfT, parser), callback);
}

protected SimpleHttpRequest getRequest(String endpoint, Object payload, String method) {
SimpleHttpRequest req = new SimpleHttpRequest(method, String.format("%s%s", config.getBaseURL(), endpoint));
req.addHeader(HttpHeaders.ACCEPT, "*/*");
req.addHeader(HttpHeaders.CONTENT_TYPE, "application/json");
if (payload != null) {
req.setBody(serializer.toJsonString(payload), ContentType.APPLICATION_JSON);
}
return client.execute(SimpleRequestProducer.create(req), new WeaviateResponseConsumer<>(classOfT, parser), callback);
return req;
}
}
20 changes: 20 additions & 0 deletions src/main/java/io/weaviate/client/base/AsyncBaseGraphQLClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package io.weaviate.client.base;

import io.weaviate.client.Config;
import io.weaviate.client.base.http.async.WeaviateGraphQLTypedResponseConsumer;
import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse;
import java.util.concurrent.Future;
import org.apache.hc.client5.http.async.methods.SimpleRequestProducer;
import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient;
import org.apache.hc.core5.concurrent.FutureCallback;

public class AsyncBaseGraphQLClient<T> extends AsyncBaseClient<T> {
public AsyncBaseGraphQLClient(CloseableHttpAsyncClient client, Config config) {
super(client, config);
}

protected <C> Future<Result<GraphQLTypedResponse<C>>> sendGraphQLTypedRequest(Object payload, Class<C> classOfC,
FutureCallback<Result<GraphQLTypedResponse<C>>> callback) {
return client.execute(SimpleRequestProducer.create(getRequest("/graphql", payload, "POST")), new WeaviateGraphQLTypedResponseConsumer<>(classOfC), callback);
}
}
16 changes: 8 additions & 8 deletions src/main/java/io/weaviate/client/base/BaseClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
public abstract class BaseClient<T> {
private final HttpClient client;
private final Config config;
private final Serializer serializer;
protected final Serializer serializer;

public BaseClient(HttpClient client, Config config) {
this.config = config;
Expand Down Expand Up @@ -42,9 +42,7 @@ protected Response<T> sendHeadRequest(String endpoint, Class<T> classOfT) {

private Response<T> sendRequest(String endpoint, Object payload, String method, Class<T> classOfT) {
try {
String url = config.getBaseURL() + endpoint;
String json = toJsonString(payload);
HttpResponse response = this.sendHttpRequest(url, json, method);
HttpResponse response = this.sendHttpRequest(endpoint, payload, method);
int statusCode = response.getStatusCode();
String responseBody = response.getBody();

Expand All @@ -61,7 +59,9 @@ private Response<T> sendRequest(String endpoint, Object payload, String method,
}
}

private HttpResponse sendHttpRequest(String address, String json, String method) throws Exception {
protected HttpResponse sendHttpRequest(String endpoint, Object payload, String method) throws Exception {
String address = config.getBaseURL() + endpoint;
String json = toJsonString(payload);
if (method.equals("POST")) {
return client.sendPostRequest(address, json);
}
Expand All @@ -80,15 +80,15 @@ private HttpResponse sendHttpRequest(String address, String json, String method)
return client.sendGetRequest(address);
}

private <C> C toResponse(String response, Class<C> classOfT) {
return serializer.toObject(response, classOfT);
protected <C> C toResponse(String response, Class<C> classOfT) {
return serializer.toResponse(response, classOfT);
}

private String toJsonString(Object object) {
return serializer.toJsonString(object);
}

private WeaviateErrorResponse getWeaviateErrorResponse(Exception e) {
protected WeaviateErrorResponse getWeaviateErrorResponse(Exception e) {
WeaviateErrorMessage error = WeaviateErrorMessage.builder().message(e.getMessage()).throwable(e).build();
return WeaviateErrorResponse.builder().error(Collections.singletonList(error)).build();
}
Expand Down
35 changes: 35 additions & 0 deletions src/main/java/io/weaviate/client/base/BaseGraphQLClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.weaviate.client.base;

import io.weaviate.client.Config;
import io.weaviate.client.base.http.HttpClient;
import io.weaviate.client.base.http.HttpResponse;
import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse;

public abstract class BaseGraphQLClient<T> extends BaseClient<T> {
public BaseGraphQLClient(HttpClient client, Config config) {
super(client, config);
}

private <C> GraphQLTypedResponse<C> toResponseTyped(String response, Class<C> classOfC) {
return serializer.toGraphQLTypedResponse(response, classOfC);
}

protected <C> Response<GraphQLTypedResponse<C>> sendGraphQLTypedRequest(Object payload, Class<C> classOfC) {
try {
HttpResponse response = this.sendHttpRequest("/graphql", payload, "POST");
int statusCode = response.getStatusCode();
String responseBody = response.getBody();

if (statusCode < 399) {
GraphQLTypedResponse<C> body = toResponseTyped(responseBody, classOfC);
return new Response<>(statusCode, body, null);
}

WeaviateErrorResponse error = toResponse(responseBody, WeaviateErrorResponse.class);
return new Response<>(statusCode, null, error);
} catch (Exception e) {
WeaviateErrorResponse errors = getWeaviateErrorResponse(e);
return new Response<>(0, null, errors);
}
}
}
39 changes: 36 additions & 3 deletions src/main/java/io/weaviate/client/base/Serializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
import io.weaviate.client.base.util.GroupHitDeserializer;
import io.weaviate.client.v1.graphql.model.GraphQLGetBaseObject;
import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse;
import java.lang.reflect.Type;

public class Serializer {
private Gson gson;
Expand All @@ -10,7 +15,20 @@ public Serializer() {
this.gson = new GsonBuilder().disableHtmlEscaping().create();
}

public <T> T toObject(String response, Class<T> classOfT) {
public <C> GraphQLTypedResponse<C> toGraphQLTypedResponse(String response, Class<C> classOfT) {
Gson gsonTyped = new GsonBuilder()
.disableHtmlEscaping()
.registerTypeAdapter(GraphQLGetBaseObject.Additional.Group.GroupHit.class, new GroupHitDeserializer())
.create();
return gsonTyped.fromJson(response,
TypeToken.getParameterized(GraphQLTypedResponse.class, classOfT).getType());
}

public <C> C toResponse(String response, Type typeOfT) {
return gson.fromJson(response, typeOfT);
}

public <T> T toResponse(String response, Class<T> classOfT) {
return gson.fromJson(response, classOfT);
}

Expand All @@ -27,13 +45,28 @@ public <T> Result<T> toResult(int statusCode, String body, Class<T> classOfT) {

public <T> Response<T> toResponse(int statusCode, String body, Class<T> classOfT) {
if (statusCode < 399) {
T obj = toObject(body, classOfT);
T obj = toResponse(body, classOfT);
return new Response<>(statusCode, obj, null);
}
return new Response<>(statusCode, null, toWeaviateError(body));
}

public <C> Response<GraphQLTypedResponse<C>> toGraphQLTypedResponse(int statusCode, String body, Class<C> classOfC) {
if (statusCode < 399) {
GraphQLTypedResponse<C> obj = toGraphQLTypedResponse(body, classOfC);
return new Response<>(statusCode, obj, null);
}
return new Response<>(statusCode, null, toWeaviateError(body));
}

public <C> Result<GraphQLTypedResponse<C>> toGraphQLTypedResult(int statusCode, String body, Class<C> classOfC) {
if (statusCode < 399) {
return new Result<>(toGraphQLTypedResponse(statusCode, body, classOfC));
}
return new Result<>(statusCode, null, toWeaviateError(body));
}

public WeaviateErrorResponse toWeaviateError(String body) {
return toObject(body, WeaviateErrorResponse.class);
return toResponse(body, WeaviateErrorResponse.class);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.weaviate.client.base.http.async;

import io.weaviate.client.base.Result;
import io.weaviate.client.base.Serializer;
import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import org.apache.hc.core5.http.ContentType;
import org.apache.hc.core5.http.HttpException;
import org.apache.hc.core5.http.HttpResponse;
import org.apache.hc.core5.http.nio.entity.BasicAsyncEntityConsumer;
import org.apache.hc.core5.http.nio.support.AbstractAsyncResponseConsumer;
import org.apache.hc.core5.http.protocol.HttpContext;

public class WeaviateGraphQLTypedResponseConsumer<C> extends AbstractAsyncResponseConsumer<Result<GraphQLTypedResponse<C>>, byte[]> {
private final Serializer serializer;
private final Class<C> classOfT;

public WeaviateGraphQLTypedResponseConsumer(Class<C> classOfT) {
super(new BasicAsyncEntityConsumer());
this.serializer = new Serializer();
this.classOfT = classOfT;
}

@Override
protected Result<GraphQLTypedResponse<C>> buildResult(HttpResponse response, byte[] entity, ContentType contentType) {
String body = (entity != null) ? new String(entity, StandardCharsets.UTF_8) : "";
return serializer.toGraphQLTypedResult(response.getCode(), body, classOfT);
}

@Override
public void informationResponse(HttpResponse response, HttpContext context) throws HttpException, IOException {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package io.weaviate.client.base.util;

import com.google.gson.JsonDeserializationContext;
import com.google.gson.JsonDeserializer;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import com.google.gson.reflect.TypeToken;
import io.weaviate.client.v1.graphql.model.GraphQLGetBaseObject;
import java.lang.reflect.Type;
import java.util.Map;

public class GroupHitDeserializer implements JsonDeserializer<GraphQLGetBaseObject.Additional.Group.GroupHit> {

@Override
public GraphQLGetBaseObject.Additional.Group.GroupHit deserialize(JsonElement json, Type typeOfT,
JsonDeserializationContext context) throws JsonParseException {
JsonObject jsonObject = json.getAsJsonObject();

GraphQLGetBaseObject.Additional.Group.GroupHit.AdditionalGroupHit additional =
context.deserialize(jsonObject.get("_additional"), GraphQLGetBaseObject.Additional.Group.GroupHit.AdditionalGroupHit.class);

// Remove _additional from the JSON object
jsonObject.remove("_additional");

// Deserialize the rest into a Map
Type mapType = new TypeToken<Map<String, Object>>() {}.getType();
Map<String, Object> properties = context.deserialize(jsonObject, mapType);

return new GraphQLGetBaseObject.Additional.Group.GroupHit(properties, additional);
}
}
Loading

0 comments on commit e536dc8

Please sign in to comment.