Skip to content

Commit

Permalink
Merge pull request #39 from arena-ai/remove_none_log_info
Browse files Browse the repository at this point in the history
Remove none log info
  • Loading branch information
ngrislain authored Oct 25, 2024
2 parents 7263dab + 994e4b7 commit a89a160
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 225 deletions.
254 changes: 86 additions & 168 deletions backend/app/api/routes/dde.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Any, Iterable, Literal, TypedDict
from app.lm.models.chat_completion import TokenLogprob
from app.lm.models import ChatCompletionResponse
from typing import Any, Literal, TypedDict
from fastapi import APIRouter, HTTPException, status, UploadFile
from fastapi.responses import JSONResponse
from sqlmodel import func, select
from sqlalchemy.exc import IntegrityError
import json
import math
import io
from pydantic import create_model, ValidationError
from app.api.deps import CurrentUser, SessionDep
Expand Down Expand Up @@ -463,43 +460,59 @@ async def extract_from_file(
full_system_content = f"{system_prompt}\n{examples_text}"

messages = [
ChatCompletionMessage(role="system", content=full_system_content),
ChatCompletionMessage(role="user", content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}")
]

pydantic_reponse=create_pydantic_model(json.loads(document_data_extractor.response_template))
format_response={"type": "json_schema",
"json_schema":{
"schema":to_strict_json_schema(pydantic_reponse),
"name":'response',
'strict':True}}

chat_completion_request = ChatCompletionRequest(
model='gpt-4o-2024-08-06',
messages=messages,
max_tokens=2000,
temperature=0.1,
logprobs=True,
top_logprobs= 5,
response_format=format_response

).model_dump(exclude_unset=True)

chat_completion_response = await ArenaHandler(session, document_data_extractor.owner, chat_completion_request).process_request()
extracted_data=chat_completion_response.choices[0].message.content
ChatCompletionMessage(role="system", content=full_system_content),
ChatCompletionMessage(
role="user",
content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}",
),
]

pydantic_reponse = create_pydantic_model(
json.loads(document_data_extractor.response_template)
)
format_response = {
"type": "json_schema",
"json_schema": {
"schema": to_strict_json_schema(pydantic_reponse),
"name": "response",
"strict": True,
},
}

chat_completion_request = ChatCompletionRequest(
model="gpt-4o-2024-08-06",
messages=messages,
max_tokens=2000,
temperature=0.1,
logprobs=True,
top_logprobs=5,
response_format=format_response,
).model_dump(exclude_unset=True)

chat_completion_response = await ArenaHandler(
session, document_data_extractor.owner, chat_completion_request
).process_request()
extracted_data = chat_completion_response.choices[0].message.content
extracted_data_token = chat_completion_response.choices[0].logprobs.content
#TODO: handle refusal or case in which content was not correctly done
# TODO: handle refusal or case in which content was not correctly done
# TODO: Improve the prompt to ensure the output is always a valid JSON
json_string = extracted_data[extracted_data.find('{'):extracted_data.rfind('}')+1]
keys = list(pydantic_reponse.__fields__.keys())
value_indices = extract_tokens_indices_for_each_key(keys, extracted_data_token)
logprobs = extract_logprobs_from_indices(value_indices, extracted_data_token)
return {'extracted_data': json.loads(json_string), 'logprobs': logprobs}
#
json_string = extracted_data[
extracted_data.find("{") : extracted_data.rfind("}") + 1
]
# keys = list(pydantic_reponse.__fields__.keys())
# value_indices = extract_tokens_indices_for_each_key(keys, extracted_data_token)
# logprobs = extract_logprobs_from_indices(value_indices, extracted_data_token)
return {"extracted_data": json.loads(json_string)}


class Token(TypedDict):
token: str

def extract_tokens_indices_for_each_key(keys: list[str], token_list:list[Token]) -> dict[str, list[int]]:


def extract_tokens_indices_for_each_key(
keys: list[str], token_list: list[Token]
) -> dict[str, list[int]]:
"""
Extracts the indices of tokens corresponding to extracted data related to a list of specified keys.
Expand All @@ -520,50 +533,66 @@ def extract_tokens_indices_for_each_key(keys: list[str], token_list:list[Token])
current_key = ""
matched_key = None
remaining_keys = keys.copy()
saving_indices = False
saving_indices = False
for i, token_object in enumerate(token_list):
token = token_object.token
if matched_key is not None:
if saving_indices:
if token == '","' or token == ',"':
next_token = token_list[i + 1].token if i + 1 < len(token_list) else None
if next_token is not None and any(key.startswith(next_token) for key in remaining_keys):
value_indices[matched_key].append(i - 1) #stop saving indices when token is "," and the next token is the start of one of the keys
matched_key = None
next_token = (
token_list[i + 1].token
if i + 1 < len(token_list)
else None
)
if next_token is not None and any(
key.startswith(next_token) for key in remaining_keys
):
value_indices[matched_key].append(
i - 1
) # stop saving indices when token is "," and the next token is the start of one of the keys
matched_key = None
saving_indices = False
current_key = ""
continue
elif token_list[i + 1].token == '}':
value_indices[matched_key].append(i) #stop saving indices when the next token is '}'
matched_key = None
saving_indices = False
current_key = ""
continue
current_key = ""
continue
elif token_list[i + 1].token == "}":
value_indices[matched_key].append(
i
) # stop saving indices when the next token is '}'
matched_key = None
saving_indices = False
current_key = ""
continue
continue
elif token == '":' or token == '":"':
value_indices[matched_key].append(i + 1) #start saving indices after tokens '":' or '":"'
value_indices[matched_key].append(
i + 1
) # start saving indices after tokens '":' or '":"'
saving_indices = True
else:
current_key += token
for key in remaining_keys:
if key.startswith(current_key):
if current_key == key:
matched_key = key #full key matched
matched_key = key # full key matched
remaining_keys.remove(key)
break
else:
current_key = ""
return value_indices

def extract_logprobs_from_indices(value_indices: dict[str, list[int]], token_list: list[Token]) -> dict[str, list[Any]]:


def extract_logprobs_from_indices(
value_indices: dict[str, list[int]], token_list: list[Token]
) -> dict[str, list[Any]]:
logprobs = {key: [] for key in value_indices}
for key, indices in value_indices.items():
start_idx = indices[0]
end_idx = indices[-1]
for i in range(start_idx, end_idx + 1):
logprobs[key].append(token_list[i].top_logprobs[0].logprob)
start_idx = indices[0]
end_idx = indices[-1]
for i in range(start_idx, end_idx + 1):
logprobs[key].append(token_list[i].top_logprobs[0].logprob)
return logprobs


def create_pydantic_model(
schema: dict[
str,
Expand Down Expand Up @@ -607,114 +636,3 @@ def validate_extracted_text(text: str):
status_code=500,
detail="The extracted text from the document is empty. Please check if the document is corrupted.",
)


# TODO: Optimize the entire process of extracting and handling log probabilities from OpenAI for the identified tokens.
def is_equal_ignore_sign(a, b) -> bool:
try:
a = float(a)
b = float(b)
except ValueError:
return False
return (
abs(a) == abs(b)
) # necessary because logits are associated only with numerical tokens, so here values are considered in their absolute form, ignoring the sign.


def combined_token_in_extracted_data(
combined_token: str, extracted_data: Iterable
) -> bool:
try:
combined_token = float(combined_token)
except ValueError:
return False
return any(
is_equal_ignore_sign(combined_token, value)
for value in extracted_data
if isinstance(value, (int, float))
)


def find_key_by_value(
combined_token: str, extracted_data: dict[str, Any]
) -> str | None:
try:
combined_token = float(combined_token)
except ValueError:
return None
return next(
(
k
for k, v in extracted_data.items()
if isinstance(v, (int, float))
and is_equal_ignore_sign(combined_token, v)
),
None,
)


def extract_logprobs_from_response(
response: ChatCompletionResponse, extracted_data: dict[str, Any]
) -> dict[str, float | list[float]]:
logprob_data = {}
extracted_data_token = response.choices[0].logprobs.content

def process_numeric_values(extracted_data: dict[str, Any], path=''):

for i in range(len(extracted_data_token)-1):
token = extracted_data_token[i].token
if token.isdigit(): # Only process tokens that are numeric
combined_token, combined_logprob = combine_tokens(extracted_data_token, i)
if combined_token_in_extracted_data(combined_token, extracted_data.values()): #Checks if a combined token matches any numeric values in the extracted data.
key = find_key_by_value(
combined_token, extracted_data
) # Finds the key in 'extracted_data' corresponding to a numeric value that matches the combined token.
if key:
full_key = path + key
logprob_data[full_key + '_prob_first_token'] = math.exp(extracted_data_token[i].logprob)
logprob_data[full_key + '_prob_second_token'] = math.exp(extracted_data_token[i+1].logprob)

toplogprobs_firsttoken = extracted_data_token[i].top_logprobs
toplogprobs_secondtoken = extracted_data_token[i+1].top_logprobs

logprobs_first = [
top_logprob.logprob
for top_logprob in toplogprobs_firsttoken
]
logprobs_second = [
top_logprob.logprob
for top_logprob in toplogprobs_secondtoken
]

logprob_data[full_key + "_first_token_toplogprobs"] = (
logprobs_first
)
logprob_data[
full_key + "_second_token_toplogprobs"
] = logprobs_second

def traverse_and_extract(data: dict, path=""):
for key, value in data.items():
if isinstance(value, dict):
print("value for traverse_and_extract", value)
traverse_and_extract(value, path + key + ".")
elif isinstance(value, (int, float)):
print("data for process_numeric_values", data)
process_numeric_values(data, path)

traverse_and_extract(extracted_data)
return logprob_data


def combine_tokens(extracted_data_token: list[TokenLogprob], start_index: int) -> tuple[str, float]:
combined_token = extracted_data_token[start_index].token
combined_logprob = extracted_data_token[start_index].logprob

# Keep combining tokens as long as the next token is a digit
for i in range(start_index + 1, len(extracted_data_token)):
if not extracted_data_token[i].token.isdigit():
break
combined_token += extracted_data_token[i].token
combined_logprob += extracted_data_token[i].logprob

return combined_token, combined_logprob
6 changes: 6 additions & 0 deletions backend/app/ops/computation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Generic, TypeVar, TypeVarTuple
import logging
from types import NoneType
from abc import ABC, abstractmethod
from time import time
Expand All @@ -16,6 +17,8 @@
A = TypeVar("A")
B = TypeVar("B")

logger = logging.getLogger("uvicorn.error")


# A mixin class to add hashability to pydantic models
class Hashable:
Expand Down Expand Up @@ -197,6 +200,9 @@ async def call(self) -> B:
All tasks should have been created
"""
args = [await arg.task for arg in self.args]
logger.info(
f"Executing op {type(self.op)} with arguments of type {[type(el) for el in args]}"
)
return await self.op.call(*args)

def tasks(self, task_group: TaskGroup):
Expand Down
19 changes: 7 additions & 12 deletions backend/app/services/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,10 @@ async def analyze(
json=req.model_dump(exclude_none=True),
timeout=1000,
)
try:
return analyzer_response.validate_python(
response.raise_for_status().json()
)
except httpx.HTTPStatusError:
return None

return analyzer_response.validate_python(
response.raise_for_status().json()
)


class Replace(BaseModel):
Expand Down Expand Up @@ -155,9 +153,6 @@ async def anonymize(self, req: AnonymizerRequest) -> AnonymizerResponse:
json=req.model_dump(exclude_none=True),
timeout=1000,
)
try:
return AnonymizerResponse.model_validate(
response.raise_for_status().json()
)
except httpx.HTTPStatusError:
return None
return AnonymizerResponse.model_validate(
response.raise_for_status().json()
)
Loading

0 comments on commit a89a160

Please sign in to comment.