From 004e9f04d9e6e9d2f28c765f0e13e8df268128dc Mon Sep 17 00:00:00 2001 From: iQuxLE Date: Wed, 31 Jul 2024 16:49:34 +0100 Subject: [PATCH 1/2] add logic for default file path and use get_store() in view index command --- src/curate_gpt/cli.py | 341 +++++++++++++++++++++++------------------- 1 file changed, 184 insertions(+), 157 deletions(-) diff --git a/src/curate_gpt/cli.py b/src/curate_gpt/cli.py index 5664400..d00c05e 100644 --- a/src/curate_gpt/cli.py +++ b/src/curate_gpt/cli.py @@ -3,8 +3,10 @@ import gzip import json import logging +import os import sys import tempfile +import time from pathlib import Path from typing import Any, Dict, List, Union, Optional @@ -48,7 +50,8 @@ ] -def dump(obj: Union[str, AnnotatedObject, Dict], format="yaml", old_object: Optional[Dict] = None, primary_key: Optional[str] = None) -> None: +def dump(obj: Union[str, AnnotatedObject, Dict], format="yaml", old_object: Optional[Dict] = None, + primary_key: Optional[str] = None) -> None: """ Dump an object to stdout. @@ -100,9 +103,9 @@ def dump(obj: Union[str, AnnotatedObject, Dict], format="yaml", old_object: Opti "-m", "--model", help="Model to use for generation or embedding, e.g. gpt-4." ) extract_format_option = click.option("--extract-format", - "-X", - default="json", - show_default=True, help="Format to use for extraction.") + "-X", + default="json", + show_default=True, help="Format to use for extraction.") schema_option = click.option("-s", "--schema", help="Path to schema.") collection_option = click.option("-c", "--collection", help="Collection within the database.") output_format_option = click.option( @@ -234,28 +237,28 @@ def main(verbose: int, quiet: bool): help="jsonpath to use to subselect from each JSON document.", ) @click.option("--remove-field", - multiple=True, - help="Field to remove recursively from each object.") + multiple=True, + help="Field to remove recursively from each object.") @batch_size_option @encoding_option @click.argument("files", nargs=-1) def index( - files, - path, - append: bool, - text_field, - collection, - model, - object_type, - description, - batch_size, - glob, - view, - select, - collect, - encoding, - remove_field, - **kwargs, + files, + path, + append: bool, + text_field, + collection, + model, + object_type, + description, + batch_size, + glob, + view, + select, + collect, + encoding, + remove_field, + **kwargs, ): """ Index files. @@ -425,16 +428,16 @@ def search(query, path, collection, show_documents, **kwargs): ) @output_format_option def all_by_all( - path, - collection, - other_collection, - other_path, - threshold, - ids_only, - output_format, - left_field, - right_field, - **kwargs, + path, + collection, + other_collection, + other_path, + threshold, + ids_only, + output_format, + left_field, + right_field, + **kwargs, ): """Match two collections.""" db = ChromaDBAdapter(path) @@ -543,17 +546,17 @@ def matches(id, path, collection): ) @click.argument("texts", nargs=-1) def annotate( - texts, - path, - model, - collection, - input_file, - split_sentences, - category, - prefix, - identifier_field, - label_field, - **kwargs, + texts, + path, + model, + collection, + input_file, + split_sentences, + category, + prefix, + identifier_field, + label_field, + **kwargs, ): """Concept recognition.""" db = ChromaDBAdapter(path) @@ -621,17 +624,17 @@ def annotate( @output_format_option @click.argument("text", nargs=-1) def extract( - text, - input, - path, - docstore_path, - docstore_collection, - conversation, - rule: List[str], - model, - schema, - output_format, - **kwargs, + text, + input, + path, + docstore_path, + docstore_collection, + conversation, + rule: List[str], + model, + schema, + output_format, + **kwargs, ): """Extract a structured object from text. @@ -711,17 +714,17 @@ def extract( ) @click.argument("ids", nargs=-1) def extract_from_pubmed( - ids, - pubmed_id_file, - output_directory, - path, - docstore_path, - docstore_collection, - conversation, - rule: List[str], - model, - schema, - **kwargs, + ids, + pubmed_id_file, + output_directory, + path, + docstore_path, + docstore_collection, + conversation, + rule: List[str], + model, + schema, + **kwargs, ): """Extract structured knowledge from a publication using its PubMed ID. @@ -854,18 +857,18 @@ def bootstrap_data(config, schema, model): @output_format_option @click.argument("query") def complete( - query, - path, - docstore_path, - docstore_collection, - conversation, - rule: List[str], - model, - query_property, - schema, - extract_format, - output_format, - **kwargs, + query, + path, + docstore_path, + docstore_collection, + conversation, + rule: List[str], + model, + query_property, + schema, + extract_format, + output_format, + **kwargs, ): """ Generate an entry from a query using object completion. @@ -948,20 +951,20 @@ def complete( @click.option("--primary-key", help="Primary key for patch output.") @click.argument("where", nargs=-1) def update( - where, - path, - collection, - docstore_path, - docstore_collection, - conversation, - rule: List[str], - model, - query_property, - schema, - extract_format, - output_format, - primary_key, - **kwargs, + where, + path, + collection, + docstore_path, + docstore_collection, + conversation, + rule: List[str], + model, + query_property, + schema, + extract_format, + output_format, + primary_key, + **kwargs, ): """ Update an entry from a database using object completion. @@ -1039,20 +1042,20 @@ def update( @click.option("--primary-key", help="Primary key for patch output.") @click.argument("where", nargs=-1) def review( - where, - path, - collection, - docstore_path, - docstore_collection, - conversation, - rule: List[str], - model, - query_property, - schema, - extract_format, - output_format, - primary_key, - **kwargs, + where, + path, + collection, + docstore_path, + docstore_collection, + conversation, + rule: List[str], + model, + query_property, + schema, + extract_format, + output_format, + primary_key, + **kwargs, ): """ Review entries. @@ -1095,7 +1098,8 @@ def review( dac.document_adapter_collection = docstore_collection for obj, _s, _meta in db.find(where_q, collection=collection): logging.debug(f"Updating {obj}") - ao = dac.review(obj, rules=rule, collection=collection, context_property=query_property, primary_key=primary_key, **filtered_kwargs) + ao = dac.review(obj, rules=rule, collection=collection, context_property=query_property, + primary_key=primary_key, **filtered_kwargs) if output_format == "yaml": print("---") dump(ao.object, format=output_format, old_object=obj, primary_key=primary_key) @@ -1137,18 +1141,18 @@ def review( @output_format_option @click.argument("input_file") def complete_multiple( - input_file, - path, - docstore_path, - docstore_collection, - conversation, - rule: List[str], - model, - query_property, - schema, - output_format, - extract_format, - **kwargs, + input_file, + path, + docstore_path, + docstore_collection, + conversation, + rule: List[str], + model, + query_property, + schema, + output_format, + extract_format, + **kwargs, ): """ Generate an entry from a query using object completion for multiple objects. @@ -1230,17 +1234,17 @@ def complete_multiple( ) @schema_option def complete_all( - path, - collection, - docstore_path, - docstore_collection, - conversation, - rule: List[str], - model, - field_to_predict, - schema, - id_file, - **kwargs, + path, + collection, + docstore_path, + docstore_collection, + conversation, + rule: List[str], + model, + field_to_predict, + schema, + id_file, + **kwargs, ): """ Generate missing values for all objects @@ -1329,17 +1333,17 @@ def complete_all( ) @schema_option def generate_evaluate( - path, - docstore_path, - docstore_collection, - model, - schema, - test_collection, - num_tests, - hold_back_fields, - mask_fields, - rule: List[str], - **kwargs, + path, + docstore_path, + docstore_collection, + model, + schema, + test_collection, + num_tests, + hold_back_fields, + mask_fields, + rule: List[str], + **kwargs, ): """ Evaluate generate using a test set. @@ -1426,17 +1430,17 @@ def generate_evaluate( @generate_background_option @click.argument("tasks", nargs=-1) def evaluate( - tasks, - working_directory, - path, - model, - generate_background, - num_testing, - hold_back_fields, - mask_fields, - rule: List[str], - collection, - **kwargs, + tasks, + working_directory, + path, + model, + generate_background, + num_testing, + hold_back_fields, + mask_fields, + rule: List[str], + collection, + **kwargs, ): """ Evaluate given a task configuration. @@ -1654,6 +1658,7 @@ def apply_patch(input_file, patch, primary_key): print("---") print(yaml.dump(obj, sort_keys=False)) + @main.command() @collection_option @path_option @@ -1928,8 +1933,8 @@ def copy_collection(path, collection, target_path, **kwargs): @click.option( "--derived-collection-base", help=( - "Base name for derived collections. Will be suffixed with _train, _test, _val." - "If not provided, will use the same name as the original collection." + "Base name for derived collections. Will be suffixed with _train, _test, _val." + "If not provided, will use the same name as the original collection." ), ) @model_option @@ -1974,7 +1979,7 @@ def copy_collection(path, collection, target_path, **kwargs): ) @path_option def split_collection( - path, collection, derived_collection_base, output_path, model, test_id_file, **kwargs + path, collection, derived_collection_base, output_path, model, test_id_file, **kwargs ): """ Split a collection into test/train/validation. @@ -2020,6 +2025,7 @@ def set_collection_metadata(path, collection, metadata_yaml): def ontology(): "Use the ontology model" + @ontology.command(name="index") @path_option @collection_option @@ -2046,6 +2052,13 @@ def index_ontology_command(ont, path, collection, append, model, index_fields, b curategpt ontology index -p stagedb/duck.db -c ont-hp sqlite:obo:hp -D duckdb """ + + s = time.time() + + if os.path.isdir(path): + path = os.path.join(path, "duck.duckdb") + click.echo("You have to provide a path to a file : Defaulting to" + path) + oak_adapter = get_adapter(ont) view = OntologyWrapper(oak_adapter=oak_adapter) if branches: @@ -2054,6 +2067,7 @@ def index_ontology_command(ont, path, collection, append, model, index_fields, b db.text_lookup = view.text_field if index_fields: fields = index_fields.split(",") + # print(f"Indexing fields: {fields}") def _text_lookup(obj: Dict): @@ -2063,8 +2077,11 @@ def _text_lookup(obj: Dict): db.text_lookup = _text_lookup if not append: db.remove_collection(collection, exists_ok=True) + click.echo(f"Indexing {len(list(view.objects()))} objects") db.insert(view.objects(), collection=collection, model=model) db.update_collection_metadata(collection, object_type="OntologyClass") + e = time.time() + click.echo(f"Indexed {len(list(view.objects()))} in {e - s} seconds") @main.group() @@ -2225,13 +2242,23 @@ def view_search(query, view, model, init_with, limit, **kwargs): @model_option @init_with_option @append_option -def view_index(view, path, append, collection, model, init_with, batch_size, **kwargs): - """Populate an index from a view.""" +@database_type_option +def view_index(view, path, append, collection, model, init_with, batch_size, database_type, **kwargs): + """Populate an index from a view. + curategpt -v index -p stagedb --batch-size 10 -V hpoa -c hpoa -m openai: (that uses chroma by default) + curategpt -v index -p stagedb/hpoa.duckdb --batch-size 10 -V hpoa -c hpoa -m openai: -D duckdb + + """ + if os.path.isdir(path): + path = os.path.join(path, "duck.duckdb") + click.echo("You have to provide a path to a file : Defaulting to " + path) + if init_with: for k, v in yaml.safe_load(init_with).items(): kwargs[k] = v wrapper: BaseWrapper = get_wrapper(view, **kwargs) - store = ChromaDBAdapter(path) + store = get_store(database_type, path) + if not append: if collection in store.list_collection_names(): store.remove_collection(collection) From c229918a8e85d1625cf520eb1bef66e128ef65f2 Mon Sep 17 00:00:00 2001 From: iQuxLE Date: Wed, 31 Jul 2024 16:50:39 +0100 Subject: [PATCH 2/2] handling concurrency if there is already a process running on the same duckdb file --- src/curate_gpt/store/duckdb_adapter.py | 98 ++++++++++++++++++++++---- 1 file changed, 85 insertions(+), 13 deletions(-) diff --git a/src/curate_gpt/store/duckdb_adapter.py b/src/curate_gpt/store/duckdb_adapter.py index e16b127..b3e1f0a 100644 --- a/src/curate_gpt/store/duckdb_adapter.py +++ b/src/curate_gpt/store/duckdb_adapter.py @@ -2,14 +2,16 @@ This is a DuckDB adapter for the Vector Similarity Search (VSS) extension using the experimental persistence feature """ +import psutil import yaml import logging import os import time +import re import numpy as np from dataclasses import dataclass, field -from typing import ClassVar, Iterable, Iterator, Optional, Union, Callable, Mapping, List, Dict, Any, Tuple +from typing import ClassVar, Iterable, Iterator, Optional, Union, Callable, Mapping, List, Dict, Any import duckdb import json @@ -67,7 +69,19 @@ def __post_init__(self): self.ef_search = self._validate_ef_search(self.ef_search) self.M = self._validate_m(self.M) logger.info(f"Using DuckDB at {self.path}") - self.conn = duckdb.connect(self.path, read_only=False) + # handling concurrency + try: + self.conn = duckdb.connect(self.path, read_only=False) + except duckdb.IOException as e: + match = re.search(r'PID (\d+)', str(e)) + if match: + pid = int(match.group(1)) + logger.info(f"Got {e}.Attempting to kill process with PID: {pid}") + self.kill_process(pid) + self.conn = duckdb.connect(self.path, read_only=False) + else: + logger.error(f"{e} without PID information.") + raise self.conn.execute("INSTALL vss;") self.conn.execute("LOAD vss;") self.conn.execute("SET hnsw_enable_experimental_persistence=true;") @@ -286,18 +300,11 @@ def _process_objects( self.conn.execute("COMMIT;") except Exception as e: self.conn.execute("ROLLBACK;") - logger.error(f"Trransaction failed: {e}") + logger.error(f"Transaction failed: {e}") raise - self.create_index(collection) + finally: + self.create_index(collection) - def _generate_sql_command(self, collection: str, method: str) -> str: - safe_collection_name = f'"{collection}"' - if method == "insert": - return f""" - INSERT INTO {safe_collection_name} (id,metadata, embeddings, documents) VALUES (?, ?, ?, ?) - """ - else: - raise ValueError(f"Unknown method: {method}") def remove_collection(self, collection: str = None, exists_ok=False, **kwargs): """ @@ -636,6 +643,38 @@ def _diversified_search(self, text: str, where: QUERY = None, collection: str = for idx in reranked_indices: yield parsed_results[idx] + @staticmethod + def kill_process(pid): + """ + Kill the process with the given PID + Returns + ------- + + """ + process = None + try: + process = psutil.Process(pid) + process.terminate() # Sends SIGTERM + process.wait(timeout=5) + except psutil.NoSuchProcess: + logger.info("Process already terminated.") + except psutil.TimeoutExpired: + if process is not None: + logger.warning("Process did not terminate in time, forcing kill.") + process.kill() # Sends SIGKILL as a last resort + except Exception as e: + logger.error(f"Failed to terminate process: {e}") + + @staticmethod + def _generate_sql_command(collection: str, method: str) -> str: + safe_collection_name = f'"{collection}"' + if method == "insert": + return f""" + INSERT INTO {safe_collection_name} (id,metadata, embeddings, documents) VALUES (?, ?, ?, ?) + """ + else: + raise ValueError(f"Unknown method: {method}") + def _is_openai(self, collection: str) -> bool: """ Check if the collection uses a OpenAI Embedding model @@ -719,7 +758,7 @@ def parse_duckdb_result(results) -> Iterator[DuckDBSearchResult]: metadata=json.loads(obj[1]), embeddings=obj[2], documents=obj[3], - cosim=obj[4] + distance=obj[4] ) @staticmethod @@ -780,18 +819,51 @@ def _get_embedding_dimension(self, model_name: str) -> int: @staticmethod def _validate_ef_construction(value: int) -> int: + """ + The number of candidate vertices to consider during the construction of the index. A higher value will result + in a more accurate index, but will also increase the time it takes to build the index. + Parameters + ---------- + value + + Returns + ------- + + """ if not (10 <= value <= 200): raise ValueError("ef_construction must be between 10 and 200") return value @staticmethod def _validate_ef_search(value: int) -> int: + """ + The number of candidate vertices to consider during the search phase of the index. + A higher value will result in a more accurate index, but will also increase the time it takes to perform a search. + Parameters + ---------- + value + + Returns + ------- + + """ if not (10 <= value <= 200): raise ValueError("ef_search must be between 10 and 200") return value @staticmethod def _validate_m(value: int) -> int: + """ + The maximum number of neighbors to keep for each vertex in the graph. + A higher value will result in a more accurate index, but will also increase the time it takes to build the index. + Parameters + ---------- + value + + Returns + ------- + + """ if not (5 <= value <= 48): raise ValueError("M must be between 5 and 48") return value