diff --git a/src/curate_gpt/agents/concept_recognition_agent.py b/src/curate_gpt/agents/concept_recognition_agent.py index 47d0ee6..8fffc7e 100644 --- a/src/curate_gpt/agents/concept_recognition_agent.py +++ b/src/curate_gpt/agents/concept_recognition_agent.py @@ -19,6 +19,7 @@ class Span(BaseModel): """An individual span of text containing a single concept.""" + text: str start: Optional[int] = None @@ -33,6 +34,7 @@ class Span(BaseModel): is_suspect: bool = False """Potential hallucination due to ID/label mismatch.""" + class GroundingResult(BaseModel): """Result of grounding text.""" @@ -48,6 +50,7 @@ class GroundingResult(BaseModel): class AnnotationMethod(str, Enum): """Strategy or algorithm used for CR.""" + INLINE = "inline" """LLM creates an annotated document""" @@ -153,7 +156,7 @@ def parse_annotations(text, marker_char: str = None) -> List[CONCEPT]: :return: """ # First Pass: Extract text within [ ... ] - pattern1 = r'\[([^\]]+)\]' + pattern1 = r"\[([^\]]+)\]" matches = re.findall(pattern1, text) # Second Pass: Parse the last token of each match @@ -163,7 +166,7 @@ def parse_annotations(text, marker_char: str = None) -> List[CONCEPT]: if marker_char: toks = match.split(marker_char) if len(toks) > 1: - annotation = ' '.join(toks[:-1]).strip() + annotation = " ".join(toks[:-1]).strip() id = toks[-1].strip() else: annotation = match @@ -171,7 +174,7 @@ def parse_annotations(text, marker_char: str = None) -> List[CONCEPT]: else: words = match.split() if len(words) > 1 and ":" in words[-1]: - annotation = ' '.join(words[:-1]) + annotation = " ".join(words[:-1]) id = words[-1] else: annotation = match @@ -198,14 +201,19 @@ def parse_spans(text: str, concept_dict: Dict[str, str] = None) -> List[Span]: concept_label = row[1].strip('"') mention_text = ",".join(row[2:]) verified_concept_label = concept_dict.get(concept_id, None) - spans.append(Span(text=mention_text, concept_id=concept_id, concept_label=verified_concept_label, - is_suspect=verified_concept_label != concept_label)) + spans.append( + Span( + text=mention_text, + concept_id=concept_id, + concept_label=verified_concept_label, + is_suspect=verified_concept_label != concept_label, + ) + ) return spans @dataclass class ConceptRecognitionAgent(BaseAgent): - identifier_field: str = None """Field to use as identifier for objects.""" @@ -225,16 +233,17 @@ def ground_concept( text: str, collection: str = None, categories: Optional[List[str]] = None, - include_category_in_search = True, + include_category_in_search=True, context: str = None, **kwargs, ) -> GroundingResult: - system_prompt = GROUND_PROMPT query = text if include_category_in_search and categories: query += " Categories: " + ", ".join(categories) - concept_pairs, concept_prompt = self._label_id_pairs_prompt_section(query, collection, **kwargs) + concept_pairs, concept_prompt = self._label_id_pairs_prompt_section( + query, collection, **kwargs + ) concept_dict = {c[0]: c[1] for c in concept_pairs} system_prompt += concept_prompt model = self.extractor.model @@ -262,9 +271,14 @@ def ground_concept( concept_label = concept_dict[concept_id] else: concept_label = None - span = Span(text=text, concept_id=concept_id, concept_label=concept_label, is_suspect=provided_concept_label != concept_label) + span = Span( + text=text, + concept_id=concept_id, + concept_label=concept_label, + is_suspect=provided_concept_label != concept_label, + ) spans.append(span) - #spans = parse_spans(response.text(), concept_dict) + # spans = parse_spans(response.text(), concept_dict) ann = GroundingResult(input_text=text, annotated_text=response.text(), spans=spans) return ann @@ -296,12 +310,16 @@ def annotate_two_pass( system_prompt = "Your job is to parse the supplied text, identifying instances of concepts " if len(categories) == 1: system_prompt += f" that represent some kind of {categories[0]}. " - system_prompt += ("Mark up the concepts in square brackets, " - "preserving the original text inside the brackets. ") + system_prompt += ( + "Mark up the concepts in square brackets, " + "preserving the original text inside the brackets. " + ) else: system_prompt += " that represent one of the following categories: " system_prompt += ", ".join(categories) - system_prompt += "Mark up the concepts in square brackets, with the category after the pipe symbol, " + system_prompt += ( + "Mark up the concepts in square brackets, with the category after the pipe symbol, " + ) system_prompt += "Using the syntax [ORIGINAL TEXT | CATEGORY]." logger.debug(f"Prompting with: {text}") model = self.extractor.model @@ -310,12 +328,24 @@ def annotate_two_pass( anns = parse_annotations(marked_up_text, "|") spans = [] for term, category in anns: - concept = self.ground_concept(term, collection, categories=[category] if category else None, context=text, **kwargs) + concept = self.ground_concept( + term, + collection, + categories=[category] if category else None, + context=text, + **kwargs, + ) if not concept.spans: logger.debug(f"Unable to ground concept {term} in category {category}") continue main_span = concept.spans[0] - spans.append(Span(text=term, concept_id=main_span.concept_id, concept_label=main_span.concept_label)) + spans.append( + Span( + text=term, + concept_id=main_span.concept_id, + concept_label=main_span.concept_label, + ) + ) return AnnotatedText( input_text=text, annotated_text=marked_up_text, @@ -329,9 +359,10 @@ def annotate_inline( categories: List[str] = None, **kwargs, ) -> AnnotatedText: - system_prompt = ANNOTATE_PROMPT - concept_pairs, concepts_prompt = self._label_id_pairs_prompt_section(text, collection, **kwargs) + concept_pairs, concepts_prompt = self._label_id_pairs_prompt_section( + text, collection, **kwargs + ) concept_dict = {c[0]: c[1] for c in concept_pairs} system_prompt += concepts_prompt model = self.extractor.model @@ -339,13 +370,11 @@ def annotate_inline( response = model.prompt(text, system=system_prompt) anns = parse_annotations(response.text()) logger.info(f"Anns: {anns}") - spans = [Span(text=ann[0], concept_id=ann[1], concept_label=concept_dict.get(ann[1], None)) for ann in anns] - return AnnotatedText( - input_text=text, - spans=spans, - annotated_text=response.text() - ) - + spans = [ + Span(text=ann[0], concept_id=ann[1], concept_label=concept_dict.get(ann[1], None)) + for ann in anns + ] + return AnnotatedText(input_text=text, spans=spans, annotated_text=response.text()) def annotate_concept_list( self, @@ -354,18 +383,28 @@ def annotate_concept_list( categories: List[str] = None, **kwargs, ) -> AnnotatedText: - system_prompt = MENTION_PROMPT - concept_pairs, concepts_prompt = self._label_id_pairs_prompt_section(text, collection, **kwargs) + concept_pairs, concepts_prompt = self._label_id_pairs_prompt_section( + text, collection, **kwargs + ) concept_dict = {c[0]: c[1] for c in concept_pairs} system_prompt += concepts_prompt model = self.extractor.model logger.debug(f"Prompting with: {text}") response = model.prompt(text, system=system_prompt) spans = parse_spans(response.text(), concept_dict) - return AnnotatedText(input_text=text, summary=response.text(), spans=spans, prompt=system_prompt) + return AnnotatedText( + input_text=text, summary=response.text(), spans=spans, prompt=system_prompt + ) - def _label_id_pairs_prompt_section(self, text:str, collection: str, prolog: str = None, relevance_factor: float = None, **kwargs) -> Tuple[List[CONCEPT], str]: + def _label_id_pairs_prompt_section( + self, + text: str, + collection: str, + prolog: str = None, + relevance_factor: float = None, + **kwargs, + ) -> Tuple[List[CONCEPT], str]: prompt = prolog if not prompt: prompt = "Here are the candidate concepts, as label // ConceptID pairs:\n" @@ -395,4 +434,4 @@ def _label_id_pairs_prompt_section(self, text:str, collection: str, prolog: str raise ValueError(f"Object {obj} has no label field {label_field}") prompt += f"{label} // {id} \n" concept_pairs.append((id, label)) - return concept_pairs, prompt \ No newline at end of file + return concept_pairs, prompt diff --git a/src/curate_gpt/agents/mapping_agent.py b/src/curate_gpt/agents/mapping_agent.py index 71b200b..a2ac1bc 100644 --- a/src/curate_gpt/agents/mapping_agent.py +++ b/src/curate_gpt/agents/mapping_agent.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from enum import Enum from random import shuffle -from typing import Any, Dict, Iterator, List, Optional, Union, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import inflection import yaml @@ -191,6 +191,8 @@ def find_links(self, other_collection: str) -> Iterator[Tuple[str, str, str]]: :return: """ # TODO - for obj, _, info in self.knowledge_source.find(collection=other_collection, include = ["embeddings", "documents", "metadatas"]): + for obj, _, info in self.knowledge_source.find( + collection=other_collection, include=["embeddings", "documents", "metadatas"] + ): embeddings = info["embeddings"] self.knowledge_source.find(embeddings, limit=10) diff --git a/src/curate_gpt/agents/summarization_agent.py b/src/curate_gpt/agents/summarization_agent.py index ab4cbec..12ae39f 100644 --- a/src/curate_gpt/agents/summarization_agent.py +++ b/src/curate_gpt/agents/summarization_agent.py @@ -1,12 +1,13 @@ import logging from dataclasses import dataclass -from typing import Dict, Union, List +from typing import Dict, List, Union from curate_gpt.agents.base_agent import BaseAgent from curate_gpt.wrappers import BaseWrapper logger = logging.getLogger(__name__) + @dataclass class SummarizationAgent(BaseAgent): """ @@ -15,7 +16,14 @@ class SummarizationAgent(BaseAgent): AKA SPINDOCTOR/TALISMAN """ - def summarize(self, object_ids: List[str], description_field: str, name_field: str, strict: bool = False, system_prompt: str = None): + def summarize( + self, + object_ids: List[str], + description_field: str, + name_field: str, + strict: bool = False, + system_prompt: str = None, + ): """ Summarize a list of objects. @@ -48,9 +56,13 @@ def summarize(self, object_ids: List[str], description_field: str, name_field: s if isinstance(knowledge_source, BaseWrapper): object_iter = knowledge_source.objects(self.knowledge_source_collection, object_ids) else: - object_iter = knowledge_source.lookup_multiple(object_ids, collection=self.knowledge_source_collection) + object_iter = knowledge_source.lookup_multiple( + object_ids, collection=self.knowledge_source_collection + ) objects = list(object_iter) - descriptions = [(obj.get(name_field, ""), obj.get(description_field, None)) for obj in objects] + descriptions = [ + (obj.get(name_field, ""), obj.get(description_field, None)) for obj in objects + ] if any(desc[0] is None for desc in descriptions): raise ValueError(f"Missing name for objects: {objects}") if strict: diff --git a/src/curate_gpt/app/app.py b/src/curate_gpt/app/app.py index 71e6aa2..fcb7b6a 100644 --- a/src/curate_gpt/app/app.py +++ b/src/curate_gpt/app/app.py @@ -11,8 +11,8 @@ from curate_gpt import BasicExtractor from curate_gpt.agents import MappingAgent from curate_gpt.agents.chat_agent import ChatAgent, ChatResponse -from curate_gpt.agents.dragon_agent import DragonAgent from curate_gpt.agents.dase_agent import DatabaseAugmentedStructuredExtraction +from curate_gpt.agents.dragon_agent import DragonAgent from curate_gpt.agents.evidence_agent import EvidenceAgent from curate_gpt.app.components import ( DimensionalityReductionOptions, diff --git a/src/curate_gpt/cli.py b/src/curate_gpt/cli.py index b09217d..2f7a09f 100644 --- a/src/curate_gpt/cli.py +++ b/src/curate_gpt/cli.py @@ -5,13 +5,13 @@ import logging import sys from pathlib import Path -from typing import Dict, List, Union, Any +from typing import Any, Dict, List, Union import click import pandas as pd import yaml from click_default_group import DefaultGroup -from linkml_runtime.dumpers import yaml_dumper, json_dumper +from linkml_runtime.dumpers import json_dumper, yaml_dumper from linkml_runtime.utils.yamlutils import YAMLRoot from llm import UnknownModelError, get_model, get_plugins from llm.cli import load_conversation @@ -20,9 +20,9 @@ from curate_gpt import ChromaDBAdapter, __version__ from curate_gpt.agents.chat_agent import ChatAgent, ChatResponse -from curate_gpt.agents.concept_recognition_agent import ConceptRecognitionAgent, AnnotationMethod -from curate_gpt.agents.dragon_agent import DragonAgent +from curate_gpt.agents.concept_recognition_agent import AnnotationMethod, ConceptRecognitionAgent from curate_gpt.agents.dase_agent import DatabaseAugmentedStructuredExtraction +from curate_gpt.agents.dragon_agent import DragonAgent from curate_gpt.agents.evidence_agent import EvidenceAgent from curate_gpt.agents.summarization_agent import SummarizationAgent from curate_gpt.evaluation.dae_evaluator import DatabaseAugmentedCompletionEvaluator @@ -69,10 +69,13 @@ def dump(obj: Union[str, AnnotatedObject, Dict], format="yaml") -> None: raise ValueError(f"Unknown format {format}") print(ser) + # logger = logging.getLogger(__name__) path_option = click.option("-p", "--path", help="Path to a file or directory for database.") -model_option = click.option("-m", "--model", help="Model to use for generation or embedding, e.g. gpt-4.") +model_option = click.option( + "-m", "--model", help="Model to use for generation or embedding, e.g. gpt-4." +) schema_option = click.option("-s", "--schema", help="Path to schema.") collection_option = click.option("-c", "--collection", help="Collection within the database.") output_format_option = click.option( @@ -323,15 +326,26 @@ def search(query, path, collection, show_documents, **kwargs): multiple=True, help="Field to show from right collection.", ) - @output_format_option -def all_by_all(path, collection, other_collection, other_path, threshold, ids_only, output_format, left_field, right_field, **kwargs): +def all_by_all( + path, + collection, + other_collection, + other_path, + threshold, + ids_only, + output_format, + left_field, + right_field, + **kwargs, +): """Match two collections.""" db = ChromaDBAdapter(path) if other_path is None: other_path = path other_db = ChromaDBAdapter(other_path) results = match_collections(db, collection, other_collection, other_db) + def _obj(obj: Dict, is_left=False) -> Any: if ids_only: obj = {"id": obj["id"]} @@ -342,6 +356,7 @@ def _obj(obj: Dict, is_left=False) -> Any: side = "left" if is_left else "right" obj = {f"{side}_{k}": v for k, v in obj.items()} return obj + i = 0 for obj1, obj2, sim in results: if threshold and sim < threshold: @@ -362,6 +377,7 @@ def _obj(obj: Dict, is_left=False) -> Any: dump(obj1, output_format) dump(obj2, output_format) + @main.command() @path_option @collection_option @@ -395,9 +411,7 @@ def matches(id, path, collection): "-L", help="Field to use as label (defaults to label).", ) -@click.option( - "-l", "--limit", default=50, show_default=True, help="Number of candidate terms." -) +@click.option("-l", "--limit", default=50, show_default=True, help="Number of candidate terms.") @click.option( "--input-file", "-i", @@ -430,9 +444,20 @@ def matches(id, path, collection): multiple=True, help="Category/ies for candidate IDs.", ) - @click.argument("texts", nargs=-1) -def annotate(texts, path, model, collection, input_file, split_sentences, category, prefix, identifier_field, label_field, **kwargs): +def annotate( + texts, + path, + model, + collection, + input_file, + split_sentences, + category, + prefix, + identifier_field, + label_field, + **kwargs, +): """Concept recognition.""" db = ChromaDBAdapter(path) extractor = BasicExtractor() @@ -460,6 +485,7 @@ def annotate(texts, path, model, collection, input_file, split_sentences, catego dump(ao) print(f"---\n") + @main.command() @path_option @collection_option @@ -784,6 +810,7 @@ def complete_multiple( print("---") dump(ao.object, format=output_format) + @main.command() @path_option @collection_option @@ -1266,7 +1293,7 @@ def citeseek(query, path, collection, model, show_references, _continue, convers @collection_option @path_option @model_option -@click.option("--view", "-V", help="Name of the wrapper to use.") +@click.option("--view", "-V", help="Name of the wrapper to use.") @click.option("--name-field", help="Field for names.") @click.option("--description-field", help="Field for names.") @click.option("--system-prompt", help="System gpt prompt to use.") @@ -1582,7 +1609,6 @@ def unwrap_objects(input_file, view, path, collection, output_format, **kwargs): dump(unwrapped, output_format) - @view.command(name="search") @click.option("--view", "-V") @click.option("--source-locator") diff --git a/src/curate_gpt/evaluation/runner.py b/src/curate_gpt/evaluation/runner.py index 6569afb..cb408af 100644 --- a/src/curate_gpt/evaluation/runner.py +++ b/src/curate_gpt/evaluation/runner.py @@ -81,9 +81,7 @@ def run_task( task.executed_on = ( f"{platform.system()}-{platform.release()}-{platform.version()}-{platform.machine()}" ) - agent = DragonAgent( - knowledge_source=tdb, knowledge_source_collection="", extractor=extractor - ) + agent = DragonAgent(knowledge_source=tdb, knowledge_source_collection="", extractor=extractor) if task.additional_collections: if len(task.additional_collections) > 1: raise NotImplementedError("Only one additional collection is supported") diff --git a/src/curate_gpt/store/chromadb_adapter.py b/src/curate_gpt/store/chromadb_adapter.py index 89e7430..f1c7fe6 100644 --- a/src/curate_gpt/store/chromadb_adapter.py +++ b/src/curate_gpt/store/chromadb_adapter.py @@ -8,7 +8,8 @@ import chromadb import yaml -from chromadb import ClientAPI as API, Settings +from chromadb import ClientAPI as API +from chromadb import Settings from chromadb.api import EmbeddingFunction from chromadb.types import Collection from chromadb.utils import embedding_functions @@ -431,9 +432,13 @@ def find( f"Empty metadata for item {i} [num: {len(metadatas)}] doc: {documents[i]}" ) continue - obj = self._unjson(metadatas[i]), 0.0, { - "document": documents[i], - } + obj = ( + self._unjson(metadatas[i]), + 0.0, + { + "document": documents[i], + }, + ) if embeddings: obj[2]["_embeddings"] = embeddings[i] yield obj @@ -558,7 +563,6 @@ def dump_then_load(self, collection: str = None, target: DBAdapter = None): logger.debug(f"Dumping {i} of {len(result['ids'])}") batched_obj = {} for k in ["ids", "metadatas", "documents", "embeddings"]: - batched_obj[k] = result[k][i: i + batch_size] + batched_obj[k] = result[k][i : i + batch_size] target_collection_obj.add(**batched_obj) i += batch_size - diff --git a/src/curate_gpt/store/db_adapter.py b/src/curate_gpt/store/db_adapter.py index 9f0a7eb..ad455f5 100644 --- a/src/curate_gpt/store/db_adapter.py +++ b/src/curate_gpt/store/db_adapter.py @@ -317,7 +317,6 @@ def lookup_multiple(self, ids: List[str], **kwargs) -> Iterator[OBJECT]: """ yield from [self.lookup(id, **kwargs) for id in ids] - @abstractmethod def peek(self, collection: str = None, limit=5, **kwargs) -> Iterator[OBJECT]: """ @@ -398,7 +397,7 @@ def dump( format = "json" if not include: include = ["embeddings", "documents", "metadatas"] - #include = ["embeddings", "documents", "metadatas"] + # include = ["embeddings", "documents", "metadatas"] if not isinstance(include, list): include = list(include) objects = self.find(collection=collection, include=include, **kwargs) diff --git a/src/curate_gpt/utils/search.py b/src/curate_gpt/utils/search.py index b3e8633..eea436a 100644 --- a/src/curate_gpt/utils/search.py +++ b/src/curate_gpt/utils/search.py @@ -1,5 +1,3 @@ import logging logger = logging.getLogger(__name__) - - diff --git a/src/curate_gpt/utils/vector_algorithms.py b/src/curate_gpt/utils/vector_algorithms.py index cef88ae..2222e9d 100644 --- a/src/curate_gpt/utils/vector_algorithms.py +++ b/src/curate_gpt/utils/vector_algorithms.py @@ -47,7 +47,9 @@ def top_matches(cosine_similarity_matrix: np.ndarray) -> Tuple[np.ndarray, np.nd return top_match_indices, top_match_values -def top_n_matches(cosine_similarity_matrix: np.ndarray, n: int=10) -> Tuple[np.ndarray, np.ndarray]: +def top_n_matches( + cosine_similarity_matrix: np.ndarray, n: int = 10 +) -> Tuple[np.ndarray, np.ndarray]: # Find the indices that would sort each row in descending order sorted_indices = np.argsort(-cosine_similarity_matrix, axis=1) diff --git a/src/curate_gpt/utils/vectordb_operations.py b/src/curate_gpt/utils/vectordb_operations.py index 2010afe..08e06fd 100644 --- a/src/curate_gpt/utils/vectordb_operations.py +++ b/src/curate_gpt/utils/vectordb_operations.py @@ -4,10 +4,12 @@ from curate_gpt import DBAdapter from curate_gpt.utils.vector_algorithms import compute_cosine_similarity, top_matches - logger = logging.getLogger(__name__) -def match_collections(db: DBAdapter, left_collection: str, right_collection: str, other_db: DBAdapter=None) -> Iterator[Tuple[dict, dict, float]]: + +def match_collections( + db: DBAdapter, left_collection: str, right_collection: str, other_db: DBAdapter = None +) -> Iterator[Tuple[dict, dict, float]]: """ Match every element in left collection with every element in right collection. diff --git a/src/curate_gpt/wrappers/__init__.py b/src/curate_gpt/wrappers/__init__.py index 3af5a15..20a5aaf 100644 --- a/src/curate_gpt/wrappers/__init__.py +++ b/src/curate_gpt/wrappers/__init__.py @@ -40,6 +40,11 @@ def get_all_subclasses(cls): def get_wrapper(name: str, **kwargs) -> BaseWrapper: # NOTE: ORDER DEPENDENT. TODO: fix this + from curate_gpt.wrappers.bio.alliance_gene_wrapper import AllianceGeneWrapper # noqa + from curate_gpt.wrappers.bio.bacdive_wrapper import BacDiveWrapper # noqa + from curate_gpt.wrappers.bio.gocam_wrapper import GOCAMWrapper # noqa + from curate_gpt.wrappers.bio.mediadive_wrapper import MediaDiveWrapper # noqa + from curate_gpt.wrappers.bio.reactome_wrapper import ReactomeWrapper # noqa from curate_gpt.wrappers.clinical.clinvar_wrapper import ClinVarWrapper # noqa from curate_gpt.wrappers.clinical.hpoa_by_pub_wrapper import HPOAByPubWrapper # noqa from curate_gpt.wrappers.clinical.hpoa_wrapper import HPOAWrapper # noqa @@ -50,6 +55,7 @@ def get_wrapper(name: str, **kwargs) -> BaseWrapper: from curate_gpt.wrappers.general.gspread_wrapper import GSpreadWrapper # noqa from curate_gpt.wrappers.general.json_wrapper import JSONWrapper # noqa from curate_gpt.wrappers.general.linkml_schema_wrapper import LinkMLSchemarapper # noqa + from curate_gpt.wrappers.investigation.ess_deepdive_wrapper import ESSDeepDiveWrapper # noqa from curate_gpt.wrappers.investigation.ncbi_bioproject_wrapper import ( # noqa NCBIBioprojectWrapper, ) @@ -57,20 +63,14 @@ def get_wrapper(name: str, **kwargs) -> BaseWrapper: NCBIBiosampleWrapper, ) from curate_gpt.wrappers.investigation.nmdc_wrapper import NMDCWrapper # noqa - from curate_gpt.wrappers.investigation.ess_deepdive_wrapper import ESSDeepDiveWrapper # noqa + from curate_gpt.wrappers.legal.reusabledata_wrapper import ReusableDataWrapper # noqa from curate_gpt.wrappers.literature.bioc_wrapper import BiocWrapper # noqa from curate_gpt.wrappers.literature.pmc_wrapper import PMCWrapper # noqa from curate_gpt.wrappers.literature.pubmed_wrapper import PubmedWrapper # noqa from curate_gpt.wrappers.literature.wikipedia_wrapper import WikipediaWrapper # noqa from curate_gpt.wrappers.ontology.bioportal_wrapper import BioportalWrapper # noqa - from curate_gpt.wrappers.ontology.ontology_wrapper import OntologyWrapper # noqa from curate_gpt.wrappers.ontology.oboformat_wrapper import OBOFormatWrapper # noqa - from curate_gpt.wrappers.bio.gocam_wrapper import GOCAMWrapper # noqa - from curate_gpt.wrappers.bio.alliance_gene_wrapper import AllianceGeneWrapper # noqa - from curate_gpt.wrappers.bio.mediadive_wrapper import MediaDiveWrapper # noqa - from curate_gpt.wrappers.bio.bacdive_wrapper import BacDiveWrapper # noqa - from curate_gpt.wrappers.bio.reactome_wrapper import ReactomeWrapper # noqa - from curate_gpt.wrappers.legal.reusabledata_wrapper import ReusableDataWrapper # noqa + from curate_gpt.wrappers.ontology.ontology_wrapper import OntologyWrapper # noqa for c in get_all_subclasses(BaseWrapper): if c.name == name: diff --git a/src/curate_gpt/wrappers/base_wrapper.py b/src/curate_gpt/wrappers/base_wrapper.py index 7a4b9f8..c774b05 100644 --- a/src/curate_gpt/wrappers/base_wrapper.py +++ b/src/curate_gpt/wrappers/base_wrapper.py @@ -171,7 +171,7 @@ def split_objects(self, objects: List[Dict], text_field="text", id_field="id") - new_obj[id_field] = f"{obj_id}#{n}" new_obj[text_field] = text[: self.max_text_length + self.text_overlap] new_objects.append(new_obj) - text = text[self.max_text_length:] + text = text[self.max_text_length :] else: new_objects.append(obj) return new_objects diff --git a/src/curate_gpt/wrappers/bio/alliance_gene_wrapper.py b/src/curate_gpt/wrappers/bio/alliance_gene_wrapper.py index 059ac1f..a2252ea 100644 --- a/src/curate_gpt/wrappers/bio/alliance_gene_wrapper.py +++ b/src/curate_gpt/wrappers/bio/alliance_gene_wrapper.py @@ -1,16 +1,15 @@ """Chat with a KB.""" import gzip -import os -import requests import logging +import os from dataclasses import dataclass, field +from glob import glob from typing import ClassVar, Dict, Iterable, Iterator, Optional +import requests import requests_cache -from bs4 import BeautifulSoup -from glob import glob - import yaml +from bs4 import BeautifulSoup from oaklib import BasicOntologyInterface, get_adapter from curate_gpt.wrappers import BaseWrapper @@ -19,6 +18,7 @@ BASE_URL = "https://www.alliancegenome.org/api" + @dataclass class AllianceGeneWrapper(BaseWrapper): """ @@ -35,8 +35,7 @@ class AllianceGeneWrapper(BaseWrapper): taxon_id: str = field(default="NCBITaxon:9606") - - def object_ids(self, taxon_id: str=None, **kwargs) -> Iterator[str]: + def object_ids(self, taxon_id: str = None, **kwargs) -> Iterator[str]: """ Get all gene ids for a given taxon id @@ -48,10 +47,13 @@ def object_ids(self, taxon_id: str=None, **kwargs) -> Iterator[str]: if not taxon_id: taxon_id = self.taxon_id - response = session.get(f"{BASE_URL}/geneMap/geneIDs", params={ - "taxonID": taxon_id, - "rows": 50000, - }) + response = session.get( + f"{BASE_URL}/geneMap/geneIDs", + params={ + "taxonID": taxon_id, + "rows": 50000, + }, + ) response.raise_for_status() gene_ids = response.text.split(",") yield from gene_ids @@ -82,4 +84,3 @@ def objects( response.raise_for_status() obj = response.json() yield obj - diff --git a/src/curate_gpt/wrappers/bio/bacdive_wrapper.py b/src/curate_gpt/wrappers/bio/bacdive_wrapper.py index b33fb8d..5d87c45 100644 --- a/src/curate_gpt/wrappers/bio/bacdive_wrapper.py +++ b/src/curate_gpt/wrappers/bio/bacdive_wrapper.py @@ -40,7 +40,7 @@ def wrap_object(self, obj: Dict) -> Iterator[Dict]: general = obj["General"] name_info = obj["Name and taxonomic classification"] new_obj = {} - new_obj["id"] = self.create_curie(general['BacDive-ID']) + new_obj["id"] = self.create_curie(general["BacDive-ID"]) new_obj["name"] = name_info.get("full scientific name", None) if not new_obj["name"]: new_obj["name"] = name_info["LPSN"].get("scientific name", None) @@ -55,4 +55,3 @@ def wrap_object(self, obj: Dict) -> Iterator[Dict]: break new_obj = {**new_obj, **obj} yield new_obj - diff --git a/src/curate_gpt/wrappers/bio/mediadive_wrapper.py b/src/curate_gpt/wrappers/bio/mediadive_wrapper.py index 2fbfe39..0a810c9 100644 --- a/src/curate_gpt/wrappers/bio/mediadive_wrapper.py +++ b/src/curate_gpt/wrappers/bio/mediadive_wrapper.py @@ -3,7 +3,6 @@ from typing import ClassVar, Dict, Iterable, Iterator, Optional import requests_cache - from oaklib import BasicOntologyInterface, get_adapter from curate_gpt.wrappers import BaseWrapper @@ -12,6 +11,7 @@ BASE_URL = "https://mediadive.dsmz.de/rest" + @dataclass class MediaDiveWrapper(BaseWrapper): """ @@ -69,4 +69,3 @@ def objects( else: logger.warning(f"No solutions for {object_id}") yield obj - diff --git a/src/curate_gpt/wrappers/bio/reactome_wrapper.py b/src/curate_gpt/wrappers/bio/reactome_wrapper.py index 7c07a58..ffd64af 100644 --- a/src/curate_gpt/wrappers/bio/reactome_wrapper.py +++ b/src/curate_gpt/wrappers/bio/reactome_wrapper.py @@ -1,16 +1,15 @@ """Chat with a KB.""" import gzip -import os -import requests import logging +import os from dataclasses import dataclass, field -from typing import ClassVar, Dict, Iterable, Iterator, Optional, List - -import requests_cache -from bs4 import BeautifulSoup from glob import glob +from typing import ClassVar, Dict, Iterable, Iterator, List, Optional +import requests +import requests_cache import yaml +from bs4 import BeautifulSoup from oaklib import BasicOntologyInterface, get_adapter from curate_gpt.wrappers import BaseWrapper @@ -19,6 +18,7 @@ BASE_URL = "https://reactome.org/ContentService/data" + def ids_from_tree(objs: List): """ Recursively yield all ids from a tree of objects @@ -91,7 +91,7 @@ class ReactomeWrapper(BaseWrapper): taxon_id: str = field(default="NCBITaxon:9606") - def object_ids(self, taxon_id: str=None, **kwargs) -> Iterator[str]: + def object_ids(self, taxon_id: str = None, **kwargs) -> Iterator[str]: """ Get all object ids for a given taxon id @@ -137,12 +137,13 @@ def objects( response = session.get(f"{BASE_URL}/query/{local_id}") obj = response.json() summations = obj["summation"] - new_obj = {"id": object_id, - "label": obj["displayName"], - "speciesName": obj["speciesName"], - "description": "\n".join([x["text"] for x in summations]), - "type": obj["schemaClass"], - } + new_obj = { + "id": object_id, + "label": obj["displayName"], + "speciesName": obj["speciesName"], + "description": "\n".join([x["text"] for x in summations]), + "type": obj["schemaClass"], + } for key, func in OBJECT_FUNCTION_MAP.items(): if key in obj: new_obj[key] = [func(x) for x in obj[key] if isinstance(x, dict)] diff --git a/src/curate_gpt/wrappers/investigation/ess_deepdive_wrapper.py b/src/curate_gpt/wrappers/investigation/ess_deepdive_wrapper.py index 583f957..3ff1582 100644 --- a/src/curate_gpt/wrappers/investigation/ess_deepdive_wrapper.py +++ b/src/curate_gpt/wrappers/investigation/ess_deepdive_wrapper.py @@ -27,7 +27,9 @@ def _get_records_chunk(session: requests_cache.CachedSession, cursor=1, limit=20 raise ValueError(f"Could not download records from {url}") -def get_records(session: requests_cache.CachedSession, cursor=1, limit=200, maximum: int = None) -> Iterator[dict]: +def get_records( + session: requests_cache.CachedSession, cursor=1, limit=200, maximum: int = None +) -> Iterator[dict]: """ Iterate through all records in ESSDeepDive and download them. @@ -61,7 +63,9 @@ class ESSDeepDiveWrapper(BaseWrapper): default_object_type = "Class" - session: requests_cache.CachedSession = field(default_factory=lambda: requests_cache.CachedSession("ess_deepdive")) + session: requests_cache.CachedSession = field( + default_factory=lambda: requests_cache.CachedSession("ess_deepdive") + ) limit: int = field(default=50) diff --git a/src/curate_gpt/wrappers/legal/reusabledata_wrapper.py b/src/curate_gpt/wrappers/legal/reusabledata_wrapper.py index 970ab0a..9e0e930 100644 --- a/src/curate_gpt/wrappers/legal/reusabledata_wrapper.py +++ b/src/curate_gpt/wrappers/legal/reusabledata_wrapper.py @@ -1,16 +1,15 @@ """Chat with a KB.""" import gzip -import os -import requests import logging +import os from dataclasses import dataclass, field +from glob import glob from typing import ClassVar, Dict, Iterable, Iterator, Optional +import requests import requests_cache -from bs4 import BeautifulSoup -from glob import glob - import yaml +from bs4 import BeautifulSoup from oaklib import BasicOntologyInterface, get_adapter from curate_gpt.wrappers import BaseWrapper @@ -52,7 +51,15 @@ def objects( obj_id = obj["id"] license_link = obj.get("license-link", None) if license_link: - if license_link in ["TODO", "https://", "inconsistent", "https://civic.genome.wustl.edu/about", "http://www.supfam.org/about", "ftp://ftp.nextprot.org/pub/README", "ftp://ftp.jcvi.org/pub/data/TIGRFAMs/COPYRIGHT"]: + if license_link in [ + "TODO", + "https://", + "inconsistent", + "https://civic.genome.wustl.edu/about", + "http://www.supfam.org/about", + "ftp://ftp.nextprot.org/pub/README", + "ftp://ftp.jcvi.org/pub/data/TIGRFAMs/COPYRIGHT", + ]: logger.warning(f"base link {license_link} for {obj_id}") continue if license_link.startswith("ftp://"): @@ -63,8 +70,7 @@ def objects( logger.warning(f"bad link {license_link} for {obj_id}") continue data = response.text - soup = BeautifulSoup(data, 'html.parser') + soup = BeautifulSoup(data, "html.parser") license_text = soup.get_text() obj["license_text"] = license_text yield obj - diff --git a/src/curate_gpt/wrappers/literature/pubmed_wrapper.py b/src/curate_gpt/wrappers/literature/pubmed_wrapper.py index a0458fe..5f5b053 100644 --- a/src/curate_gpt/wrappers/literature/pubmed_wrapper.py +++ b/src/curate_gpt/wrappers/literature/pubmed_wrapper.py @@ -222,7 +222,7 @@ def fetch_full_text(self, object_id: str) -> Optional[str]: parsed_url = urlparse(download_url) if parsed_url.scheme not in ["http", "https", "ftp"]: continue - urlretrieve(download_url, local_file_path) # noqa S310 + urlretrieve(download_url, local_file_path) # noqa S310 # Open and extract the tar.gz file with tarfile.open(local_file_path, "r:gz") as tar: diff --git a/src/curate_gpt/wrappers/ontology/ontology_wrapper.py b/src/curate_gpt/wrappers/ontology/ontology_wrapper.py index b5794aa..050a88a 100644 --- a/src/curate_gpt/wrappers/ontology/ontology_wrapper.py +++ b/src/curate_gpt/wrappers/ontology/ontology_wrapper.py @@ -184,8 +184,9 @@ def retrieve_shorthand_to_id_from_store(self, store: DBAdapter) -> Mapping[str, def unwrap_object(self, obj: Dict[str, Any], store: DBAdapter, **kwargs) -> og.Graph: return self.unwrap_objects([obj], store, **kwargs) - - def unwrap_objects(self, objs: Iterable[Dict[str, Any]], store: DBAdapter, drop_dangling=False, **kwargs) -> og.GraphDocument: + def unwrap_objects( + self, objs: Iterable[Dict[str, Any]], store: DBAdapter, drop_dangling=False, **kwargs + ) -> og.GraphDocument: """ Convert an object from the store to the view representation. diff --git a/tests/agents/test_concept_recognizer.py b/tests/agents/test_concept_recognizer.py index a97eb11..fc7d9c2 100644 --- a/tests/agents/test_concept_recognizer.py +++ b/tests/agents/test_concept_recognizer.py @@ -1,7 +1,7 @@ import pytest import yaml -from curate_gpt.agents.concept_recognition_agent import ConceptRecognitionAgent, AnnotationMethod +from curate_gpt.agents.concept_recognition_agent import AnnotationMethod, ConceptRecognitionAgent from curate_gpt.extract.basic_extractor import BasicExtractor @@ -9,33 +9,36 @@ "text,categories,prefixes,expected", [ ( - "A metabolic process that results in the breakdown of chemicals in vacuolar structures.", - ["BiologicalProcess", "SubcellularStructure"], - ["GO"], - ["GO:0044237", "GO:0005773"], + "A metabolic process that results in the breakdown of chemicals in vacuolar structures.", + ["BiologicalProcess", "SubcellularStructure"], + ["GO"], + ["GO:0044237", "GO:0005773"], ), ( - "A metabolic process that results in the breakdown of chemicals in vacuolar structures.", - ["BiologicalProcess", "SubcellularStructure"], - ["FAKE"], - [], + "A metabolic process that results in the breakdown of chemicals in vacuolar structures.", + ["BiologicalProcess", "SubcellularStructure"], + ["FAKE"], + [], ), ( - "Protoplasm", - None, - None, - ["GO:0005622"], + "Protoplasm", + None, + None, + ["GO:0005622"], ), -( - "The photosynthetic membrane of plants and algae", - ["BiologicalProcess", "SubcellularStructure", "OrganismTaxon"], - None, - ["GO:0005622"], + ( + "The photosynthetic membrane of plants and algae", + ["BiologicalProcess", "SubcellularStructure", "OrganismTaxon"], + None, + ["GO:0005622"], ), ], ) -@pytest.mark.parametrize("method", [AnnotationMethod.CONCEPT_LIST, AnnotationMethod.CONCEPT_LIST, AnnotationMethod.TWO_PASS]) -def test_concept_recognizer(go_test_chroma_db, text, categories,prefixes, expected, method): +@pytest.mark.parametrize( + "method", + [AnnotationMethod.CONCEPT_LIST, AnnotationMethod.CONCEPT_LIST, AnnotationMethod.TWO_PASS], +) +def test_concept_recognizer(go_test_chroma_db, text, categories, prefixes, expected, method): limit = 50 if method == AnnotationMethod.TWO_PASS: limit = 10 @@ -44,7 +47,9 @@ def test_concept_recognizer(go_test_chroma_db, text, categories,prefixes, expect cr.prefixes = prefixes cr.identifier_field = "original_id" print(f"## METHOD: {method} CATEGORY: {categories} PREFIXES: {prefixes}") - ann = cr.annotate(text, collection="terms_go", method=method, categories=categories, limit=limit) + ann = cr.annotate( + text, collection="terms_go", method=method, categories=categories, limit=limit + ) print("RESULT:") print(yaml.dump(ann.dict(), sort_keys=False)) overlap = len(set(ann.concepts).intersection(set(expected))) diff --git a/tests/agents/test_dase.py b/tests/agents/test_dase.py index 8f1da80..dd76af7 100644 --- a/tests/agents/test_dase.py +++ b/tests/agents/test_dase.py @@ -1,8 +1,8 @@ import pytest import yaml -from curate_gpt.agents.dragon_agent import DragonAgent from curate_gpt.agents.dase_agent import DatabaseAugmentedStructuredExtraction +from curate_gpt.agents.dragon_agent import DragonAgent from curate_gpt.extract.basic_extractor import BasicExtractor