Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
lccnl committed Oct 25, 2024
1 parent 679010c commit 994e4b7
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 107 deletions.
142 changes: 85 additions & 57 deletions backend/app/api/routes/dde.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Any, Iterable, Literal, TypedDict
from app.lm.models.chat_completion import TokenLogprob
from app.lm.models import ChatCompletionResponse
from typing import Any, Literal, TypedDict
from fastapi import APIRouter, HTTPException, status, UploadFile
from fastapi.responses import JSONResponse
from sqlmodel import func, select
from sqlalchemy.exc import IntegrityError
import json
import math
import io
from pydantic import create_model, ValidationError
from app.api.deps import CurrentUser, SessionDep
Expand Down Expand Up @@ -463,44 +460,59 @@ async def extract_from_file(
full_system_content = f"{system_prompt}\n{examples_text}"

messages = [
ChatCompletionMessage(role="system", content=full_system_content),
ChatCompletionMessage(role="user", content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}")
]

pydantic_reponse=create_pydantic_model(json.loads(document_data_extractor.response_template))
format_response={"type": "json_schema",
"json_schema":{
"schema":to_strict_json_schema(pydantic_reponse),
"name":'response',
'strict':True}}

chat_completion_request = ChatCompletionRequest(
model='gpt-4o-2024-08-06',
messages=messages,
max_tokens=2000,
temperature=0.1,
logprobs=True,
top_logprobs= 5,
response_format=format_response

).model_dump(exclude_unset=True)

chat_completion_response = await ArenaHandler(session, document_data_extractor.owner, chat_completion_request).process_request()
extracted_data=chat_completion_response.choices[0].message.content
ChatCompletionMessage(role="system", content=full_system_content),
ChatCompletionMessage(
role="user",
content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}",
),
]

pydantic_reponse = create_pydantic_model(
json.loads(document_data_extractor.response_template)
)
format_response = {
"type": "json_schema",
"json_schema": {
"schema": to_strict_json_schema(pydantic_reponse),
"name": "response",
"strict": True,
},
}

chat_completion_request = ChatCompletionRequest(
model="gpt-4o-2024-08-06",
messages=messages,
max_tokens=2000,
temperature=0.1,
logprobs=True,
top_logprobs=5,
response_format=format_response,
).model_dump(exclude_unset=True)

chat_completion_response = await ArenaHandler(
session, document_data_extractor.owner, chat_completion_request
).process_request()
extracted_data = chat_completion_response.choices[0].message.content
extracted_data_token = chat_completion_response.choices[0].logprobs.content
#TODO: handle refusal or case in which content was not correctly done
# TODO: handle refusal or case in which content was not correctly done
# TODO: Improve the prompt to ensure the output is always a valid JSON
#
json_string = extracted_data[extracted_data.find('{'):extracted_data.rfind('}')+1]
#keys = list(pydantic_reponse.__fields__.keys())
#value_indices = extract_tokens_indices_for_each_key(keys, extracted_data_token)
#logprobs = extract_logprobs_from_indices(value_indices, extracted_data_token)
return {'extracted_data': json.loads(json_string)}
json_string = extracted_data[
extracted_data.find("{") : extracted_data.rfind("}") + 1
]
# keys = list(pydantic_reponse.__fields__.keys())
# value_indices = extract_tokens_indices_for_each_key(keys, extracted_data_token)
# logprobs = extract_logprobs_from_indices(value_indices, extracted_data_token)
return {"extracted_data": json.loads(json_string)}


class Token(TypedDict):
token: str

def extract_tokens_indices_for_each_key(keys: list[str], token_list:list[Token]) -> dict[str, list[int]]:


def extract_tokens_indices_for_each_key(
keys: list[str], token_list: list[Token]
) -> dict[str, list[int]]:
"""
Extracts the indices of tokens corresponding to extracted data related to a list of specified keys.
Expand All @@ -521,50 +533,66 @@ def extract_tokens_indices_for_each_key(keys: list[str], token_list:list[Token])
current_key = ""
matched_key = None
remaining_keys = keys.copy()
saving_indices = False
saving_indices = False
for i, token_object in enumerate(token_list):
token = token_object.token
if matched_key is not None:
if saving_indices:
if token == '","' or token == ',"':
next_token = token_list[i + 1].token if i + 1 < len(token_list) else None
if next_token is not None and any(key.startswith(next_token) for key in remaining_keys):
value_indices[matched_key].append(i - 1) #stop saving indices when token is "," and the next token is the start of one of the keys
matched_key = None
saving_indices = False
current_key = ""
continue
elif token_list[i + 1].token == '}':
value_indices[matched_key].append(i) #stop saving indices when the next token is '}'
matched_key = None
next_token = (
token_list[i + 1].token
if i + 1 < len(token_list)
else None
)
if next_token is not None and any(
key.startswith(next_token) for key in remaining_keys
):
value_indices[matched_key].append(
i - 1
) # stop saving indices when token is "," and the next token is the start of one of the keys
matched_key = None
saving_indices = False
current_key = ""
continue
current_key = ""
continue
elif token_list[i + 1].token == "}":
value_indices[matched_key].append(
i
) # stop saving indices when the next token is '}'
matched_key = None
saving_indices = False
current_key = ""
continue
continue
elif token == '":' or token == '":"':
value_indices[matched_key].append(i + 1) #start saving indices after tokens '":' or '":"'
value_indices[matched_key].append(
i + 1
) # start saving indices after tokens '":' or '":"'
saving_indices = True
else:
current_key += token
for key in remaining_keys:
if key.startswith(current_key):
if current_key == key:
matched_key = key #full key matched
matched_key = key # full key matched
remaining_keys.remove(key)
break
else:
current_key = ""
return value_indices

def extract_logprobs_from_indices(value_indices: dict[str, list[int]], token_list: list[Token]) -> dict[str, list[Any]]:


def extract_logprobs_from_indices(
value_indices: dict[str, list[int]], token_list: list[Token]
) -> dict[str, list[Any]]:
logprobs = {key: [] for key in value_indices}
for key, indices in value_indices.items():
start_idx = indices[0]
end_idx = indices[-1]
for i in range(start_idx, end_idx + 1):
logprobs[key].append(token_list[i].top_logprobs[0].logprob)
start_idx = indices[0]
end_idx = indices[-1]
for i in range(start_idx, end_idx + 1):
logprobs[key].append(token_list[i].top_logprobs[0].logprob)
return logprobs


def create_pydantic_model(
schema: dict[
str,
Expand Down
7 changes: 5 additions & 2 deletions backend/app/ops/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
A = TypeVar("A")
B = TypeVar("B")

logger=logging.getLogger('uvicorn.error')
logger = logging.getLogger("uvicorn.error")


# A mixin class to add hashability to pydantic models
class Hashable:
Expand Down Expand Up @@ -199,7 +200,9 @@ async def call(self) -> B:
All tasks should have been created
"""
args = [await arg.task for arg in self.args]
logger.info(f'Executing op {type(self.op)} with arguments of type {[type(el) for el in args]}')
logger.info(
f"Executing op {type(self.op)} with arguments of type {[type(el) for el in args]}"
)
return await self.op.call(*args)

def tasks(self, task_group: TaskGroup):
Expand Down
5 changes: 2 additions & 3 deletions backend/app/services/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def analyze(
json=req.model_dump(exclude_none=True),
timeout=1000,
)

return analyzer_response.validate_python(
response.raise_for_status().json()
)
Expand Down Expand Up @@ -152,8 +152,7 @@ async def anonymize(self, req: AnonymizerRequest) -> AnonymizerResponse:
url=f"{self.url}",
json=req.model_dump(exclude_none=True),
timeout=1000,
)
)
return AnonymizerResponse.model_validate(
response.raise_for_status().json()
)

Loading

0 comments on commit 994e4b7

Please sign in to comment.