forked from TransformerOptimus/SuperAGI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* untested lance implrementation * Update lancedb.py * 2nd untested implementation * Update lancedb.py * added open_table workaround * Update requirements.txt * unit tests for lance
- Loading branch information
Showing
2 changed files
with
110 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import numpy as np | ||
import shutil | ||
import pytest | ||
|
||
import lancedb | ||
from superagi.vector_store.lancedb import LanceDB | ||
from superagi.vector_store.document import Document | ||
from superagi.vector_store.embedding.openai import OpenAiEmbedding | ||
|
||
|
||
@pytest.fixture | ||
def client(): | ||
db = lancedb.connect(".test_lancedb") | ||
yield db | ||
shutil.rmtree(".test_lancedb") | ||
|
||
|
||
@pytest.fixture | ||
def mock_openai_embedding(monkeypatch): | ||
monkeypatch.setattr( | ||
OpenAiEmbedding, | ||
"get_embedding", | ||
lambda self, text: np.random.random(3).tolist(), | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def store(client, mock_openai_embedding): | ||
yield LanceDB(client, OpenAiEmbedding(api_key="test_api_key"), "text") | ||
|
||
|
||
@pytest.fixture | ||
def dataset(): | ||
book_titles = [ | ||
"The Great Gatsby", | ||
"To Kill a Mockingbird", | ||
"1984", | ||
"Pride and Prejudice", | ||
"The Catcher in the Rye", | ||
] | ||
|
||
documents = [] | ||
for i, title in enumerate(book_titles): | ||
author = f"Author {i}" | ||
description = f"A summary of {title}" | ||
text_content = f"This is the text for {title}" | ||
metadata = {"author": author, "description": description} | ||
document = Document(text_content=text_content, metadata=metadata) | ||
|
||
documents.append(document) | ||
|
||
return documents | ||
|
||
|
||
@pytest.fixture | ||
def dataset_no_metadata(): | ||
book_titles = [ | ||
"The Lord of the Rings", | ||
"The Hobbit", | ||
"The Chronicles of Narnia", | ||
] | ||
|
||
documents = [] | ||
for title in book_titles: | ||
text_content = f"This is the text for {title}" | ||
document = Document(text_content=text_content) | ||
documents.append(document) | ||
|
||
return documents | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data, results, table_name", | ||
[ | ||
("dataset", (5, 2), "test_table"), | ||
("dataset_no_metadata", (3, 0), "test_table_no_metadata"), | ||
], | ||
) | ||
def test_add_texts(store, client, data, results, table_name, request): | ||
dataset = request.getfixturevalue(data) | ||
count, meta_count = results | ||
ids = store.add_documents(dataset, table_name=table_name) | ||
assert len(ids) == count | ||
|
||
tbl = client.open_table(table_name) | ||
assert len(tbl.to_pandas().columns) - 3 == meta_count | ||
# Subtracting 3 because of the id, vector, and text columns. The rest | ||
# should be metadata columns. | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data, search_text, table_name, index", | ||
[ | ||
("dataset", "The Great Gatsby", "test_table", 0), | ||
("dataset", "1984", "test_table2", 2), | ||
("dataset_no_metadata", "The Hobbit", "test_table_no_metadata", 1), | ||
], | ||
) | ||
def test_get_matching_text(store, data, search_text, table_name, index, request): | ||
print("SEARCHING FOR " + search_text) | ||
dataset = request.getfixturevalue(data) | ||
store.add_documents(dataset, table_name=table_name) | ||
results = store.get_matching_text(search_text, top_k=2, namespace=table_name) | ||
print(results[0]) | ||
assert len(results) == 2 | ||
assert results[0] == dataset[index] |