Skip to content

Commit

Permalink
Lint tests
Browse files Browse the repository at this point in the history
  • Loading branch information
caufieldjh committed Jul 31, 2024
1 parent 0fe6ab7 commit 4f1e172
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 19 deletions.
2 changes: 1 addition & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for curate-gpt."""

import os
from pathlib import Path

Expand All @@ -8,4 +9,3 @@
OUTPUT_DIR = this_directory / "output"
OUTPUT_CHROMA_DB_PATH = OUTPUT_DIR / "db"
OUTPUT_DUCKDB_PATH = os.path.join(OUTPUT_DIR, "duckdbvss.db")

8 changes: 6 additions & 2 deletions tests/cli/test_store_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@

def test_store_management(runner):
# test index ontology with duckdb
result = runner.invoke(main, ["ontology", "index", ONT_DB, "-m", "all-MiniLM-L6-v2", "-c", "oai", "-D", "duckdb"])
result = runner.invoke(
main, ["ontology", "index", ONT_DB, "-m", "all-MiniLM-L6-v2", "-c", "oai", "-D", "duckdb"]
)
assert result.exit_code == 0
# test index ontology with chromadb
result = runner.invoke(main, ["ontology", "index", ONT_DB, "-D", "chromadb", "-m", "all-MiniLM-L6-v2", "-c", "oai"])
result = runner.invoke(
main, ["ontology", "index", ONT_DB, "-D", "chromadb", "-m", "all-MiniLM-L6-v2", "-c", "oai"]
)
assert result.exit_code == 0
result = runner.invoke(main, ["ontology", "index", ONT_DB, "-c", "default"])
assert result.exit_code == 0
Expand Down
50 changes: 34 additions & 16 deletions tests/store/test_duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@

def terms_to_objects(terms: list[str]) -> list[Dict]:
return [
{"id": f"ID:{i}", "text": t, "wordlen": len(t), "nested": {"wordlen": len(t)}, "vec": [float(i)] * 1536}
{
"id": f"ID:{i}",
"text": t,
"wordlen": len(t),
"nested": {"wordlen": len(t)},
"vec": [float(i)] * 1536,
}
for i, t in enumerate(terms)
]

Expand Down Expand Up @@ -66,14 +72,14 @@ def test_store(simple_schema_manager, example_texts):
db2 = DuckDBAdapter(str(OUTPUT_DUCKDB_PATH))
assert db2.collection_metadata(collection).description == "test collection"
assert db.list_collection_names() == ["test_collection"]
results = list(db.search("fox", collection=collection, include=['metadata']))
results = list(db.search("fox", collection=collection, include=["metadata"]))
db.update(objs, collection=collection)
assert db.collection_metadata(collection).description == "test collection"
long_words = list(db.find(where={"wordlen": {"$gt": 12}}, collection=collection))
assert len(long_words) == 2
db.remove_collection(collection)
db.insert(objs, collection=collection)
results2 = list(db.search("fox", collection=collection, include=['metadata']))
results2 = list(db.search("fox", collection=collection, include=["metadata"]))

def _id(id, _meta, _emb, _doc, _dist):
return id
Expand All @@ -85,6 +91,7 @@ def _id(id, _meta, _emb, _doc, _dist):
results2 = list(db.find({}, limit=10000000, collection=collection))
assert len(results2) > limit


def test_the_embedding_function(simple_schema_manager, example_texts):
db = DuckDBAdapter(OUTPUT_DUCKDB_PATH)
db.conn.execute("DROP TABLE IF EXISTS test_collection")
Expand Down Expand Up @@ -150,25 +157,25 @@ def test_ontology_matches(ontology_db):
"id": first_obj.id,
"metadata": {"updated_key": "updated_value"},
"embeddings": [0.1] * len(first_obj.embeddings),
"documents": "Updated document text"
"documents": "Updated document text",
}
ontology_db.update([updated_obj], collection="test_collection")
# verify update
result = ontology_db.lookup(first_obj.id, collection="test_collection")
assert result['metadata'] == {"updated_key": "updated_value"}
assert result["metadata"] == {"updated_key": "updated_value"}
assert result == updated_obj
updated_results = list(ontology_db.matches(updated_obj))
assert len(updated_results) == 10
for res in updated_results:
if res.id == first_obj.id:
assert res.metadata == updated_obj
assert res.metadata['metadata'] == {"updated_key": "updated_value"}
assert res.metadata["metadata"] == {"updated_key": "updated_value"}
# test upsert
new_obj = {
"id": "new_id",
"metadata": {"new_key": "new_value"},
"embeddings": [0.5] * len(first_obj.embeddings),
"documents": "New document text"
"documents": "New document text",
}
ontology_db.upsert([new_obj], collection="test_collection")
# verify upsert
Expand All @@ -186,19 +193,21 @@ def test_ontology_matches(ontology_db):
)
def test_where_queries(loaded_ontology_db, where, num_expected, limit, include):
db = loaded_ontology_db
results = list(db.find(where=where, limit=limit, collection="other_collection", include=include))
results = list(
db.find(where=where, limit=limit, collection="other_collection", include=include)
)
assert len(results) == num_expected
for res in results:
if include:
if 'id' in include:
if "id" in include:
assert res.id is not None
if 'metadata' in include:
if "metadata" in include:
assert res.metadata is not None
if 'embeddings' in include:
if "embeddings" in include:
assert res.embeddings is not None
if 'documents' in include:
if "documents" in include:
assert res.documents is not None
if 'distance' in include:
if "distance" in include:
assert res.distance is not None
else:
assert res.id is not None
Expand All @@ -214,7 +223,10 @@ def test_load_in_batches(ontology_db):
sliced_gen = list(itertools.islice(view.objects(), 3))
ontology_db.insert(sliced_gen, batch_size=10, collection="other_collection")
objs = list(
ontology_db.find(where={"original_id": {"$eq": "BFO:0000002"}}, collection="other_collection", limit=2000))
ontology_db.find(
where={"original_id": {"$eq": "BFO:0000002"}}, collection="other_collection", limit=2000
)
)
assert len(objs) == 1


Expand All @@ -231,7 +243,10 @@ def combo_db(example_combo_texts) -> DuckDBAdapter:
def test_diversified_search(combo_db):
relevance_factor = 0.5
results = combo_db.search(
"pineapple helicopter 5", collection="test_collection", relevance_factor=relevance_factor, limit=20
"pineapple helicopter 5",
collection="test_collection",
relevance_factor=relevance_factor,
limit=20,
)
for i, res in enumerate(results):
obj, distance, id_field, doc = res.metadata, res.distance, res.id, res.documents
Expand All @@ -245,6 +260,9 @@ def test_diversified_search(combo_db):
def test_diversified_search_on_empty_db(empty_db):
relevance_factor = 0.5
results = empty_db.search(
"pineapple helicopter 5", collection="test_collection", relevance_factor=relevance_factor, limit=20
"pineapple helicopter 5",
collection="test_collection",
relevance_factor=relevance_factor,
limit=20,
)
assert len(list(results)) == 0

0 comments on commit 4f1e172

Please sign in to comment.