diff --git a/backend/app/api/routes/dde.py b/backend/app/api/routes/dde.py index a03ed08..520b43e 100644 --- a/backend/app/api/routes/dde.py +++ b/backend/app/api/routes/dde.py @@ -1,12 +1,9 @@ -from typing import Any, Iterable, Literal, TypedDict -from app.lm.models.chat_completion import TokenLogprob -from app.lm.models import ChatCompletionResponse +from typing import Any, Literal, TypedDict from fastapi import APIRouter, HTTPException, status, UploadFile from fastapi.responses import JSONResponse from sqlmodel import func, select from sqlalchemy.exc import IntegrityError import json -import math import io from pydantic import create_model, ValidationError from app.api.deps import CurrentUser, SessionDep @@ -463,44 +460,59 @@ async def extract_from_file( full_system_content = f"{system_prompt}\n{examples_text}" messages = [ - ChatCompletionMessage(role="system", content=full_system_content), - ChatCompletionMessage(role="user", content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}") - ] - - pydantic_reponse=create_pydantic_model(json.loads(document_data_extractor.response_template)) - format_response={"type": "json_schema", - "json_schema":{ - "schema":to_strict_json_schema(pydantic_reponse), - "name":'response', - 'strict':True}} - - chat_completion_request = ChatCompletionRequest( - model='gpt-4o-2024-08-06', - messages=messages, - max_tokens=2000, - temperature=0.1, - logprobs=True, - top_logprobs= 5, - response_format=format_response - - ).model_dump(exclude_unset=True) - - chat_completion_response = await ArenaHandler(session, document_data_extractor.owner, chat_completion_request).process_request() - extracted_data=chat_completion_response.choices[0].message.content + ChatCompletionMessage(role="system", content=full_system_content), + ChatCompletionMessage( + role="user", + content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}", + ), + ] + + pydantic_reponse = create_pydantic_model( + json.loads(document_data_extractor.response_template) + ) + format_response = { + "type": "json_schema", + "json_schema": { + "schema": to_strict_json_schema(pydantic_reponse), + "name": "response", + "strict": True, + }, + } + + chat_completion_request = ChatCompletionRequest( + model="gpt-4o-2024-08-06", + messages=messages, + max_tokens=2000, + temperature=0.1, + logprobs=True, + top_logprobs=5, + response_format=format_response, + ).model_dump(exclude_unset=True) + + chat_completion_response = await ArenaHandler( + session, document_data_extractor.owner, chat_completion_request + ).process_request() + extracted_data = chat_completion_response.choices[0].message.content extracted_data_token = chat_completion_response.choices[0].logprobs.content - #TODO: handle refusal or case in which content was not correctly done + # TODO: handle refusal or case in which content was not correctly done # TODO: Improve the prompt to ensure the output is always a valid JSON # - json_string = extracted_data[extracted_data.find('{'):extracted_data.rfind('}')+1] - #keys = list(pydantic_reponse.__fields__.keys()) - #value_indices = extract_tokens_indices_for_each_key(keys, extracted_data_token) - #logprobs = extract_logprobs_from_indices(value_indices, extracted_data_token) - return {'extracted_data': json.loads(json_string)} + json_string = extracted_data[ + extracted_data.find("{") : extracted_data.rfind("}") + 1 + ] + # keys = list(pydantic_reponse.__fields__.keys()) + # value_indices = extract_tokens_indices_for_each_key(keys, extracted_data_token) + # logprobs = extract_logprobs_from_indices(value_indices, extracted_data_token) + return {"extracted_data": json.loads(json_string)} + class Token(TypedDict): token: str - -def extract_tokens_indices_for_each_key(keys: list[str], token_list:list[Token]) -> dict[str, list[int]]: + + +def extract_tokens_indices_for_each_key( + keys: list[str], token_list: list[Token] +) -> dict[str, list[int]]: """ Extracts the indices of tokens corresponding to extracted data related to a list of specified keys. @@ -521,50 +533,66 @@ def extract_tokens_indices_for_each_key(keys: list[str], token_list:list[Token]) current_key = "" matched_key = None remaining_keys = keys.copy() - saving_indices = False + saving_indices = False for i, token_object in enumerate(token_list): token = token_object.token if matched_key is not None: if saving_indices: if token == '","' or token == ',"': - next_token = token_list[i + 1].token if i + 1 < len(token_list) else None - if next_token is not None and any(key.startswith(next_token) for key in remaining_keys): - value_indices[matched_key].append(i - 1) #stop saving indices when token is "," and the next token is the start of one of the keys - matched_key = None - saving_indices = False - current_key = "" - continue - elif token_list[i + 1].token == '}': - value_indices[matched_key].append(i) #stop saving indices when the next token is '}' - matched_key = None + next_token = ( + token_list[i + 1].token + if i + 1 < len(token_list) + else None + ) + if next_token is not None and any( + key.startswith(next_token) for key in remaining_keys + ): + value_indices[matched_key].append( + i - 1 + ) # stop saving indices when token is "," and the next token is the start of one of the keys + matched_key = None saving_indices = False - current_key = "" - continue + current_key = "" + continue + elif token_list[i + 1].token == "}": + value_indices[matched_key].append( + i + ) # stop saving indices when the next token is '}' + matched_key = None + saving_indices = False + current_key = "" + continue continue elif token == '":' or token == '":"': - value_indices[matched_key].append(i + 1) #start saving indices after tokens '":' or '":"' + value_indices[matched_key].append( + i + 1 + ) # start saving indices after tokens '":' or '":"' saving_indices = True else: current_key += token for key in remaining_keys: if key.startswith(current_key): if current_key == key: - matched_key = key #full key matched + matched_key = key # full key matched remaining_keys.remove(key) break else: current_key = "" return value_indices - -def extract_logprobs_from_indices(value_indices: dict[str, list[int]], token_list: list[Token]) -> dict[str, list[Any]]: + + +def extract_logprobs_from_indices( + value_indices: dict[str, list[int]], token_list: list[Token] +) -> dict[str, list[Any]]: logprobs = {key: [] for key in value_indices} for key, indices in value_indices.items(): - start_idx = indices[0] - end_idx = indices[-1] - for i in range(start_idx, end_idx + 1): - logprobs[key].append(token_list[i].top_logprobs[0].logprob) + start_idx = indices[0] + end_idx = indices[-1] + for i in range(start_idx, end_idx + 1): + logprobs[key].append(token_list[i].top_logprobs[0].logprob) return logprobs + def create_pydantic_model( schema: dict[ str, diff --git a/backend/app/ops/computation.py b/backend/app/ops/computation.py index d35a18c..ddea669 100644 --- a/backend/app/ops/computation.py +++ b/backend/app/ops/computation.py @@ -17,7 +17,8 @@ A = TypeVar("A") B = TypeVar("B") -logger=logging.getLogger('uvicorn.error') +logger = logging.getLogger("uvicorn.error") + # A mixin class to add hashability to pydantic models class Hashable: @@ -199,7 +200,9 @@ async def call(self) -> B: All tasks should have been created """ args = [await arg.task for arg in self.args] - logger.info(f'Executing op {type(self.op)} with arguments of type {[type(el) for el in args]}') + logger.info( + f"Executing op {type(self.op)} with arguments of type {[type(el) for el in args]}" + ) return await self.op.call(*args) def tasks(self, task_group: TaskGroup): diff --git a/backend/app/services/masking.py b/backend/app/services/masking.py index 654f60b..a95f731 100644 --- a/backend/app/services/masking.py +++ b/backend/app/services/masking.py @@ -51,7 +51,7 @@ async def analyze( json=req.model_dump(exclude_none=True), timeout=1000, ) - + return analyzer_response.validate_python( response.raise_for_status().json() ) @@ -152,8 +152,7 @@ async def anonymize(self, req: AnonymizerRequest) -> AnonymizerResponse: url=f"{self.url}", json=req.model_dump(exclude_none=True), timeout=1000, - ) + ) return AnonymizerResponse.model_validate( response.raise_for_status().json() ) - diff --git a/backend/app/tests/api/routes/test_dde.py b/backend/app/tests/api/routes/test_dde.py index 02ecc32..0c40031 100644 --- a/backend/app/tests/api/routes/test_dde.py +++ b/backend/app/tests/api/routes/test_dde.py @@ -6,7 +6,10 @@ import pytest import json from typing import Generator, Any -from app.api.routes.dde import extract_tokens_indices_for_each_key, extract_logprobs_from_indices +from app.api.routes.dde import ( + extract_tokens_indices_for_each_key, + extract_logprobs_from_indices, +) @pytest.fixture(scope="module") @@ -15,17 +18,16 @@ def document_data_extractor( ) -> Generator[dict[str, Any], None, None]: fake_name = "Test_dde" fake_prompt = "Extract the adresse from document" - response_template = { - "adresse": [ - "str", - "required" - ] - } + response_template = {"adresse": ["str", "required"]} - payload = {"name": fake_name, "prompt": fake_prompt, 'response_template':response_template} + payload = { + "name": fake_name, + "prompt": fake_prompt, + "response_template": response_template, + } headers = superuser_token_headers - + response = client.post( f"{settings.API_V1_STR}/dde", headers=headers, @@ -39,7 +41,7 @@ def document_data_extractor( timestamp="2024-10-03T09:31:33.748765", owner_id=1, document_data_examples=[], - response_template = json.dumps(response_template) + response_template=json.dumps(response_template), ) assert response.status_code == 200 @@ -91,7 +93,9 @@ def test_update_document_data_extractor( update_payload = { "name": updated_name, "prompt": document_data_extractor["prompt"], - "response_template": json.loads(document_data_extractor["response_template"]) + "response_template": json.loads( + document_data_extractor["response_template"] + ), } response = client.put( @@ -108,11 +112,18 @@ def test_update_document_data_extractor( assert response_data["prompt"] == document_data_extractor["prompt"] assert response_data["timestamp"] == document_data_extractor["timestamp"] assert response_data["owner_id"] == document_data_extractor["owner_id"] - assert response_data["response_template"] == document_data_extractor["response_template"] + assert ( + response_data["response_template"] + == document_data_extractor["response_template"] + ) assert len(response_data["document_data_examples"]) == 0 - -def test_create_document_data_example(client: TestClient, superuser_token_headers: dict[str, str], document_data_extractor: dict[str, Any]): + +def test_create_document_data_example( + client: TestClient, + superuser_token_headers: dict[str, str], + document_data_extractor: dict[str, Any], +): name = document_data_extractor["name"] with patch.object(Documents, "exists", return_value=True): @@ -155,7 +166,7 @@ def test_update_document_data_example( ): name_dde = document_data_extractor["name"] id_example = document_data_extractor["document_data_examples"][0]["id"] - updated_data = {'adresse': '2 ALLEE DES HORTENSIAS'} + updated_data = {"adresse": "2 ALLEE DES HORTENSIAS"} update_payload = { "document_id": document_data_extractor["document_data_examples"][0][ @@ -165,8 +176,12 @@ def test_update_document_data_example( "document_data_extractor_id": document_data_extractor[ "document_data_examples" ][0]["document_data_extractor_id"], - "start_page":document_data_extractor["document_data_examples"][0]['start_page'], - "end_page":document_data_extractor["document_data_examples"][0]['end_page'] + "start_page": document_data_extractor["document_data_examples"][0][ + "start_page" + ], + "end_page": document_data_extractor["document_data_examples"][0][ + "end_page" + ], } document_data_extractor["document_data_examples"][0]["data"] = updated_data @@ -186,59 +201,111 @@ def test_update_document_data_example( assert response_data["document_data_extractor_id"] == 1 assert response_data["id"] == 1 + class TopLogprob: def __init__(self, logprob: float): self.logprob = logprob + class Token: def __init__(self, token: str, top_logprobs: list[TopLogprob]): self.token = token - self.top_logprobs = top_logprobs - + self.top_logprobs = top_logprobs + + @pytest.fixture() def token_list() -> list[Token]: - tokens = ['{', 'na', 'me', '":"', 'Mi', 'cha', 'el', '","', 'smi', 'th', '","', 'age', '":', '35', ',"' , 'natio', 'nality', '":"', 'Eng', 'lish', '}'] + tokens = [ + "{", + "na", + "me", + '":"', + "Mi", + "cha", + "el", + '","', + "smi", + "th", + '","', + "age", + '":', + "35", + ',"', + "natio", + "nality", + '":"', + "Eng", + "lish", + "}", + ] token_list = [Token(token, []) for token in tokens] return token_list + @pytest.fixture() def token_logprobs_list() -> list[Token]: - tokens = ['{', 'na', 'me', '":"', 'Mi', 'cha', 'el', '","', 'smi', 'th', '","', 'age', '":', '35', ',"' , 'natio', 'nality', '":"', 'Eng', 'lish', '}'] + tokens = [ + "{", + "na", + "me", + '":"', + "Mi", + "cha", + "el", + '","', + "smi", + "th", + '","', + "age", + '":', + "35", + ',"', + "natio", + "nality", + '":"', + "Eng", + "lish", + "}", + ] logprobs = [-5.42e-05, -10.50, -11.50, -12.50, -13.750] logprobs_list = [TopLogprob(logprob) for logprob in logprobs] token_list = [Token(token, logprobs_list) for token in tokens] return token_list - + + def test_extract_tokens_indices_for_each_key(token_list: list[Token]): - keys = ['name', 'age', 'nationality'] + keys = ["name", "age", "nationality"] expected_result = { - 'name': [4, 9], - 'age': [13, 13], - 'nationality': [18, 19] + "name": [4, 9], + "age": [13, 13], + "nationality": [18, 19], } result = extract_tokens_indices_for_each_key(keys, token_list) - + assert result == expected_result + def test_extract_logprobs_from_indices(token_logprobs_list: list[Token]): expected_value_indices = { - 'name': [4, 9], - 'age': [13, 13], - 'nationality': [18, 19] + "name": [4, 9], + "age": [13, 13], + "nationality": [18, 19], } expected_logprobs = -5.42e-05 expected_result = { - 'name': [expected_logprobs] * 6, - 'age': [expected_logprobs], - 'nationality': [expected_logprobs] * 2 + "name": [expected_logprobs] * 6, + "age": [expected_logprobs], + "nationality": [expected_logprobs] * 2, } - result = extract_logprobs_from_indices(expected_value_indices, token_logprobs_list) + result = extract_logprobs_from_indices( + expected_value_indices, token_logprobs_list + ) assert result == expected_result - -#TODO: test extract_from_file - + + +# TODO: test extract_from_file # TODO: test extract_from_file diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index 59f432e..d18b8a4 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -21,7 +21,7 @@ Event, EventIdentifier, Attribute, - EventAttribute + EventAttribute, ) from app.tests.utils.user import authentication_token_from_email from app.tests.utils.utils import get_superuser_token_headers @@ -29,7 +29,13 @@ LMApiKeys, openai, mistral, - anthropic,Choice,TokenLogprob,ChoiceLogprobs,TopLogprob,CompletionUsage,Message + anthropic, + Choice, + TokenLogprob, + ChoiceLogprobs, + TopLogprob, + CompletionUsage, + Message, ) @@ -164,9 +170,7 @@ def chat_completion_openai() -> openai.ChatCompletionResponse: ), ] ), - message=Message( - role="assistant", content="Hello world!" - ), + message=Message(role="assistant", content="Hello world!"), ) ], created=1672463200, @@ -222,9 +226,7 @@ def chat_completion_mistral() -> mistral.ChatCompletionResponse: token=".", logprob=-0.100103, top_logprobs=[ - TopLogprob( - token=".", logprob=-0.100103 - ) + TopLogprob(token=".", logprob=-0.100103) ], ) ]