Skip to content

Commit

Permalink
Merge pull request #235 from weaviate/grpc-cross-refs-support
Browse files Browse the repository at this point in the history
Add support for cross references in gRPC Batch API
  • Loading branch information
antas-marcin authored Nov 14, 2023
2 parents d202a4a + 369c9e6 commit 1ee55d1
Show file tree
Hide file tree
Showing 13 changed files with 460 additions and 59 deletions.
38 changes: 38 additions & 0 deletions src/main/java/io/weaviate/client/base/util/CrossReference.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package io.weaviate.client.base.util;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.apache.commons.lang3.StringUtils;

@ToString
@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class CrossReference {
String peerName;
String className;
String targetID;
boolean local;

public CrossReference(String peerName, String className, String targetID) {
this.local = peerName != null && peerName.equals("localhost");
this.peerName = peerName;
this.className = className;
this.targetID = targetID;
}

public static CrossReference fromBeacon(String beacon) {
if (StringUtils.isNotBlank(beacon) && beacon.startsWith("weaviate://")) {
String path = beacon.replaceFirst("weaviate://", "");
String[] parts = path.split("/");
if (parts.length == 3) {
return new CrossReference(parts[0], parts[1], parts[2]);
}
if (parts.length == 2) {
return new CrossReference(parts[0], "", parts[1]);
}
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import io.weaviate.client.v1.data.model.WeaviateObject;
import io.weaviate.client.base.util.CrossReference;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBase;
import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch;
import io.weaviate.client.v1.data.model.WeaviateObject;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
Expand All @@ -13,8 +14,6 @@
import java.util.stream.Collectors;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;

Expand Down Expand Up @@ -51,31 +50,37 @@ private static class Properties {
List<WeaviateProtoBase.BooleanArrayProperties> booleanArrayProperties;
List<WeaviateProtoBase.ObjectProperties> objectProperties;
List<WeaviateProtoBase.ObjectArrayProperties> objectArrayProperties;
List<WeaviateProtoBatch.BatchObject.SingleTargetRefProps> singleTargetRefProps;
List<WeaviateProtoBatch.BatchObject.MultiTargetRefProps> multiTargetRefProps;
}

private static WeaviateProtoBatch.BatchObject.Properties buildProperties(Map<String, Object> properties) {
WeaviateProtoBatch.BatchObject.Properties.Builder builder = WeaviateProtoBatch.BatchObject.Properties.newBuilder();

Properties props = extractProperties(properties);
Properties props = extractProperties(properties, true);
builder.setNonRefProperties(Struct.newBuilder().putAllFields(props.nonRefProperties).build());
props.numberArrayProperties.forEach(builder::addNumberArrayProperties);
props.intArrayProperties.forEach(builder::addIntArrayProperties);
props.textArrayProperties.forEach(builder::addTextArrayProperties);
props.booleanArrayProperties.forEach(builder::addBooleanArrayProperties);
props.objectProperties.forEach(builder::addObjectProperties);
props.objectArrayProperties.forEach(builder::addObjectArrayProperties);
props.singleTargetRefProps.forEach(builder::addSingleTargetRefProps);
props.multiTargetRefProps.forEach(builder::addMultiTargetRefProps);

return builder.build();
}

private static Properties extractProperties(Map<String, Object> properties) {
private static Properties extractProperties(Map<String, Object> properties, boolean rootLevel) {
Map<String, Value> nonRefProperties = new HashMap<>();
List<WeaviateProtoBase.NumberArrayProperties> numberArrayProperties = new ArrayList<>();
List<WeaviateProtoBase.IntArrayProperties> intArrayProperties = new ArrayList<>();
List<WeaviateProtoBase.TextArrayProperties> textArrayProperties = new ArrayList<>();
List<WeaviateProtoBase.BooleanArrayProperties> booleanArrayProperties = new ArrayList<>();
List<WeaviateProtoBase.ObjectProperties> objectProperties = new ArrayList<>();
List<WeaviateProtoBase.ObjectArrayProperties> objectArrayProperties = new ArrayList<>();
List<WeaviateProtoBatch.BatchObject.SingleTargetRefProps> singleTargetRefProps = new ArrayList<>();
List<WeaviateProtoBatch.BatchObject.MultiTargetRefProps> multiTargetRefProps = new ArrayList<>();
// extract properties
for (Map.Entry<String, Object> e : properties.entrySet()) {
String propName = e.getKey();
Expand Down Expand Up @@ -146,7 +151,7 @@ private static Properties extractProperties(Map<String, Object> properties) {
continue;
}
if (propValue instanceof Map) {
Properties extractedProperties = extractProperties((Map<String, Object>) propValue);
Properties extractedProperties = extractProperties((Map<String, Object>) propValue, false);
WeaviateProtoBase.ObjectPropertiesValue.Builder objectPropertiesValue = WeaviateProtoBase.ObjectPropertiesValue.newBuilder();
objectPropertiesValue.setNonRefProperties(Struct.newBuilder().putAllFields(extractedProperties.nonRefProperties).build());
extractedProperties.numberArrayProperties.forEach(objectPropertiesValue::addNumberArrayProperties);
Expand All @@ -163,30 +168,90 @@ private static Properties extractProperties(Map<String, Object> properties) {
continue;
}
if (propValue instanceof List) {
List<WeaviateProtoBase.ObjectPropertiesValue> objectPropertiesValues = new ArrayList<>();
for (Object propValueObject : (List) propValue) {
if (propValueObject instanceof Map) {
Properties extractedProperties = extractProperties((Map<String, Object>) propValueObject);
WeaviateProtoBase.ObjectPropertiesValue.Builder objectPropertiesValue = WeaviateProtoBase.ObjectPropertiesValue.newBuilder();
objectPropertiesValue.setNonRefProperties(Struct.newBuilder().putAllFields(extractedProperties.nonRefProperties).build());
extractedProperties.numberArrayProperties.forEach(objectPropertiesValue::addNumberArrayProperties);
extractedProperties.intArrayProperties.forEach(objectPropertiesValue::addIntArrayProperties);
extractedProperties.textArrayProperties.forEach(objectPropertiesValue::addTextArrayProperties);
extractedProperties.booleanArrayProperties.forEach(objectPropertiesValue::addBooleanArrayProperties);
extractedProperties.objectProperties.forEach(objectPropertiesValue::addObjectProperties);
extractedProperties.objectArrayProperties.forEach(objectPropertiesValue::addObjectArrayProperties);

objectPropertiesValues.add(objectPropertiesValue.build());
if (isCrossReference((List<?>) propValue, rootLevel)) {
// it's a cross reference
List<String> beacons = extractBeacons((List<?>) propValue);
List<CrossReference> crossReferences = beacons.stream()
.map(CrossReference::fromBeacon)
.collect(Collectors.toList());

Map<String, List<String>> crefs = new HashMap<>();
for (CrossReference cref : crossReferences) {
List<String> uuids = crefs.get(cref.getClassName());
if (uuids == null) {
uuids = new ArrayList<>();
}
uuids.add(cref.getTargetID());
crefs.put(cref.getClassName(), uuids);
}
}

WeaviateProtoBase.ObjectArrayProperties objectArrayProps = WeaviateProtoBase.ObjectArrayProperties.newBuilder()
.setPropName(propName).addAllValues(objectPropertiesValues).build();
if (crefs.size() == 1) {
for (Map.Entry<String, List<String>> crefEntry : crefs.entrySet()) {
WeaviateProtoBatch.BatchObject.SingleTargetRefProps singleTargetCrossRefs = WeaviateProtoBatch.BatchObject.SingleTargetRefProps.newBuilder()
.setPropName(propName).addAllUuids(crefEntry.getValue()).build();
singleTargetRefProps.add(singleTargetCrossRefs);
}
}
if (crefs.size() > 1) {
for (Map.Entry<String, List<String>> crefEntry : crefs.entrySet()) {
WeaviateProtoBatch.BatchObject.MultiTargetRefProps multiTargetCrossRefs = WeaviateProtoBatch.BatchObject.MultiTargetRefProps.newBuilder()
.setPropName(propName).addAllUuids(crefEntry.getValue()).setTargetCollection(crefEntry.getKey()).build();
multiTargetRefProps.add(multiTargetCrossRefs);
}
}
} else {
// it's an object
List<WeaviateProtoBase.ObjectPropertiesValue> objectPropertiesValues = new ArrayList<>();
for (Object propValueObject : (List) propValue) {
if (propValueObject instanceof Map) {
Properties extractedProperties = extractProperties((Map<String, Object>) propValueObject, false);
WeaviateProtoBase.ObjectPropertiesValue.Builder objectPropertiesValue = WeaviateProtoBase.ObjectPropertiesValue.newBuilder();
objectPropertiesValue.setNonRefProperties(Struct.newBuilder().putAllFields(extractedProperties.nonRefProperties).build());
extractedProperties.numberArrayProperties.forEach(objectPropertiesValue::addNumberArrayProperties);
extractedProperties.intArrayProperties.forEach(objectPropertiesValue::addIntArrayProperties);
extractedProperties.textArrayProperties.forEach(objectPropertiesValue::addTextArrayProperties);
extractedProperties.booleanArrayProperties.forEach(objectPropertiesValue::addBooleanArrayProperties);
extractedProperties.objectProperties.forEach(objectPropertiesValue::addObjectProperties);
extractedProperties.objectArrayProperties.forEach(objectPropertiesValue::addObjectArrayProperties);

objectPropertiesValues.add(objectPropertiesValue.build());
}
}

objectArrayProperties.add(objectArrayProps);
WeaviateProtoBase.ObjectArrayProperties objectArrayProps = WeaviateProtoBase.ObjectArrayProperties.newBuilder()
.setPropName(propName).addAllValues(objectPropertiesValues).build();

objectArrayProperties.add(objectArrayProps);
}
}
}
return new Properties(nonRefProperties, numberArrayProperties, intArrayProperties, textArrayProperties,
booleanArrayProperties, objectProperties, objectArrayProperties);
booleanArrayProperties, objectProperties, objectArrayProperties, singleTargetRefProps, multiTargetRefProps);
}

private static boolean isCrossReference(List<?> propValue, boolean rootLevel) {
if (rootLevel) {
for (Object element : propValue) {
if (element instanceof Map) {
Map<?, ?> valueMap = ((Map<?, ?>) element);
if (valueMap.size() > 1 || (valueMap.size() == 1 && (valueMap.get("beacon") == null || !(valueMap.get("beacon") instanceof String)))) {
return false;
}
}
}
return true;
}
return false;
}

private static List<String> extractBeacons(List<?> propValue) {
List<String> beacons = new ArrayList<>();
for (Object element : propValue) {
if (element instanceof Map) {
Map<?, ?> valueMap = ((Map<?, ?>) element);
beacons.add((String) valueMap.get("beacon"));
}
}
return beacons;
}
}
47 changes: 47 additions & 0 deletions src/test/java/io/weaviate/client/base/util/CrossReferenceTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package io.weaviate.client.base.util;

import org.junit.Test;
import static org.assertj.core.api.Assertions.assertThat;

public class CrossReferenceTest {

@Test
public void testParseBeaconWithClass() {
// given
String beacon = "weaviate://localhost/RefClass/f81bfe5e-16ba-4615-a516-46c2ae2e5a80";
// when
CrossReference crossRef = CrossReference.fromBeacon(beacon);
// then
assertThat(crossRef).isNotNull().satisfies(cf -> {
assertThat(cf.isLocal()).isTrue();
assertThat(cf.getPeerName()).isEqualTo("localhost");
assertThat(cf.getClassName()).isEqualTo("RefClass");
assertThat(cf.getTargetID()).isEqualTo("f81bfe5e-16ba-4615-a516-46c2ae2e5a80");
});
}

@Test
public void testParseBeaconWithoutClass() {
// given
String beacon = "weaviate://localhost/f81bfe5e-16ba-4615-a516-46c2ae2e5a80";
// when
CrossReference crossRef = CrossReference.fromBeacon(beacon);
// then
assertThat(crossRef).isNotNull().satisfies(cf -> {
assertThat(cf.isLocal()).isTrue();
assertThat(cf.getPeerName()).isEqualTo("localhost");
assertThat(cf.getClassName()).isEqualTo("");
assertThat(cf.getTargetID()).isEqualTo("f81bfe5e-16ba-4615-a516-46c2ae2e5a80");
});
}

@Test
public void testParseBeaconEmpty() {
// given
String beacon = "";
// when
CrossReference crossRef = CrossReference.fromBeacon(beacon);
// then
assertThat(crossRef).isNull();
}
}
Loading

0 comments on commit 1ee55d1

Please sign in to comment.