Skip to content

Commit

Permalink
format sql identifiers, rework diversified_search, rework cosine/dist…
Browse files Browse the repository at this point in the history
…ance
  • Loading branch information
iQuxLE committed Jul 30, 2024
1 parent 8291f8a commit 2e5d816
Showing 1 changed file with 72 additions and 72 deletions.
144 changes: 72 additions & 72 deletions src/curate_gpt/store/duckdb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
This is a DuckDB adapter for the Vector Similarity Search (VSS) extension
using the experimental persistence feature
"""
import itertools

import yaml
import logging
import os
Expand All @@ -28,6 +26,7 @@
from curate_gpt.store.db_adapter import DBAdapter, OBJECT, QUERY, PROJECTION, SEARCH_RESULT
from curate_gpt.store.duckdb_result import DuckDBSearchResult
from curate_gpt.store.metadata import CollectionMetadata
from curate_gpt.utils.vector_algorithms import mmr_diversified_search

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,7 +61,7 @@ class DuckDBAdapter(DBAdapter):

def __post_init__(self):
if not self.path:
self.path="./duck.db"
self.path = "./duck.db"
os.makedirs(os.path.dirname(self.path), exist_ok=True)
self.ef_construction = self._validate_ef_construction(self.ef_construction)
self.ef_search = self._validate_ef_search(self.ef_search)
Expand Down Expand Up @@ -101,8 +100,9 @@ def _create_table_if_not_exists(self, collection: str, vec_dimension: int, model
:param collection:
:return:
"""
safe_collection_name = f'"{collection}"'
create_table_sql = f"""
CREATE TABLE IF NOT EXISTS {collection} (
CREATE TABLE IF NOT EXISTS {safe_collection_name} (
id VARCHAR PRIMARY KEY,
metadata JSON,
embeddings FLOAT[{vec_dimension}],
Expand All @@ -117,14 +117,15 @@ def _create_table_if_not_exists(self, collection: str, vec_dimension: int, model
else:
metadata = CollectionMetadata(name=collection, model=self.default_model)
metadata_json = json.dumps(metadata.model_dump(exclude_none=True))
safe_collection_name = f'"{collection}"'
self.conn.execute(
f"""
INSERT INTO {collection} (id, metadata) VALUES ('__metadata__', ?)
INSERT INTO {safe_collection_name} (id, metadata) VALUES ('__metadata__', ?)
ON CONFLICT (id) DO NOTHING
""", [metadata_json]
)

def create_index(self, collection: str):
def create_index(self, collection: str, metric: str = "cosine"):
"""
Create an index for the given collection
Parameters
Expand All @@ -136,18 +137,20 @@ def create_index(self, collection: str):
"""

metrics = ['l2sq', 'cosine', 'ip'] #l2sq,ip
for metric in metrics:
create_index_sql = f"""
CREATE INDEX IF NOT EXISTS idx_{collection}_embeddings_{metric} ON {collection}
USING HNSW (embeddings) WITH (
metric='{metric}',
ef_construction={self.ef_construction},
ef_search={self.ef_search},
M={self.M}
)
"""
self.conn.execute(create_index_sql)
# metrics = ['l2sq', 'cosine', 'ip'] #l2sq,ip
# for metric in metrics:
safe_collection_name = f'"{collection}"'
index_name = f"{collection}_index"
create_index_sql = f"""
CREATE INDEX IF NOT EXISTS "{index_name}" ON {safe_collection_name}
USING HNSW (embeddings) WITH (
metric='{metric}',
ef_construction={self.ef_construction},
ef_search={self.ef_search},
M={self.M}
)
"""
self.conn.execute(create_index_sql)

def _embedding_function(self, texts: Union[str, List[str]], model: str = None) -> list:
"""
Expand Down Expand Up @@ -196,7 +199,8 @@ def update(self, objs: Union[OBJECT, Iterable[OBJECT]], **kwargs):
"""
collection = kwargs.get('collection')
ids = [self._id(o, self.id_field) for o in objs]
delete_sql = f"DELETE FROM {collection} WHERE id = ?"
safe_collection_name = f'"{collection}"'
delete_sql = f"DELETE FROM {safe_collection_name} WHERE id = ?"
self.conn.executemany(delete_sql, [(id_,) for id_ in ids])
self.insert(objs, **kwargs)

Expand All @@ -211,7 +215,8 @@ def upsert(self, objs: Union[OBJECT, Iterable[OBJECT]], **kwargs):
ids = [self._id(o, self.id_field) for o in objs]
existing_ids = set()
for id_ in ids:
result = self.conn.execute(f"SELECT id FROM {collection} WHERE id = ?",[id_]).fetchall()
safe_collection_name = f'"{collection}"'
result = self.conn.execute(f"SELECT id FROM {safe_collection_name} WHERE id = ?", [id_]).fetchall()
if result:
existing_ids.add(id_)
objs_to_update = [o for o in objs if self._id(o, self.id_field) in existing_ids]
Expand Down Expand Up @@ -260,6 +265,7 @@ def _process_objects(
text_field = self.text_lookup
id_field = self.id_field
num_objs = len(objs) if isinstance(objs, list) else "?"
logger.info(f"Processing {len(num_objs)} objects")
cumulative_len = 0
sql_command = self._generate_sql_command(collection, method)
sql_command = sql_command.format(collection=collection)
Expand All @@ -275,6 +281,7 @@ def _process_objects(
metadatas = [self._dict(o) for o in next_objs]
ids = [self._id(o, id_field) for o in next_objs]
embeddings = self._embedding_function(docs, model)
logger.info(f"Processing object {next_objs[0]} of {num_objs}")
try:
self.conn.execute("BEGIN TRANSACTION;")
self.conn.executemany(sql_command, list(zip(ids, metadatas, embeddings, docs)))
Expand All @@ -285,15 +292,15 @@ def _process_objects(
raise
self.create_index(collection)

def _generate_sql_command(self, collection: str, method: str) -> str:
def _generate_sql_command(self, collection: str, method: str) -> str:
safe_collection_name = f'"{collection}"'
if method == "insert":
return f"""
INSERT INTO {collection} (id,metadata, embeddings, documents) VALUES (?, ?, ?, ?)
INSERT INTO {safe_collection_name} (id,metadata, embeddings, documents) VALUES (?, ?, ?, ?)
"""
else:
raise ValueError(f"Unknown method: {method}")


def remove_collection(self, collection: str = None, exists_ok=False, **kwargs):
"""
Remove the collection from the database
Expand All @@ -306,7 +313,9 @@ def remove_collection(self, collection: str = None, exists_ok=False, **kwargs):
if not exists_ok:
if collection not in self.list_collection_names():
raise ValueError(f"Collection {collection} does not exist")
self.conn.execute(f"DROP TABLE IF EXISTS {collection}")
# duckdb, requires that identifiers containing special characters ("-") must be enclosed in double quotes.
safe_collection_name = f'"{collection}"'
self.conn.execute(f"DROP TABLE IF EXISTS {safe_collection_name}")

def search(self, text: str, where: QUERY = None, collection: str = None, limit: int = 10,
relevance_factor: float = None, **kwargs) -> Iterator[SEARCH_RESULT]:
Expand Down Expand Up @@ -342,10 +351,12 @@ def _search(self, text: str, where: QUERY = None, collection: str = None, limit:
yield from self._diversified_search(text, where, collection, limit, relevance_factor, **kwargs)
return
query_embedding = self._embedding_function(text, model)
print("embedding", query_embedding)
safe_collection_name = f'"{collection}"'
results = self.conn.execute(f"""
SELECT *, array_cosine_similarity(embeddings::FLOAT[{self.vec_dimension}],
{query_embedding}::FLOAT[{self.vec_dimension}]) as distance
FROM {collection}
SELECT *, (1 - array_cosine_similarity(embeddings::FLOAT[{self.vec_dimension}],
{query_embedding}::FLOAT[{self.vec_dimension}])) as distance
FROM {safe_collection_name}
{where_clause}
ORDER BY distance
LIMIT ?
Expand All @@ -372,7 +383,8 @@ def collection_metadata(self, collection_name: Optional[str] = None, include_der
:param kwargs:
:return:
"""
result = self.conn.execute(f"SELECT metadata FROM {collection_name} WHERE id = '__metadata__'").fetchone()
safe_collection_name = f'"{collection_name}"'
result = self.conn.execute(f"SELECT metadata FROM {safe_collection_name} WHERE id = '__metadata__'").fetchone()
if result:
metadata = json.loads(result[0])
return CollectionMetadata(**metadata)
Expand All @@ -394,9 +406,10 @@ def update_collection_metadata(self, collection: str, **kwargs):
else:
current_metadata = current_metadata.model_copy(update=kwargs)
metadata_json = json.dumps(current_metadata.model_dump(exclude_none=True))
safe_collection_name = f'"{collection}"'
self.conn.execute(
f"""
UPDATE {collection}
UPDATE {safe_collection_name}
SET metadata = ?
WHERE id = '__metadata__'
""", [metadata_json]
Expand All @@ -417,21 +430,15 @@ def set_collection_metadata(
raise ValueError("Collection name must be provided.")

metadata_json = json.dumps(metadata.dict(exclude_none=True))
safe_collection_name = f'"{collection_name}"'
self.conn.execute(
f"""
UPDATE {collection_name}
UPDATE {safe_collection_name}
SET metadata = ?
WHERE id = '__metadata__'
""", [metadata_json]
)

def _object_metadata(self, obj):
"""
Convert the object to a dictionary.
"""
obj = self._dict(obj)
return obj

def find(self, where: QUERY = None, projection: PROJECTION = None, collection: str = None,
include: Optional[List[str]] = None, limit: int = 10, **kwargs) -> Iterator[
DuckDBSearchResult]:
Expand All @@ -453,9 +460,10 @@ def find(self, where: QUERY = None, projection: PROJECTION = None, collection: s
collection = self._get_collection(collection)
where_clause = self._parse_where_clause(where) if where else ""
where_clause = f"WHERE {where_clause}" if where_clause else ""
safe_collection_name = f'"{collection}"'
query = f"""
SELECT id, metadata, embeddings, documents, NULL as distance
FROM {collection}
FROM {safe_collection_name}
{where_clause}
LIMIT ?
"""
Expand Down Expand Up @@ -484,9 +492,10 @@ def lookup(self, id: str, collection: str = None, **kwargs) -> OBJECT:
:param kwargs:
:return:
"""
safe_collection_name = f'"{collection}"'
result = self.conn.execute(f"""
SELECT *
FROM {collection}
FROM {safe_collection_name}
WHERE id = ?
""", [id]).fetchone()
if isinstance(result, tuple) and len(result) > 1:
Expand All @@ -502,9 +511,10 @@ def peek(self, collection: str = None, limit=5, **kwargs) -> Iterator[OBJECT]:
:param kwargs:
:return:
"""
safe_collection_name = f'"{collection}"'
results = self.conn.execute(f"""
SELECT id, metadata, embeddings, documents, NULL as distance
FROM {collection}
FROM {safe_collection_name}
LIMIT ?
""", [limit]).fetchall()

Expand All @@ -523,7 +533,8 @@ def dump(self, collection: str = None, to_file: Union[str, Path] = None, format:
raise ValueError("Collection name must be provided.")

collection_name = self._get_collection_name(collection)
query = f"SELECT id, embeddings, metadata, documents FROM {collection_name}"
safe_collection_name = f'"{collection_name}"'
query = f"SELECT id, embeddings, metadata, documents FROM {safe_collection_name}"
data = self.conn.execute(query).fetchall()
metadata = self.collection_metadata(collection_name).dict(exclude_none=True)

Expand Down Expand Up @@ -558,8 +569,7 @@ def dump_then_load(self, collection: str = None, target: DBAdapter = None,
"""
if collection is None:
raise ValueError("Collection name must be provided.")

if not isinstance(target, DuckDBVSSAdapter):
if not isinstance(target, DuckDBAdapter):
raise ValueError("Target must be a DuckDBVSSAdapter instance")
self.dump(collection=collection, to_file=temp_file, format=format)
with open(temp_file, 'r') as f:
Expand Down Expand Up @@ -600,43 +610,32 @@ def _diversified_search(self, text: str, where: QUERY = None, collection: str =
where_clause = f"WHERE {where_clause}"

query_embedding = self._embedding_function(text, model=cm.model)

safe_collection_name = f'"{collection}"'
results = self.conn.execute(f"""
SELECT *, array_cosine_similarity(embeddings::FLOAT[{self.vec_dimension}],
{query_embedding}::FLOAT[{self.vec_dimension}]) as distance
FROM {collection}
SELECT *, (1 - array_cosine_similarity(embeddings::FLOAT[{self.vec_dimension}],
{query_embedding}::FLOAT[{self.vec_dimension}])) as distance
FROM {safe_collection_name}
{where_clause}
ORDER BY distance
LIMIT ?
""", [limit * 10]).fetchall()

# first row currently always with distance None as id = '__metadata__'
results = [r for r in results if r[-1] is not None]
results = sorted(results, key=lambda x: x[-1])
distances = np.array([r[-1] for r in results])
parsed_results = list(self.parse_duckdb_result(results))
selected_indices = []
unselected_indices = list(range(len(parsed_results)))

for _ in range(min(limit, len(parsed_results))):
if not unselected_indices:
break

mmr_scores = relevance_factor * (1 - distances[unselected_indices])
if selected_indices:
selected_vecs = np.array([parsed_results[i].embeddings for i in selected_indices if
parsed_results[i].embeddings is not None])
if selected_vecs.size != 0:
similarities = np.dot(np.array([parsed_results[i].embeddings for i in unselected_indices]),
selected_vecs.T)
max_similarities = np.max(similarities, axis=1)
mmr_scores -= (1 - relevance_factor) * max_similarities

next_index = unselected_indices[np.argmax(mmr_scores)]
selected_indices.append(next_index)
unselected_indices.remove(next_index)

for idx in selected_indices:
parsed_results[idx].distance = float(distances[idx])
print(parsed_results)
document_vectors = [np.array(result.embeddings) for result in parsed_results if result.embeddings is not None]
query_vector = np.array(self._embedding_function(text, model=cm.model))
if not document_vectors:
logger.info("The database might be empty. No diversified search results to return.")
return
reranked_indices = mmr_diversified_search(
query_vector=query_vector,
document_vectors=document_vectors,
relevance_factor=relevance_factor, top_n=limit
)
for idx in reranked_indices:
yield parsed_results[idx]

def _is_openai(self, collection: str) -> bool:
Expand All @@ -646,7 +645,8 @@ def _is_openai(self, collection: str) -> bool:
:return:
"""
collection = self._get_collection(collection)
query = f"SELECT metadata FROM {collection} WHERE id = '__metadata__'"
safe_collection_name = f'"{collection}"'
query = f"SELECT metadata FROM {safe_collection_name} WHERE id = '__metadata__'"
result = self.conn.execute(query).fetchone()
if result:
metadata = json.loads(result[0])
Expand Down

0 comments on commit 2e5d816

Please sign in to comment.