diff --git a/src/curate_gpt/store/duckdb_adapter.py b/src/curate_gpt/store/duckdb_adapter.py index 128f809..d2035d2 100644 --- a/src/curate_gpt/store/duckdb_adapter.py +++ b/src/curate_gpt/store/duckdb_adapter.py @@ -2,8 +2,6 @@ This is a DuckDB adapter for the Vector Similarity Search (VSS) extension using the experimental persistence feature """ -import itertools - import yaml import logging import os @@ -28,6 +26,7 @@ from curate_gpt.store.db_adapter import DBAdapter, OBJECT, QUERY, PROJECTION, SEARCH_RESULT from curate_gpt.store.duckdb_result import DuckDBSearchResult from curate_gpt.store.metadata import CollectionMetadata +from curate_gpt.utils.vector_algorithms import mmr_diversified_search logger = logging.getLogger(__name__) @@ -62,7 +61,7 @@ class DuckDBAdapter(DBAdapter): def __post_init__(self): if not self.path: - self.path="./duck.db" + self.path = "./duck.db" os.makedirs(os.path.dirname(self.path), exist_ok=True) self.ef_construction = self._validate_ef_construction(self.ef_construction) self.ef_search = self._validate_ef_search(self.ef_search) @@ -101,8 +100,9 @@ def _create_table_if_not_exists(self, collection: str, vec_dimension: int, model :param collection: :return: """ + safe_collection_name = f'"{collection}"' create_table_sql = f""" - CREATE TABLE IF NOT EXISTS {collection} ( + CREATE TABLE IF NOT EXISTS {safe_collection_name} ( id VARCHAR PRIMARY KEY, metadata JSON, embeddings FLOAT[{vec_dimension}], @@ -117,14 +117,15 @@ def _create_table_if_not_exists(self, collection: str, vec_dimension: int, model else: metadata = CollectionMetadata(name=collection, model=self.default_model) metadata_json = json.dumps(metadata.model_dump(exclude_none=True)) + safe_collection_name = f'"{collection}"' self.conn.execute( f""" - INSERT INTO {collection} (id, metadata) VALUES ('__metadata__', ?) + INSERT INTO {safe_collection_name} (id, metadata) VALUES ('__metadata__', ?) ON CONFLICT (id) DO NOTHING """, [metadata_json] ) - def create_index(self, collection: str): + def create_index(self, collection: str, metric: str = "cosine"): """ Create an index for the given collection Parameters @@ -136,18 +137,20 @@ def create_index(self, collection: str): """ - metrics = ['l2sq', 'cosine', 'ip'] #l2sq,ip - for metric in metrics: - create_index_sql = f""" - CREATE INDEX IF NOT EXISTS idx_{collection}_embeddings_{metric} ON {collection} - USING HNSW (embeddings) WITH ( - metric='{metric}', - ef_construction={self.ef_construction}, - ef_search={self.ef_search}, - M={self.M} - ) - """ - self.conn.execute(create_index_sql) + # metrics = ['l2sq', 'cosine', 'ip'] #l2sq,ip + # for metric in metrics: + safe_collection_name = f'"{collection}"' + index_name = f"{collection}_index" + create_index_sql = f""" + CREATE INDEX IF NOT EXISTS "{index_name}" ON {safe_collection_name} + USING HNSW (embeddings) WITH ( + metric='{metric}', + ef_construction={self.ef_construction}, + ef_search={self.ef_search}, + M={self.M} + ) + """ + self.conn.execute(create_index_sql) def _embedding_function(self, texts: Union[str, List[str]], model: str = None) -> list: """ @@ -196,7 +199,8 @@ def update(self, objs: Union[OBJECT, Iterable[OBJECT]], **kwargs): """ collection = kwargs.get('collection') ids = [self._id(o, self.id_field) for o in objs] - delete_sql = f"DELETE FROM {collection} WHERE id = ?" + safe_collection_name = f'"{collection}"' + delete_sql = f"DELETE FROM {safe_collection_name} WHERE id = ?" self.conn.executemany(delete_sql, [(id_,) for id_ in ids]) self.insert(objs, **kwargs) @@ -211,7 +215,8 @@ def upsert(self, objs: Union[OBJECT, Iterable[OBJECT]], **kwargs): ids = [self._id(o, self.id_field) for o in objs] existing_ids = set() for id_ in ids: - result = self.conn.execute(f"SELECT id FROM {collection} WHERE id = ?",[id_]).fetchall() + safe_collection_name = f'"{collection}"' + result = self.conn.execute(f"SELECT id FROM {safe_collection_name} WHERE id = ?", [id_]).fetchall() if result: existing_ids.add(id_) objs_to_update = [o for o in objs if self._id(o, self.id_field) in existing_ids] @@ -260,6 +265,7 @@ def _process_objects( text_field = self.text_lookup id_field = self.id_field num_objs = len(objs) if isinstance(objs, list) else "?" + logger.info(f"Processing {len(num_objs)} objects") cumulative_len = 0 sql_command = self._generate_sql_command(collection, method) sql_command = sql_command.format(collection=collection) @@ -275,6 +281,7 @@ def _process_objects( metadatas = [self._dict(o) for o in next_objs] ids = [self._id(o, id_field) for o in next_objs] embeddings = self._embedding_function(docs, model) + logger.info(f"Processing object {next_objs[0]} of {num_objs}") try: self.conn.execute("BEGIN TRANSACTION;") self.conn.executemany(sql_command, list(zip(ids, metadatas, embeddings, docs))) @@ -285,15 +292,15 @@ def _process_objects( raise self.create_index(collection) - def _generate_sql_command(self, collection: str, method: str) -> str: + def _generate_sql_command(self, collection: str, method: str) -> str: + safe_collection_name = f'"{collection}"' if method == "insert": return f""" - INSERT INTO {collection} (id,metadata, embeddings, documents) VALUES (?, ?, ?, ?) + INSERT INTO {safe_collection_name} (id,metadata, embeddings, documents) VALUES (?, ?, ?, ?) """ else: raise ValueError(f"Unknown method: {method}") - def remove_collection(self, collection: str = None, exists_ok=False, **kwargs): """ Remove the collection from the database @@ -306,7 +313,9 @@ def remove_collection(self, collection: str = None, exists_ok=False, **kwargs): if not exists_ok: if collection not in self.list_collection_names(): raise ValueError(f"Collection {collection} does not exist") - self.conn.execute(f"DROP TABLE IF EXISTS {collection}") + # duckdb, requires that identifiers containing special characters ("-") must be enclosed in double quotes. + safe_collection_name = f'"{collection}"' + self.conn.execute(f"DROP TABLE IF EXISTS {safe_collection_name}") def search(self, text: str, where: QUERY = None, collection: str = None, limit: int = 10, relevance_factor: float = None, **kwargs) -> Iterator[SEARCH_RESULT]: @@ -342,10 +351,12 @@ def _search(self, text: str, where: QUERY = None, collection: str = None, limit: yield from self._diversified_search(text, where, collection, limit, relevance_factor, **kwargs) return query_embedding = self._embedding_function(text, model) + print("embedding", query_embedding) + safe_collection_name = f'"{collection}"' results = self.conn.execute(f""" - SELECT *, array_cosine_similarity(embeddings::FLOAT[{self.vec_dimension}], - {query_embedding}::FLOAT[{self.vec_dimension}]) as distance - FROM {collection} + SELECT *, (1 - array_cosine_similarity(embeddings::FLOAT[{self.vec_dimension}], + {query_embedding}::FLOAT[{self.vec_dimension}])) as distance + FROM {safe_collection_name} {where_clause} ORDER BY distance LIMIT ? @@ -372,7 +383,8 @@ def collection_metadata(self, collection_name: Optional[str] = None, include_der :param kwargs: :return: """ - result = self.conn.execute(f"SELECT metadata FROM {collection_name} WHERE id = '__metadata__'").fetchone() + safe_collection_name = f'"{collection_name}"' + result = self.conn.execute(f"SELECT metadata FROM {safe_collection_name} WHERE id = '__metadata__'").fetchone() if result: metadata = json.loads(result[0]) return CollectionMetadata(**metadata) @@ -394,9 +406,10 @@ def update_collection_metadata(self, collection: str, **kwargs): else: current_metadata = current_metadata.model_copy(update=kwargs) metadata_json = json.dumps(current_metadata.model_dump(exclude_none=True)) + safe_collection_name = f'"{collection}"' self.conn.execute( f""" - UPDATE {collection} + UPDATE {safe_collection_name} SET metadata = ? WHERE id = '__metadata__' """, [metadata_json] @@ -417,21 +430,15 @@ def set_collection_metadata( raise ValueError("Collection name must be provided.") metadata_json = json.dumps(metadata.dict(exclude_none=True)) + safe_collection_name = f'"{collection_name}"' self.conn.execute( f""" - UPDATE {collection_name} + UPDATE {safe_collection_name} SET metadata = ? WHERE id = '__metadata__' """, [metadata_json] ) - def _object_metadata(self, obj): - """ - Convert the object to a dictionary. - """ - obj = self._dict(obj) - return obj - def find(self, where: QUERY = None, projection: PROJECTION = None, collection: str = None, include: Optional[List[str]] = None, limit: int = 10, **kwargs) -> Iterator[ DuckDBSearchResult]: @@ -453,9 +460,10 @@ def find(self, where: QUERY = None, projection: PROJECTION = None, collection: s collection = self._get_collection(collection) where_clause = self._parse_where_clause(where) if where else "" where_clause = f"WHERE {where_clause}" if where_clause else "" + safe_collection_name = f'"{collection}"' query = f""" SELECT id, metadata, embeddings, documents, NULL as distance - FROM {collection} + FROM {safe_collection_name} {where_clause} LIMIT ? """ @@ -484,9 +492,10 @@ def lookup(self, id: str, collection: str = None, **kwargs) -> OBJECT: :param kwargs: :return: """ + safe_collection_name = f'"{collection}"' result = self.conn.execute(f""" SELECT * - FROM {collection} + FROM {safe_collection_name} WHERE id = ? """, [id]).fetchone() if isinstance(result, tuple) and len(result) > 1: @@ -502,9 +511,10 @@ def peek(self, collection: str = None, limit=5, **kwargs) -> Iterator[OBJECT]: :param kwargs: :return: """ + safe_collection_name = f'"{collection}"' results = self.conn.execute(f""" SELECT id, metadata, embeddings, documents, NULL as distance - FROM {collection} + FROM {safe_collection_name} LIMIT ? """, [limit]).fetchall() @@ -523,7 +533,8 @@ def dump(self, collection: str = None, to_file: Union[str, Path] = None, format: raise ValueError("Collection name must be provided.") collection_name = self._get_collection_name(collection) - query = f"SELECT id, embeddings, metadata, documents FROM {collection_name}" + safe_collection_name = f'"{collection_name}"' + query = f"SELECT id, embeddings, metadata, documents FROM {safe_collection_name}" data = self.conn.execute(query).fetchall() metadata = self.collection_metadata(collection_name).dict(exclude_none=True) @@ -558,8 +569,7 @@ def dump_then_load(self, collection: str = None, target: DBAdapter = None, """ if collection is None: raise ValueError("Collection name must be provided.") - - if not isinstance(target, DuckDBVSSAdapter): + if not isinstance(target, DuckDBAdapter): raise ValueError("Target must be a DuckDBVSSAdapter instance") self.dump(collection=collection, to_file=temp_file, format=format) with open(temp_file, 'r') as f: @@ -600,43 +610,32 @@ def _diversified_search(self, text: str, where: QUERY = None, collection: str = where_clause = f"WHERE {where_clause}" query_embedding = self._embedding_function(text, model=cm.model) - + safe_collection_name = f'"{collection}"' results = self.conn.execute(f""" - SELECT *, array_cosine_similarity(embeddings::FLOAT[{self.vec_dimension}], - {query_embedding}::FLOAT[{self.vec_dimension}]) as distance - FROM {collection} + SELECT *, (1 - array_cosine_similarity(embeddings::FLOAT[{self.vec_dimension}], + {query_embedding}::FLOAT[{self.vec_dimension}])) as distance + FROM {safe_collection_name} {where_clause} ORDER BY distance LIMIT ? """, [limit * 10]).fetchall() + # first row currently always with distance None as id = '__metadata__' results = [r for r in results if r[-1] is not None] results = sorted(results, key=lambda x: x[-1]) - distances = np.array([r[-1] for r in results]) parsed_results = list(self.parse_duckdb_result(results)) - selected_indices = [] - unselected_indices = list(range(len(parsed_results))) - - for _ in range(min(limit, len(parsed_results))): - if not unselected_indices: - break - - mmr_scores = relevance_factor * (1 - distances[unselected_indices]) - if selected_indices: - selected_vecs = np.array([parsed_results[i].embeddings for i in selected_indices if - parsed_results[i].embeddings is not None]) - if selected_vecs.size != 0: - similarities = np.dot(np.array([parsed_results[i].embeddings for i in unselected_indices]), - selected_vecs.T) - max_similarities = np.max(similarities, axis=1) - mmr_scores -= (1 - relevance_factor) * max_similarities - - next_index = unselected_indices[np.argmax(mmr_scores)] - selected_indices.append(next_index) - unselected_indices.remove(next_index) - - for idx in selected_indices: - parsed_results[idx].distance = float(distances[idx]) + print(parsed_results) + document_vectors = [np.array(result.embeddings) for result in parsed_results if result.embeddings is not None] + query_vector = np.array(self._embedding_function(text, model=cm.model)) + if not document_vectors: + logger.info("The database might be empty. No diversified search results to return.") + return + reranked_indices = mmr_diversified_search( + query_vector=query_vector, + document_vectors=document_vectors, + relevance_factor=relevance_factor, top_n=limit + ) + for idx in reranked_indices: yield parsed_results[idx] def _is_openai(self, collection: str) -> bool: @@ -646,7 +645,8 @@ def _is_openai(self, collection: str) -> bool: :return: """ collection = self._get_collection(collection) - query = f"SELECT metadata FROM {collection} WHERE id = '__metadata__'" + safe_collection_name = f'"{collection}"' + query = f"SELECT metadata FROM {safe_collection_name} WHERE id = '__metadata__'" result = self.conn.execute(query).fetchone() if result: metadata = json.loads(result[0])