Skip to content

Commit

Permalink
Merge pull request #264 from weaviate/add-e2e-test-for-named-vectors
Browse files Browse the repository at this point in the history
Add test for multiple vectors class
  • Loading branch information
antas-marcin authored Mar 5, 2024
2 parents 7e7fea6 + 1765604 commit 69d0366
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public void testAuthAzure() throws AuthException {
Result<Meta> meta = client.misc().metaGetter().run();
assertNotNull(meta);
assertNull(meta.getError());
assertEquals("http://[::]:8081", meta.getResult().getHostname());
assertEquals("http://[::]:8080", meta.getResult().getHostname());
assertEquals(EXPECTED_WEAVIATE_VERSION, meta.getResult().getVersion());
} else {
System.out.println("Skipping Azure Client Credentials test, missing AZURE_CLIENT_SECRET");
Expand All @@ -59,7 +59,7 @@ public void testAuthAzureHardcodedScope() throws AuthException {
Result<Meta> meta = client.misc().metaGetter().run();
assertNotNull(meta);
assertNull(meta.getError());
assertEquals("http://[::]:8081", meta.getResult().getHostname());
assertEquals("http://[::]:8080", meta.getResult().getHostname());
assertEquals(EXPECTED_WEAVIATE_VERSION, meta.getResult().getVersion());
} else {
System.out.println("Skipping Azure Client Credentials test, missing AZURE_CLIENT_SECRET");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public void testAuthOkta() throws AuthException {
Result<Meta> meta = client.misc().metaGetter().run();
assertNotNull(meta);
assertNull(meta.getError());
assertEquals("http://[::]:8082", meta.getResult().getHostname());
assertEquals("http://[::]:8080", meta.getResult().getHostname());
assertEquals(EXPECTED_WEAVIATE_VERSION, meta.getResult().getVersion());
} else {
System.out.println("Skipping Okta Client Credentials test, missing OKTA_CLIENT_SECRET");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void testAuthOktaNoScope() throws AuthException {
Result<Meta> meta = client.misc().metaGetter().run();
assertNotNull(meta);
assertNull(meta.getError());
assertEquals("http://[::]:8083", meta.getResult().getHostname());
assertEquals("http://[::]:8080", meta.getResult().getHostname());
assertEquals(EXPECTED_WEAVIATE_VERSION, meta.getResult().getVersion());
} else {
System.out.println("Skipping Okta test, missing OKTA_DUMMY_CI_PW");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void testAuthWCS() throws AuthException {
Result<Meta> meta = client.misc().metaGetter().run();
assertNotNull(meta);
assertNull(meta.getError());
assertEquals("http://[::]:8085", meta.getResult().getHostname());
assertEquals("http://[::]:8080", meta.getResult().getHostname());
assertEquals(EXPECTED_WEAVIATE_VERSION, meta.getResult().getVersion());
} else {
System.out.println("Skipping WCS test, missing WCS_DUMMY_CI_PW");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ protected AccessTokenProvider getTokenProvider(Config config, BaseAuth.AuthRespo
Result<Meta> meta = client.misc().metaGetter().run();
assertThat(meta).isNotNull();
assertThat(meta.getError()).isNull();
assertThat(meta.getResult().getHostname()).isEqualTo("http://[::]:8082");
assertThat(meta.getResult().getHostname()).isEqualTo("http://[::]:8080");
assertThat(meta.getResult().getVersion()).isEqualTo(WeaviateVersion.EXPECTED_WEAVIATE_VERSION);
Thread.sleep(3000l);
// get the access token after refresh
Expand All @@ -70,7 +70,7 @@ protected AccessTokenProvider getTokenProvider(Config config, BaseAuth.AuthRespo
meta = client.misc().metaGetter().run();
assertThat(meta).isNotNull();
assertThat(meta.getError()).isNull();
assertThat(meta.getResult().getHostname()).isEqualTo("http://[::]:8082");
assertThat(meta.getResult().getHostname()).isEqualTo("http://[::]:8080");
assertThat(meta.getResult().getVersion()).isEqualTo(WeaviateVersion.EXPECTED_WEAVIATE_VERSION);
} else {
System.out.println("Skipping Okta Client Credentials refresh token test, missing OKTA_CLIENT_SECRET");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ protected AccessTokenProvider getTokenProvider(Config config, BaseAuth.AuthRespo
Result<Meta> meta = client.misc().metaGetter().run();
assertThat(meta).isNotNull();
assertThat(meta.getError()).isNull();
assertThat(meta.getResult().getHostname()).isEqualTo("http://[::]:8085");
assertThat(meta.getResult().getHostname()).isEqualTo("http://[::]:8080");
assertThat(meta.getResult().getVersion()).isEqualTo(WeaviateVersion.EXPECTED_WEAVIATE_VERSION);
Thread.sleep(3000l);
// get the access token after refresh
Expand All @@ -70,7 +70,7 @@ protected AccessTokenProvider getTokenProvider(Config config, BaseAuth.AuthRespo
meta = client.misc().metaGetter().run();
assertThat(meta).isNotNull();
assertThat(meta.getError()).isNull();
assertThat(meta.getResult().getHostname()).isEqualTo("http://[::]:8085");
assertThat(meta.getResult().getHostname()).isEqualTo("http://[::]:8080");
assertThat(meta.getResult().getVersion()).isEqualTo(WeaviateVersion.EXPECTED_WEAVIATE_VERSION);
} else {
System.out.println("Skipping WCS Refresh Token test, missing WCS_DUMMY_CI_PW");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package io.weaviate.integration.client.batch;

import io.weaviate.client.Config;
import io.weaviate.client.WeaviateClient;
import io.weaviate.client.base.Result;
import io.weaviate.client.v1.batch.model.ObjectGetResponse;
import io.weaviate.client.v1.data.model.WeaviateObject;
import io.weaviate.client.v1.misc.model.BQConfig;
import io.weaviate.client.v1.misc.model.VectorIndexConfig;
import io.weaviate.client.v1.schema.model.DataType;
import io.weaviate.client.v1.schema.model.Property;
import io.weaviate.client.v1.schema.model.WeaviateClass;
import java.io.File;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.InstanceOfAssertFactories.ARRAY;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.testcontainers.containers.ComposeContainer;
import org.testcontainers.containers.wait.strategy.Wait;

public class ClientBatchGrpcCreateNamedVectorsTest {
private static String host;
private static Integer port;
private static String grpcHost;
private static Integer grpcPort;

@ClassRule
public static ComposeContainer compose = new ComposeContainer(
new File("src/test/resources/docker-compose-test.yaml")
).withExposedService("weaviate-1", 8080, Wait.forListeningPorts(8080))
.withExposedService("weaviate-1", 50051, Wait.forListeningPorts(50051))
.withTailChildContainers(true);

@Before
public void before() {
host = compose.getServiceHost("weaviate-1", 8080);
port = compose.getServicePort("weaviate-1", 8080);
grpcHost = compose.getServiceHost("weaviate-1", 50051);
grpcPort = compose.getServicePort("weaviate-1", 50051);
}

@Test
public void shouldCreateObjectsWithNamedVectors() {
WeaviateClient client = createClient();
String className = "NamedVectors";
List<Property> properties = Arrays.asList(
Property.builder()
.name("name")
.dataType(Collections.singletonList(DataType.TEXT))
.build(),
Property.builder()
.name("title")
.dataType(Collections.singletonList(DataType.TEXT))
.build());
Map<String, Object> none = new HashMap<>();
none.put("none", new Object());
Map<String, Object> text2vecContextionary = new HashMap<>();
Map<String, Object> text2vecContextionarySettings = new HashMap<>();
text2vecContextionarySettings.put("vectorizeClassName", false);
text2vecContextionarySettings.put("properties", new String[]{"title"});
text2vecContextionary.put("text2vec-contextionary", text2vecContextionarySettings);
Map<String, WeaviateClass.VectorConfig> vectorConfig = new HashMap<>();
vectorConfig.put("hnswVector", WeaviateClass.VectorConfig.builder()
.vectorIndexType("hnsw")
.vectorizer(none)
.build());
vectorConfig.put("c11y", WeaviateClass.VectorConfig.builder()
.vectorIndexType("flat")
.vectorizer(text2vecContextionary)
.vectorIndexConfig(VectorIndexConfig.builder()
.bq(BQConfig.builder().enabled(true).build())
.build())
.build());
Result<Boolean> createResult = client.schema().classCreator()
.withClass(WeaviateClass.builder()
.className(className)
.properties(properties)
.vectorConfig(vectorConfig)
.build()
)
.run();
assertThat(createResult).isNotNull()
.returns(false, Result::hasErrors)
.returns(true, Result::getResult);

// create object
String id = "00000000-0000-0000-0000-000000000001";
Map<String, Object> props = new HashMap<>();
props.put("name", "some name");
props.put("title", "The Lord of the Rings");
Float[] vector = new Float[]{0.11f, 0.22f, 0.33f, 0.123f, -0.900009f, -0.0000000001f};
Map<String, Float[]> vectors = new HashMap<>();
vectors.put("hnswVector", vector);
WeaviateObject obj = WeaviateObject.builder()
.id(id)
.className(className)
.properties(props)
.vectors(vectors)
.build();
Result<ObjectGetResponse[]> result = client.batch().objectsBatcher()
.withObjects(obj)
.run();
assertThat(result).isNotNull()
.returns(false, Result::hasErrors)
.extracting(Result::getResult).asInstanceOf(ARRAY)
.hasSize(1);

// fetch that object
Result<List<WeaviateObject>> resultObj = client.data().objectsGetter()
.withID(obj.getId())
.withClassName(obj.getClassName())
.withVector()
.run();
assertThat(resultObj).isNotNull()
.returns(false, Result::hasErrors)
.extracting(Result::getResult).isNotNull()
.extracting(r -> r.get(0)).isNotNull()
.satisfies(o -> {
assertThat(o.getId()).isEqualTo(obj.getId());
assertThat(o.getVectors()).isNotEmpty()
.containsOnlyKeys("hnswVector", "c11y")
.satisfies(vecs -> {
assertThat(vecs.get("hnswVector")).isNotNull().isEqualTo(vector);
assertThat(vecs.get("c11y")).isNotEmpty();
});
assertThat(o.getProperties()).isNotNull()
.extracting(Map::size).isEqualTo(obj.getProperties().size());
obj.getProperties().keySet().forEach(propName -> {
assertThat(o.getProperties().get(propName)).isNotNull();
});
});

// clean up
Result<Boolean> delete = client.schema().classDeleter().withClassName(className).run();
assertThat(delete).isNotNull()
.returns(false, Result::hasErrors)
.returns(true, Result::getResult);
}

private WeaviateClient createClient() {
Config config = new Config("http", host + ":" + port);
config.setGRPCSecured(false);
config.setGRPCHost(grpcHost + ":" + grpcPort);
return new WeaviateClient(config);
}
}

0 comments on commit 69d0366

Please sign in to comment.