Skip to content

Commit

Permalink
Merge pull request #61 from monarch-initiative/fix-Issue-type
Browse files Browse the repository at this point in the history
Fixes for retrieving GitHub issues
  • Loading branch information
caufieldjh authored Aug 20, 2024
2 parents a364507 + 372afda commit cf1277b
Show file tree
Hide file tree
Showing 36 changed files with 202 additions and 172 deletions.
47 changes: 33 additions & 14 deletions src/curate_gpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def dump(
"--docstore_database_type",
default="chromadb",
show_default=True,
help="Docstore database type.")
help="Docstore database type.",
)

model_option = click.option(
"-m", "--model", help="Model to use for generation or embedding, e.g. gpt-4."
Expand Down Expand Up @@ -414,7 +415,7 @@ def search(query, path, collection, show_documents, database_type, **kwargs):
db = get_store(database_type, path)
results = db.search(query, collection=collection, **kwargs)
i = 0
for obj, distance, meta in results:
for obj, distance, _meta in results:
i += 1
print(f"## {i} DISTANCE: {distance}")
print(yaml.dump(obj, sort_keys=False))
Expand Down Expand Up @@ -485,6 +486,7 @@ def all_by_all(
if other_path is None:
other_path = path
results = match_collections(db, collection, other_collection, other_db)

def _obj(obj: Dict, is_left=False) -> Any:
if ids_only:
obj = {"id": obj["id"]}
Expand Down Expand Up @@ -525,7 +527,7 @@ def _obj(obj: Dict, is_left=False) -> Any:
def matches(id, path, collection, database_type):
"""Find matches for an ID.
curategpt matches "Continuant" -p duckdb/objects.duckdb -c objects_a -D duckdb
curategpt matches "Continuant" -p duckdb/objects.duckdb -c objects_a -D duckdb
"""
db = get_store(database_type, path)
Expand All @@ -535,7 +537,7 @@ def matches(id, path, collection, database_type):
print(obj)
results = db.matches(obj, collection=collection)
i = 0
for obj, distance, meta in results:
for obj, distance, _meta in results:
i += 1
print(f"## ID:- {obj['id']}")
print(f"## DISTANCE- {distance}")
Expand Down Expand Up @@ -972,7 +974,7 @@ def complete(
extractor.schema_proxy = schema_manager
dac = DragonAgent(knowledge_source=db, extractor=extractor)
if docstore_path or docstore_collection:
dac.document_adapter = get_store(docstore_database_type,docstore_path)
dac.document_adapter = get_store(docstore_database_type, docstore_path)
dac.document_adapter_collection = docstore_collection
if ":" in query:
query = yaml.safe_load(query)
Expand Down Expand Up @@ -1064,7 +1066,7 @@ def update(
extractor.schema_proxy = schema_manager
dac = DragonAgent(knowledge_source=db, extractor=extractor)
if docstore_path or docstore_collection:
dac.document_adapter = get_store(docstore_database_type,docstore_path)
dac.document_adapter = get_store(docstore_database_type, docstore_path)
dac.document_adapter_collection = docstore_collection
for obj, _s, _meta in db.find(where_q, collection=collection):
# click.echo(f"{obj}")
Expand Down Expand Up @@ -1262,7 +1264,7 @@ def complete_multiple(
extractor.schema_proxy = schema_manager
dac = DragonAgent(knowledge_source=db, extractor=extractor)
if docstore_path or docstore_collection:
dac.document_adapter = get_store(docstore_database_type,docstore_path)
dac.document_adapter = get_store(docstore_database_type, docstore_path)
dac.document_adapter_collection = docstore_collection
with open(input_file) as f:
queries = [l.strip() for l in f.readlines()]
Expand Down Expand Up @@ -1454,7 +1456,7 @@ def complete_all(
extractor.schema_proxy = schema_manager
dae = DragonAgent(knowledge_source=db, extractor=extractor)
if docstore_path or docstore_collection:
dae.document_adapter = get_store(docstore_database_type,docstore_path)
dae.document_adapter = get_store(docstore_database_type, docstore_path)
dae.document_adapter_collection = docstore_collection
object_ids = None
if id_file:
Expand Down Expand Up @@ -1541,7 +1543,7 @@ def generate_evaluate(
-------
curategpt -v generate-evaluate -c cdr_training -T cdr_test -F statements -m gpt-4
"""
db = get_store(database_type,path)
db = get_store(database_type, path)
if schema:
schema_manager = SchemaProxy(schema)
else:
Expand All @@ -1556,7 +1558,7 @@ def generate_evaluate(
extractor.schema_proxy = schema_manager
rage = DragonAgent(knowledge_source=db, extractor=extractor)
if docstore_path or docstore_collection:
rage.document_adapter = get_store(docstore_database_type,docstore_path)
rage.document_adapter = get_store(docstore_database_type, docstore_path)
rage.document_adapter_collection = docstore_collection
hold_back_fields = hold_back_fields.split(",")
mask_fields = mask_fields.split(",") if mask_fields else []
Expand Down Expand Up @@ -1884,7 +1886,17 @@ def apply_patch(input_file, patch, primary_key):
help="jsonpath expression to select objects from the input file.",
)
@click.argument("query")
def citeseek(query, path, collection, model, show_references, _continue, select, conversation_id, database_type):
def citeseek(
query,
path,
collection,
model,
show_references,
_continue,
select,
conversation_id,
database_type,
):
"""Find citations for an object or statement.
You can pass in a statement directly as an argument
Expand Down Expand Up @@ -2063,7 +2075,7 @@ def list_collections(database_type, path, peek: bool, minimal: bool, derived: bo
# making sure if o[id] finds nothing we get the full obj
r = list(db.peek(cn))
for o, _, _ in r:
if 'id' in o:
if "id" in o:
print(f" - {o['id']}")
else:
print(f" - {o}")
Expand Down Expand Up @@ -2202,7 +2214,14 @@ def copy_collection(path, collection, target_path, database_type, **kwargs):
)
@path_option
def split_collection(
path, collection, derived_collection_base, output_path, model, test_id_file, database_type, **kwargs
path,
collection,
derived_collection_base,
output_path,
model,
test_id_file,
database_type,
**kwargs,
):
"""
Split a collection into test/train/validation.
Expand All @@ -2224,7 +2243,7 @@ def split_collection(
)
logging.info(f"First 10: {kwargs['testing_identifiers'][:10]}")
sc = stratify_collection(db, collection, **kwargs)
output_db = get_store(database_type ,output_path)
output_db = get_store(database_type, output_path)
if not derived_collection_base:
derived_collection_base = collection
for sn in ["training", "testing", "validation"]:
Expand Down
4 changes: 2 additions & 2 deletions src/curate_gpt/store/chromadb_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from linkml_runtime.utils.yamlutils import YAMLRoot
from oaklib.utilities.iterator_utils import chunk
from pydantic import BaseModel
from curate_gpt.store.metadata import CollectionMetadata

from curate_gpt.store.vocab import OBJECT, QUERY, PROJECTION, SEARCH_RESULT
from curate_gpt.store.db_adapter import DBAdapter
from curate_gpt.store.metadata import CollectionMetadata
from curate_gpt.store.vocab import OBJECT, PROJECTION, QUERY, SEARCH_RESULT
from curate_gpt.utils.vector_algorithms import mmr_diversified_search

logger = logging.getLogger(__name__)
Expand Down
17 changes: 13 additions & 4 deletions src/curate_gpt/store/db_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import ClassVar, Dict, Iterable, Iterator, List, Optional, TextIO, Tuple, Union
from typing import ClassVar, Dict, Iterable, Iterator, List, Optional, TextIO, Union

import pandas as pd
import yaml
Expand All @@ -14,8 +14,17 @@

from curate_gpt.store.metadata import CollectionMetadata
from curate_gpt.store.schema_proxy import SchemaProxy
from curate_gpt.store.vocab import OBJECT, SEARCH_RESULT, QUERY, PROJECTION, FILE_LIKE, EMBEDDINGS, DOCUMENTS, \
METADATAS, DEFAULT_COLLECTION
from curate_gpt.store.vocab import (
DEFAULT_COLLECTION,
DOCUMENTS,
EMBEDDINGS,
FILE_LIKE,
METADATAS,
OBJECT,
PROJECTION,
QUERY,
SEARCH_RESULT,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -468,4 +477,4 @@ def dump_then_load(self, collection: str = None, target: "DBAdapter" = None):
:param target:
:return:
"""
raise NotImplementedError
raise NotImplementedError
4 changes: 2 additions & 2 deletions src/curate_gpt/store/db_metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path

from pydantic import BaseModel
import yaml
from pydantic import BaseModel


class DBSettings(BaseModel):
Expand All @@ -16,7 +16,7 @@ class DBSettings(BaseModel):

ef_construction: int = 128
"""
Construction parameter for hnsw index.
Construction parameter for hnsw index.
Higher values are more accurate but slower.
"""

Expand Down
Loading

0 comments on commit cf1277b

Please sign in to comment.