Skip to content

Commit

Permalink
unit tests for lance (#3)
Browse files Browse the repository at this point in the history
* untested lance implrementation

* Update lancedb.py

* 2nd untested implementation

* Update lancedb.py

* added open_table workaround

* Update requirements.txt

* unit tests for lance
  • Loading branch information
unkn-wn authored Jun 26, 2023
1 parent 48c13eb commit 79fbd63
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 3 deletions.
7 changes: 4 additions & 3 deletions superagi/vector_store/lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[D
Returns:
The list of documents most similar to the query
"""
namespace = kwargs.get("namespace", self.namespace)
namespace = kwargs.get("namespace", "None")

for table in self.db.table_names():
if table == namespace:
Expand All @@ -96,17 +96,18 @@ def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[D
try:
tbl
except:
raise ValueError(namespace + " Table not found in LanceDB.")
raise ValueError(namespace + " Table not found in LanceDB. Please call this function with a valid table name.")

embed_text = self.embedding_model.get_embedding(query)
res = tbl.search(embed_text).limit(top_k).to_df()
print(res)

documents = []

for i in range(len(res)):
meta = {}
for col in res:
if col != 'vector' and col != 'id':
if col != 'vector' and col != 'id' and col != 'score':
meta[col] = res[col][i]

documents.append(
Expand Down
106 changes: 106 additions & 0 deletions tests/integration_tests/vector_store/test_lancedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import numpy as np
import shutil
import pytest

import lancedb
from superagi.vector_store.lancedb import LanceDB
from superagi.vector_store.document import Document
from superagi.vector_store.embedding.openai import OpenAiEmbedding


@pytest.fixture
def client():
db = lancedb.connect(".test_lancedb")
yield db
shutil.rmtree(".test_lancedb")


@pytest.fixture
def mock_openai_embedding(monkeypatch):
monkeypatch.setattr(
OpenAiEmbedding,
"get_embedding",
lambda self, text: np.random.random(3).tolist(),
)


@pytest.fixture
def store(client, mock_openai_embedding):
yield LanceDB(client, OpenAiEmbedding(api_key="test_api_key"), "text")


@pytest.fixture
def dataset():
book_titles = [
"The Great Gatsby",
"To Kill a Mockingbird",
"1984",
"Pride and Prejudice",
"The Catcher in the Rye",
]

documents = []
for i, title in enumerate(book_titles):
author = f"Author {i}"
description = f"A summary of {title}"
text_content = f"This is the text for {title}"
metadata = {"author": author, "description": description}
document = Document(text_content=text_content, metadata=metadata)

documents.append(document)

return documents


@pytest.fixture
def dataset_no_metadata():
book_titles = [
"The Lord of the Rings",
"The Hobbit",
"The Chronicles of Narnia",
]

documents = []
for title in book_titles:
text_content = f"This is the text for {title}"
document = Document(text_content=text_content)
documents.append(document)

return documents


@pytest.mark.parametrize(
"data, results, table_name",
[
("dataset", (5, 2), "test_table"),
("dataset_no_metadata", (3, 0), "test_table_no_metadata"),
],
)
def test_add_texts(store, client, data, results, table_name, request):
dataset = request.getfixturevalue(data)
count, meta_count = results
ids = store.add_documents(dataset, table_name=table_name)
assert len(ids) == count

tbl = client.open_table(table_name)
assert len(tbl.to_pandas().columns) - 3 == meta_count
# Subtracting 3 because of the id, vector, and text columns. The rest
# should be metadata columns.


@pytest.mark.parametrize(
"data, search_text, table_name, index",
[
("dataset", "The Great Gatsby", "test_table", 0),
("dataset", "1984", "test_table2", 2),
("dataset_no_metadata", "The Hobbit", "test_table_no_metadata", 1),
],
)
def test_get_matching_text(store, data, search_text, table_name, index, request):
print("SEARCHING FOR " + search_text)
dataset = request.getfixturevalue(data)
store.add_documents(dataset, table_name=table_name)
results = store.get_matching_text(search_text, top_k=2, namespace=table_name)
print(results[0])
assert len(results) == 2
assert results[0] == dataset[index]

0 comments on commit 79fbd63

Please sign in to comment.