From 79fbd6383b168cf152694e7f9aaff3d0ba8d9bc1 Mon Sep 17 00:00:00 2001 From: Leon Yee <43097991+unkn-wn@users.noreply.github.com> Date: Mon, 26 Jun 2023 14:22:14 -0700 Subject: [PATCH] unit tests for lance (#3) * untested lance implrementation * Update lancedb.py * 2nd untested implementation * Update lancedb.py * added open_table workaround * Update requirements.txt * unit tests for lance --- superagi/vector_store/lancedb.py | 7 +- .../vector_store/test_lancedb.py | 106 ++++++++++++++++++ 2 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 tests/integration_tests/vector_store/test_lancedb.py diff --git a/superagi/vector_store/lancedb.py b/superagi/vector_store/lancedb.py index af2ce2f12..7112d162c 100644 --- a/superagi/vector_store/lancedb.py +++ b/superagi/vector_store/lancedb.py @@ -87,7 +87,7 @@ def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[D Returns: The list of documents most similar to the query """ - namespace = kwargs.get("namespace", self.namespace) + namespace = kwargs.get("namespace", "None") for table in self.db.table_names(): if table == namespace: @@ -96,17 +96,18 @@ def get_matching_text(self, query: str, top_k: int = 5, **kwargs: Any) -> List[D try: tbl except: - raise ValueError(namespace + " Table not found in LanceDB.") + raise ValueError(namespace + " Table not found in LanceDB. Please call this function with a valid table name.") embed_text = self.embedding_model.get_embedding(query) res = tbl.search(embed_text).limit(top_k).to_df() + print(res) documents = [] for i in range(len(res)): meta = {} for col in res: - if col != 'vector' and col != 'id': + if col != 'vector' and col != 'id' and col != 'score': meta[col] = res[col][i] documents.append( diff --git a/tests/integration_tests/vector_store/test_lancedb.py b/tests/integration_tests/vector_store/test_lancedb.py new file mode 100644 index 000000000..73719add5 --- /dev/null +++ b/tests/integration_tests/vector_store/test_lancedb.py @@ -0,0 +1,106 @@ +import numpy as np +import shutil +import pytest + +import lancedb +from superagi.vector_store.lancedb import LanceDB +from superagi.vector_store.document import Document +from superagi.vector_store.embedding.openai import OpenAiEmbedding + + +@pytest.fixture +def client(): + db = lancedb.connect(".test_lancedb") + yield db + shutil.rmtree(".test_lancedb") + + +@pytest.fixture +def mock_openai_embedding(monkeypatch): + monkeypatch.setattr( + OpenAiEmbedding, + "get_embedding", + lambda self, text: np.random.random(3).tolist(), + ) + + +@pytest.fixture +def store(client, mock_openai_embedding): + yield LanceDB(client, OpenAiEmbedding(api_key="test_api_key"), "text") + + +@pytest.fixture +def dataset(): + book_titles = [ + "The Great Gatsby", + "To Kill a Mockingbird", + "1984", + "Pride and Prejudice", + "The Catcher in the Rye", + ] + + documents = [] + for i, title in enumerate(book_titles): + author = f"Author {i}" + description = f"A summary of {title}" + text_content = f"This is the text for {title}" + metadata = {"author": author, "description": description} + document = Document(text_content=text_content, metadata=metadata) + + documents.append(document) + + return documents + + +@pytest.fixture +def dataset_no_metadata(): + book_titles = [ + "The Lord of the Rings", + "The Hobbit", + "The Chronicles of Narnia", + ] + + documents = [] + for title in book_titles: + text_content = f"This is the text for {title}" + document = Document(text_content=text_content) + documents.append(document) + + return documents + + +@pytest.mark.parametrize( + "data, results, table_name", + [ + ("dataset", (5, 2), "test_table"), + ("dataset_no_metadata", (3, 0), "test_table_no_metadata"), + ], +) +def test_add_texts(store, client, data, results, table_name, request): + dataset = request.getfixturevalue(data) + count, meta_count = results + ids = store.add_documents(dataset, table_name=table_name) + assert len(ids) == count + + tbl = client.open_table(table_name) + assert len(tbl.to_pandas().columns) - 3 == meta_count + # Subtracting 3 because of the id, vector, and text columns. The rest + # should be metadata columns. + + +@pytest.mark.parametrize( + "data, search_text, table_name, index", + [ + ("dataset", "The Great Gatsby", "test_table", 0), + ("dataset", "1984", "test_table2", 2), + ("dataset_no_metadata", "The Hobbit", "test_table_no_metadata", 1), + ], +) +def test_get_matching_text(store, data, search_text, table_name, index, request): + print("SEARCHING FOR " + search_text) + dataset = request.getfixturevalue(data) + store.add_documents(dataset, table_name=table_name) + results = store.get_matching_text(search_text, top_k=2, namespace=table_name) + print(results[0]) + assert len(results) == 2 + assert results[0] == dataset[index]