Skip to content

Commit

Permalink
Format code using tox
Browse files Browse the repository at this point in the history
  • Loading branch information
hrshdhgd committed Dec 4, 2023
1 parent ca6454b commit 86d44bb
Show file tree
Hide file tree
Showing 23 changed files with 238 additions and 140 deletions.
99 changes: 69 additions & 30 deletions src/curate_gpt/agents/concept_recognition_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

class Span(BaseModel):
"""An individual span of text containing a single concept."""

text: str

start: Optional[int] = None
Expand All @@ -33,6 +34,7 @@ class Span(BaseModel):
is_suspect: bool = False
"""Potential hallucination due to ID/label mismatch."""


class GroundingResult(BaseModel):
"""Result of grounding text."""

Expand All @@ -48,6 +50,7 @@ class GroundingResult(BaseModel):

class AnnotationMethod(str, Enum):
"""Strategy or algorithm used for CR."""

INLINE = "inline"
"""LLM creates an annotated document"""

Expand Down Expand Up @@ -153,7 +156,7 @@ def parse_annotations(text, marker_char: str = None) -> List[CONCEPT]:
:return:
"""
# First Pass: Extract text within [ ... ]
pattern1 = r'\[([^\]]+)\]'
pattern1 = r"\[([^\]]+)\]"
matches = re.findall(pattern1, text)

# Second Pass: Parse the last token of each match
Expand All @@ -163,15 +166,15 @@ def parse_annotations(text, marker_char: str = None) -> List[CONCEPT]:
if marker_char:
toks = match.split(marker_char)
if len(toks) > 1:
annotation = ' '.join(toks[:-1]).strip()
annotation = " ".join(toks[:-1]).strip()
id = toks[-1].strip()
else:
annotation = match
id = None
else:
words = match.split()
if len(words) > 1 and ":" in words[-1]:
annotation = ' '.join(words[:-1])
annotation = " ".join(words[:-1])
id = words[-1]
else:
annotation = match
Expand All @@ -198,14 +201,19 @@ def parse_spans(text: str, concept_dict: Dict[str, str] = None) -> List[Span]:
concept_label = row[1].strip('"')
mention_text = ",".join(row[2:])
verified_concept_label = concept_dict.get(concept_id, None)
spans.append(Span(text=mention_text, concept_id=concept_id, concept_label=verified_concept_label,
is_suspect=verified_concept_label != concept_label))
spans.append(
Span(
text=mention_text,
concept_id=concept_id,
concept_label=verified_concept_label,
is_suspect=verified_concept_label != concept_label,
)
)
return spans


@dataclass
class ConceptRecognitionAgent(BaseAgent):

identifier_field: str = None
"""Field to use as identifier for objects."""

Expand All @@ -225,16 +233,17 @@ def ground_concept(
text: str,
collection: str = None,
categories: Optional[List[str]] = None,
include_category_in_search = True,
include_category_in_search=True,
context: str = None,
**kwargs,
) -> GroundingResult:

system_prompt = GROUND_PROMPT
query = text
if include_category_in_search and categories:
query += " Categories: " + ", ".join(categories)
concept_pairs, concept_prompt = self._label_id_pairs_prompt_section(query, collection, **kwargs)
concept_pairs, concept_prompt = self._label_id_pairs_prompt_section(
query, collection, **kwargs
)
concept_dict = {c[0]: c[1] for c in concept_pairs}
system_prompt += concept_prompt
model = self.extractor.model
Expand Down Expand Up @@ -262,9 +271,14 @@ def ground_concept(
concept_label = concept_dict[concept_id]
else:
concept_label = None
span = Span(text=text, concept_id=concept_id, concept_label=concept_label, is_suspect=provided_concept_label != concept_label)
span = Span(
text=text,
concept_id=concept_id,
concept_label=concept_label,
is_suspect=provided_concept_label != concept_label,
)
spans.append(span)
#spans = parse_spans(response.text(), concept_dict)
# spans = parse_spans(response.text(), concept_dict)
ann = GroundingResult(input_text=text, annotated_text=response.text(), spans=spans)
return ann

Expand Down Expand Up @@ -296,12 +310,16 @@ def annotate_two_pass(
system_prompt = "Your job is to parse the supplied text, identifying instances of concepts "
if len(categories) == 1:
system_prompt += f" that represent some kind of {categories[0]}. "
system_prompt += ("Mark up the concepts in square brackets, "
"preserving the original text inside the brackets. ")
system_prompt += (
"Mark up the concepts in square brackets, "
"preserving the original text inside the brackets. "
)
else:
system_prompt += " that represent one of the following categories: "
system_prompt += ", ".join(categories)
system_prompt += "Mark up the concepts in square brackets, with the category after the pipe symbol, "
system_prompt += (
"Mark up the concepts in square brackets, with the category after the pipe symbol, "
)
system_prompt += "Using the syntax [ORIGINAL TEXT | CATEGORY]."
logger.debug(f"Prompting with: {text}")
model = self.extractor.model
Expand All @@ -310,12 +328,24 @@ def annotate_two_pass(
anns = parse_annotations(marked_up_text, "|")
spans = []
for term, category in anns:
concept = self.ground_concept(term, collection, categories=[category] if category else None, context=text, **kwargs)
concept = self.ground_concept(
term,
collection,
categories=[category] if category else None,
context=text,
**kwargs,
)
if not concept.spans:
logger.debug(f"Unable to ground concept {term} in category {category}")
continue
main_span = concept.spans[0]
spans.append(Span(text=term, concept_id=main_span.concept_id, concept_label=main_span.concept_label))
spans.append(
Span(
text=term,
concept_id=main_span.concept_id,
concept_label=main_span.concept_label,
)
)
return AnnotatedText(
input_text=text,
annotated_text=marked_up_text,
Expand All @@ -329,23 +359,22 @@ def annotate_inline(
categories: List[str] = None,
**kwargs,
) -> AnnotatedText:

system_prompt = ANNOTATE_PROMPT
concept_pairs, concepts_prompt = self._label_id_pairs_prompt_section(text, collection, **kwargs)
concept_pairs, concepts_prompt = self._label_id_pairs_prompt_section(
text, collection, **kwargs
)
concept_dict = {c[0]: c[1] for c in concept_pairs}
system_prompt += concepts_prompt
model = self.extractor.model
logger.debug(f"Prompting with: {text}")
response = model.prompt(text, system=system_prompt)
anns = parse_annotations(response.text())
logger.info(f"Anns: {anns}")
spans = [Span(text=ann[0], concept_id=ann[1], concept_label=concept_dict.get(ann[1], None)) for ann in anns]
return AnnotatedText(
input_text=text,
spans=spans,
annotated_text=response.text()
)

spans = [
Span(text=ann[0], concept_id=ann[1], concept_label=concept_dict.get(ann[1], None))
for ann in anns
]
return AnnotatedText(input_text=text, spans=spans, annotated_text=response.text())

def annotate_concept_list(
self,
Expand All @@ -354,18 +383,28 @@ def annotate_concept_list(
categories: List[str] = None,
**kwargs,
) -> AnnotatedText:

system_prompt = MENTION_PROMPT
concept_pairs, concepts_prompt = self._label_id_pairs_prompt_section(text, collection, **kwargs)
concept_pairs, concepts_prompt = self._label_id_pairs_prompt_section(
text, collection, **kwargs
)
concept_dict = {c[0]: c[1] for c in concept_pairs}
system_prompt += concepts_prompt
model = self.extractor.model
logger.debug(f"Prompting with: {text}")
response = model.prompt(text, system=system_prompt)
spans = parse_spans(response.text(), concept_dict)
return AnnotatedText(input_text=text, summary=response.text(), spans=spans, prompt=system_prompt)
return AnnotatedText(
input_text=text, summary=response.text(), spans=spans, prompt=system_prompt
)

def _label_id_pairs_prompt_section(self, text:str, collection: str, prolog: str = None, relevance_factor: float = None, **kwargs) -> Tuple[List[CONCEPT], str]:
def _label_id_pairs_prompt_section(
self,
text: str,
collection: str,
prolog: str = None,
relevance_factor: float = None,
**kwargs,
) -> Tuple[List[CONCEPT], str]:
prompt = prolog
if not prompt:
prompt = "Here are the candidate concepts, as label // ConceptID pairs:\n"
Expand Down Expand Up @@ -395,4 +434,4 @@ def _label_id_pairs_prompt_section(self, text:str, collection: str, prolog: str
raise ValueError(f"Object {obj} has no label field {label_field}")
prompt += f"{label} // {id} \n"
concept_pairs.append((id, label))
return concept_pairs, prompt
return concept_pairs, prompt
6 changes: 4 additions & 2 deletions src/curate_gpt/agents/mapping_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from enum import Enum
from random import shuffle
from typing import Any, Dict, Iterator, List, Optional, Union, Tuple
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import inflection
import yaml
Expand Down Expand Up @@ -191,6 +191,8 @@ def find_links(self, other_collection: str) -> Iterator[Tuple[str, str, str]]:
:return:
"""
# TODO
for obj, _, info in self.knowledge_source.find(collection=other_collection, include = ["embeddings", "documents", "metadatas"]):
for obj, _, info in self.knowledge_source.find(
collection=other_collection, include=["embeddings", "documents", "metadatas"]
):
embeddings = info["embeddings"]
self.knowledge_source.find(embeddings, limit=10)
20 changes: 16 additions & 4 deletions src/curate_gpt/agents/summarization_agent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
from dataclasses import dataclass
from typing import Dict, Union, List
from typing import Dict, List, Union

from curate_gpt.agents.base_agent import BaseAgent
from curate_gpt.wrappers import BaseWrapper

logger = logging.getLogger(__name__)


@dataclass
class SummarizationAgent(BaseAgent):
"""
Expand All @@ -15,7 +16,14 @@ class SummarizationAgent(BaseAgent):
AKA SPINDOCTOR/TALISMAN
"""

def summarize(self, object_ids: List[str], description_field: str, name_field: str, strict: bool = False, system_prompt: str = None):
def summarize(
self,
object_ids: List[str],
description_field: str,
name_field: str,
strict: bool = False,
system_prompt: str = None,
):
"""
Summarize a list of objects.
Expand Down Expand Up @@ -48,9 +56,13 @@ def summarize(self, object_ids: List[str], description_field: str, name_field: s
if isinstance(knowledge_source, BaseWrapper):
object_iter = knowledge_source.objects(self.knowledge_source_collection, object_ids)
else:
object_iter = knowledge_source.lookup_multiple(object_ids, collection=self.knowledge_source_collection)
object_iter = knowledge_source.lookup_multiple(
object_ids, collection=self.knowledge_source_collection
)
objects = list(object_iter)
descriptions = [(obj.get(name_field, ""), obj.get(description_field, None)) for obj in objects]
descriptions = [
(obj.get(name_field, ""), obj.get(description_field, None)) for obj in objects
]
if any(desc[0] is None for desc in descriptions):
raise ValueError(f"Missing name for objects: {objects}")
if strict:
Expand Down
2 changes: 1 addition & 1 deletion src/curate_gpt/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from curate_gpt import BasicExtractor
from curate_gpt.agents import MappingAgent
from curate_gpt.agents.chat_agent import ChatAgent, ChatResponse
from curate_gpt.agents.dragon_agent import DragonAgent
from curate_gpt.agents.dase_agent import DatabaseAugmentedStructuredExtraction
from curate_gpt.agents.dragon_agent import DragonAgent
from curate_gpt.agents.evidence_agent import EvidenceAgent
from curate_gpt.app.components import (
DimensionalityReductionOptions,
Expand Down
Loading

0 comments on commit 86d44bb

Please sign in to comment.