Skip to content

Commit

Permalink
Merge pull request #32 from arena-ai/pydantic_for_prompt
Browse files Browse the repository at this point in the history
Add template for entity extraction in DocumentDataExtractor
  • Loading branch information
ngrislain authored Oct 15, 2024
2 parents ba45256 + c83ba1e commit 8afbc1d
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 27 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
.DS_Store
.env
.env.*
.vscode/
*.lock
*.tgz
generated/
generated/

local_*.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Add column to dde to account for json_schema
Revision ID: 5b09eca9fc4d
Revises: 964e3df77cf0
Create Date: 2024-10-11 10:37:47.325026
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = '5b09eca9fc4d'
down_revision = '964e3df77cf0'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('documentdataextractor', sa.Column('response_template', sqlmodel.sql.sqltypes.AutoString(), nullable=False))
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('documentdataextractor', 'response_template')
# ### end Alembic commands ###
75 changes: 65 additions & 10 deletions backend/app/api/routes/dde.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from typing import Any, Iterable
from typing import Any, Iterable,Literal
from app.lm.models.chat_completion import TokenLogprob
from app.lm.models import ChatCompletionResponse
from fastapi import APIRouter, HTTPException, status, UploadFile
from fastapi.responses import JSONResponse
from sqlmodel import func, select
from sqlalchemy.exc import IntegrityError
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.csv as pc
import json
import math
import io

from pydantic import create_model,ValidationError
from app.api.deps import CurrentUser, SessionDep
from app.services import crud
from app.lm.models import ChatCompletionResponse, ChatCompletionRequest, Message as ChatCompletionMessage
Expand All @@ -22,7 +19,7 @@
from app.ops.documents import as_text
from app.models import (Message, DocumentDataExtractorCreate, DocumentDataExtractorUpdate, DocumentDataExtractor, DocumentDataExtractorOut, DocumentDataExtractorsOut,
DocumentDataExampleCreate, DocumentDataExampleUpdate, DocumentDataExample, DocumentDataExampleOut)

from openai.lib._pydantic import to_strict_json_schema
router = APIRouter()


Expand Down Expand Up @@ -77,7 +74,11 @@ def create_document_data_extractor(
"""
Create a new DocumentDataExtractor.
"""
document_data_extractor = DocumentDataExtractor.model_validate(document_data_extractor_in, update={"owner_id": current_user.id})
try:
create_pydantic_model(document_data_extractor_in.response_template)
except KeyError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="received incorrect response template")
document_data_extractor = DocumentDataExtractor.model_validate(document_data_extractor_in, update={"owner_id": current_user.id,"response_template":json.dumps(document_data_extractor_in.response_template)})
try:
session.add(document_data_extractor)
session.commit()
Expand All @@ -101,6 +102,13 @@ def update_document_data_extractor(
if not current_user.is_superuser and (document_data_extractor.owner_id != current_user.id):
raise HTTPException(status_code=400, detail="Not enough permissions")
update_dict = document_data_extractor_in.model_dump(exclude_unset=True)
pdyantic_dict=update_dict.pop('response_template')
if pdyantic_dict is not None:
try:
create_pydantic_model(pdyantic_dict)
except KeyError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="received incorrect response template")
update_dict['response_template']=json.dumps(pdyantic_dict)
document_data_extractor.sqlmodel_update(update_dict)
session.add(document_data_extractor)
session.commit()
Expand Down Expand Up @@ -152,7 +160,13 @@ def create_document_data_example(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="DocumentDataExtractor not found")
if not current_user.is_superuser and (document_data_extractor.owner_id != current_user.id):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not enough permissions")
document_data_example = DocumentDataExample.model_validate(document_data_example_in, update={"document_data_extractor_id": document_data_extractor.id})
#verify the example matches the template of the document data extractor
pyd_model=create_pydantic_model(json.loads(document_data_extractor.response_template))
try:
pyd_model.model_validate(document_data_example_in.data)
except ValidationError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Example data does match DocumentDataExtractor Template")
document_data_example = DocumentDataExample.model_validate(document_data_example_in, update={"document_data_extractor_id": document_data_extractor.id,'data':json.dumps(document_data_example_in.data)})
session.add(document_data_example)
session.commit()
session.refresh(document_data_example)
Expand All @@ -179,6 +193,15 @@ def update_document_data_example(
if document_data_example.document_data_extractor_id != document_data_extractor.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="DocumentDataExample not found in this DocumentDataExtractor")
update_dict = document_data_example_in.model_dump(exclude_unset=True)
data=update_dict.pop('data')
if data is not None:
pyd_model=create_pydantic_model(json.loads(document_data_extractor.response_template))
try:
pyd_model.model_validate(document_data_example_in.data)
except ValidationError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Example data does match DocumentDataExtractor Template")
else:
update_dict['data']=json.dumps(data)
document_data_example.sqlmodel_update(update_dict)
session.add(document_data_example)
session.commit()
Expand Down Expand Up @@ -236,24 +259,57 @@ async def extract_from_file(*, session: SessionDep, current_user: CurrentUser, n
ChatCompletionMessage(role="system", content=full_system_content),
ChatCompletionMessage(role="user", content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}")
]
pydantic_reponse=create_pydantic_model(json.loads(document_data_extractor.response_template))
format_response={"type": "json_schema",
"json_schema":{
"schema":to_strict_json_schema(pydantic_reponse),
"name":'response',
'strict':True}}

chat_completion_request = ChatCompletionRequest(
model='gpt-4o-2024-08-06',
messages=messages,
max_tokens=2000,
temperature=0.1,
logprobs=True,
top_logprobs= 5,
response_format=format_response

).model_dump(exclude_unset=True)

chat_completion_response = await ArenaHandler(session, document_data_extractor.owner, chat_completion_request).process_request()
extracted_info=chat_completion_response.choices[0].message.content
#TODO: handle refusal or case in which content was not correctly done
# TODO: Improve the prompt to ensure the output is always a valid JSON
json_string = extracted_info[extracted_info.find('{'):extracted_info.rfind('}')+1]
extracted_data = {k: v for k, v in json.loads(json_string).items() if k not in ('source', 'year')}
logprob_data = extract_logprobs_from_response(chat_completion_response, extracted_data)
return {'extracted_info': json.loads(json_string), 'logprob_data': logprob_data}


def create_pydantic_model(schema:dict[str,tuple[Literal['str','int','bool','float'],Literal['required','optional']]])->Any:
"""Creates a pydantic model from an input dictionary where
keys are names of entities to be retrieved, each value is a tuple specifying
the type of the entity and whether it is required or optional"""
# Convert string type names to actual Python types
field_types = {
'str': (str, ...), # ... means the field is required
'int': (int, ...),
'float': (float, ...),
'bool': (bool, ...),
}
optional_field_types={
'str': (str|None, ...), # ... means the field is required
'int': (int|None, ...),
'float': (float|None, ...),
'bool': (bool|None, ...),}

# Dynamically create a Pydantic model using create_model
fields = {name: field_types[ftype[0]] if ftype[1]=='required' else optional_field_types[ftype[0]] for name, ftype in schema.items()}
dynamic_model = create_model('DataExtractorSchema', **fields)
return dynamic_model


def validate_extracted_text(text: str):
if text == "":
raise HTTPException(status_code=500, detail="The extracted text from the document is empty. Please check if the document is corrupted.")
Expand Down Expand Up @@ -335,5 +391,4 @@ def combine_tokens(tokens_info: list[TokenLogprob], start_index: int) -> tuple[s
combined_token += tokens_info[i].token
combined_logprob += tokens_info[i].logprob

return combined_token, combined_logprob

return combined_token, combined_logprob
2 changes: 1 addition & 1 deletion backend/app/lm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from app.lm.models.chat_completion import (
LMApiKeys,
Function, FunctionDefinition,
ChatCompletionToolParam, Message, ResponseFormat, ChatCompletionRequest,
ChatCompletionToolParam, Message, ResponseFormatBase,ResponseFormat, ChatCompletionRequest,
TopLogprob, TokenLogprob, ChoiceLogprobs, Choice, CompletionUsage, ChatCompletionResponse,
)
from app.lm.models.evaluation import Evaluation, Score
Expand Down
2 changes: 1 addition & 1 deletion backend/app/lm/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel

from app.lm import models
from app.lm.models import Function, ChatCompletionToolParam, Message, ResponseFormat, TopLogprob, TokenLogprob, ChoiceLogprobs, Choice
from app.lm.models import Function, ChatCompletionToolParam, Message, TopLogprob, TokenLogprob, ChoiceLogprobs, Choice
"""
ChatCompletionCreate -> anthropic MessageCreateParams -> anthropic Message -> ChatCompletion
"""
Expand Down
35 changes: 31 additions & 4 deletions backend/app/lm/models/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Literal, Mapping, Sequence, Any
from typing import Literal, Mapping, Sequence, Any,Required,TypeAlias,Dict, Optional
from pydantic import BaseModel
from app.lm.models.settings import LMConfig


from typing_extensions import TypedDict
"""All LanguageModels"""

class LMApiKeys(BaseModel):
Expand Down Expand Up @@ -48,9 +47,37 @@ class Message(BaseModel):
"""The tool calls generated by the model, such as function calls."""


class ResponseFormat(BaseModel):

class JSONSchema(TypedDict, total=False):
name: Required[str]
"""The name of the response format.
Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length
of 64.
"""

schema: Dict[str, object]
"""The schema for the response format, described as a JSON Schema object."""

strict: Optional[bool]
"""Whether to enable strict schema adherence when generating the output.
If set to true, the model will always follow the exact schema defined in the
`schema` field. Only a subset of JSON Schema is supported when `strict` is
`true`.
"""


class ResponseFormatJSONSchema(TypedDict, total=False):
json_schema: Required[JSONSchema]

type: Required[Literal["json_schema"]]
"""The type of response format being defined: `json_schema`"""

class ResponseFormatBase(BaseModel):
type: Literal["text", "json_object"] | None = None

ResponseFormat:TypeAlias = ResponseFormatJSONSchema|ResponseFormatBase

class ChatCompletionRequest(BaseModel):
"""
Expand Down
4 changes: 2 additions & 2 deletions backend/app/lm/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel

from app.lm import models
from app.lm.models import Function, FunctionDefinition, ChatCompletionToolParam, Message, ResponseFormat, TopLogprob, TokenLogprob, ChoiceLogprobs, Choice, CompletionUsage
from app.lm.models import Function, FunctionDefinition, ChatCompletionToolParam, Message, ResponseFormatBase, TopLogprob, TokenLogprob, ChoiceLogprobs, Choice, CompletionUsage

"""
models.ChatCompletionCreate -> ChatCompletionCreate -> ChatCompletion -> models.ChatCompletion
Expand All @@ -21,7 +21,7 @@ class ChatCompletionRequest(BaseModel):
messages: Sequence[Message]
model: str | Literal[*MODELS]
max_tokens: int | None = None
response_format: ResponseFormat | None = None
response_format: ResponseFormatBase | None = None
safe_prompt: bool | None = None
random_seed: int | None = None
temperature: float | None = None
Expand Down
19 changes: 14 additions & 5 deletions backend/app/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional, Literal
from typing import Optional, Literal,Any
from datetime import datetime
import re
from sqlmodel import Field, Relationship, UniqueConstraint, SQLModel, func, Column, Integer, ForeignKey
from pydantic import BaseModel
from pydantic import field_validator

# Shared properties
Expand Down Expand Up @@ -222,18 +223,23 @@ def name_validator(cls, name: str) -> str:
class DocumentDataExtractorCreate(DocumentDataExtractorBase):
name: str
prompt: str
response_template:dict[str,tuple[Literal['str','int','bool','float'],Literal['required','optional']]]


# Properties to receive on DocumentDataExtractor update
class DocumentDataExtractorUpdate(DocumentDataExtractorBase):
name: str | None = None
prompt: str | None = None
response_template: dict[str,tuple[Literal['str','int','bool','float'],Literal['required','optional']]] | None = None



class DocumentDataExtractor(DocumentDataExtractorBase, table=True):
id: int | None = Field(default=None, primary_key=True)
timestamp: datetime | None = Field(default=func.now())
owner_id: int | None = Field(sa_column=Column(Integer, ForeignKey("user.id", ondelete="CASCADE"), default=None))
owner: User | None = Relationship(back_populates="document_data_extractors")
response_template: str
document_data_examples: list["DocumentDataExample"] = Relationship(back_populates="document_data_extractor", sa_relationship_kwargs={"cascade": "all, delete"})


Expand All @@ -243,6 +249,7 @@ class DocumentDataExtractorOut(DocumentDataExtractorBase):
timestamp: datetime
owner_id: int
document_data_examples: list["DocumentDataExample"]
response_template:str

class DocumentDataExtractorsOut(SQLModel):
data: list[DocumentDataExtractorOut]
Expand All @@ -252,18 +259,18 @@ class DocumentDataExtractorsOut(SQLModel):
# Examples
class DocumentDataExampleBase(SQLModel):
document_id: str
data: str
data: dict[str,str|None]
document_data_extractor_id: int | None = None

class DocumentDataExampleCreate(DocumentDataExampleBase):
document_id: str
data: str
data: dict[str,str|None]
start_page: int = 0
end_page: int | None = None

class DocumentDataExampleUpdate(DocumentDataExampleBase):
document_id: str | None = None
data: str | None = None
data: dict[str,str|None] | None = None
start_page: int | None = None
end_page: int | None = None

Expand All @@ -277,4 +284,6 @@ class DocumentDataExample(SQLModel, table=True):
end_page: int | None = None

class DocumentDataExampleOut(DocumentDataExampleBase):
id: int
id: int
data: str

2 changes: 1 addition & 1 deletion backend/app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def chat_completion_create_mistral() -> mistral.ChatCompletionRequest:
],
model="mistral-medium-2312",
max_tokens=100,
response_format=mistral.ResponseFormat(type="text"),
response_format=mistral.ResponseFormatBase(type="text"),
safe_prompt=True,
random_seed=0,
temperature=1.0,
Expand Down
2 changes: 0 additions & 2 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ pymupdf = "^1.24.10"
pytesseract = "^0.3.13"
pdf2image = "^1.17.0"



[tool.poetry.group.dev.dependencies]
pytest = "^8.3"
pytest-cov = "^5.0"
Expand Down

0 comments on commit 8afbc1d

Please sign in to comment.