Skip to content

Commit

Permalink
Add support for Group
Browse files Browse the repository at this point in the history
  • Loading branch information
antas-marcin committed Nov 26, 2024
1 parent 8f7ae87 commit d18a048
Show file tree
Hide file tree
Showing 5 changed files with 392 additions and 23 deletions.
8 changes: 7 additions & 1 deletion src/main/java/io/weaviate/client/base/Serializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
import io.weaviate.client.base.util.GroupHitDeserializer;
import io.weaviate.client.v1.graphql.model.GraphQLGetBaseObject;
import io.weaviate.client.v1.graphql.model.GraphQLTypedResponse;
import java.lang.reflect.Type;

Expand All @@ -14,7 +16,11 @@ public Serializer() {
}

public <C> GraphQLTypedResponse<C> toGraphQLTypedResponse(String response, Class<C> classOfT) {
return gson.fromJson(response,
Gson gsonTyped = new GsonBuilder()
.disableHtmlEscaping()
.registerTypeAdapter(GraphQLGetBaseObject.Additional.Group.GroupHit.class, new GroupHitDeserializer())
.create();
return gsonTyped.fromJson(response,
TypeToken.getParameterized(GraphQLTypedResponse.class, classOfT).getType());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package io.weaviate.client.base.util;

import com.google.gson.JsonDeserializationContext;
import com.google.gson.JsonDeserializer;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import com.google.gson.reflect.TypeToken;
import io.weaviate.client.v1.graphql.model.GraphQLGetBaseObject;
import java.lang.reflect.Type;
import java.util.Map;

public class GroupHitDeserializer implements JsonDeserializer<GraphQLGetBaseObject.Additional.Group.GroupHit> {

@Override
public GraphQLGetBaseObject.Additional.Group.GroupHit deserialize(JsonElement json, Type typeOfT,
JsonDeserializationContext context) throws JsonParseException {
JsonObject jsonObject = json.getAsJsonObject();

GraphQLGetBaseObject.Additional.Group.GroupHit.AdditionalGroupHit additional =
context.deserialize(jsonObject.get("_additional"), GraphQLGetBaseObject.Additional.Group.GroupHit.AdditionalGroupHit.class);

// Remove _additional from the JSON object
jsonObject.remove("_additional");

// Deserialize the rest into a Map
Type mapType = new TypeToken<Map<String, Object>>() {}.getType();
Map<String, Object> properties = context.deserialize(jsonObject, mapType);

return new GraphQLGetBaseObject.Additional.Group.GroupHit(properties, additional);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package io.weaviate.client.v1.graphql.model;

import com.google.gson.annotations.SerializedName;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.Getter;

@Getter
Expand All @@ -21,17 +23,50 @@ public static class Additional {
Float[] vector;
Map<String, Float[]> vectors;
Generate generate;
Group group;

@Getter
public static class Generate {
String singleResult;
String groupedResult;
String error;
Debug debug;

@Getter
public static class Debug {
String prompt;
}
}

@Getter
public static class Group {
public String id;
public GroupedBy groupedBy;
public Integer count;
public Float maxDistance;
public Float minDistance;
public List<GroupHit> hits;

@Getter
public static class GroupedBy {
public String value;
public String[] path;
}

@Getter
@AllArgsConstructor
public static class GroupHit {
@SerializedName("properties")
Map<String, Object> properties;
@SerializedName(value = "_additional")
AdditionalGroupHit additional;

@Getter
public static class AdditionalGroupHit {
String id;
Float distance;
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,19 @@ public void testGraphQLGetResponseSoups() throws IOException {
public void testGraphQLGetResponseSoups2() throws IOException {
// given
Serializer s = new Serializer();
String json = new String(Files.readAllBytes(Paths.get("src/test/resources/json/graphql-response.json")));
String json = new String(Files.readAllBytes(Paths.get("src/test/resources/json/graphql-group-by-response.json")));
// when
GraphQLTypedResponse<Soups2> resp = s.toGraphQLTypedResponse(json, Soups2.class);
GraphQLTypedResponse<Passages> resp = s.toGraphQLTypedResponse(json, Passages.class);
// then
assertThat(resp).isNotNull()
.extracting(o -> o.getData().getObjects().getSoups())
.extracting(o -> o.getData().getObjects().getPassages())
.extracting(o -> o.get(0)).isNotNull()
.extracting(Soups2.Soup::getName).isEqualTo("JustSoup");
.extracting(GraphQLGetBaseObject::getAdditional).isNotNull()
.extracting(GraphQLGetBaseObject.Additional::getGroup).isNotNull()
.extracting(GraphQLGetBaseObject.Additional.Group::getHits).isNotNull()
.extracting(o -> o.get(0)).isNotNull()
.extracting(GraphQLGetBaseObject.Additional.Group.GroupHit::getProperties).isNotNull()
.extracting(o -> o.get("name")).isEqualTo("test-name");
}
}

Expand All @@ -69,25 +74,13 @@ public static class Soup extends GraphQLGetBaseObject {
}
}

@Getter
class Passages {
@SerializedName(value = "Passage")
List<Passage> passages;

class Soups2 {
@SerializedName(value = "Soup")
List<Soup> soups;

public List<Soup> getSoups() {
return soups;
}

public static class Soup extends GraphQLGetBaseObject {
@Getter
public static class Passage extends GraphQLGetBaseObject {
String name;
Float price;

public String getName() {
return name;
}

public Float getPrice() {
return price;
}
}
}
Loading

0 comments on commit d18a048

Please sign in to comment.