Skip to content

Commit

Permalink
fix: GoogleLLM, agent and handler according to the new genai SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
siiddhantt committed Jan 18, 2025
1 parent ec270a3 commit 904b0bf
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 158 deletions.
199 changes: 82 additions & 117 deletions application/llm/google_ai.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,61 @@
import google.generativeai as genai
from google import genai
from google.genai import types

from application.core.settings import settings
from application.llm.base import BaseLLM


class GoogleLLM(BaseLLM):

def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = settings.API_KEY
genai.configure(api_key=self.api_key)
self.client = genai.Client(api_key="AIzaSyDmbZX65qlQKXcvfMBkJV2KwH82_0yIMlE")

Check warning on line 11 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L11

Added line #L11 was not covered by tests

def _clean_messages_google(self, messages):
cleaned_messages = []
for message in messages[1:]:
cleaned_messages.append(
{
"role": "model" if message["role"] == "system" else message["role"],
"parts": [message["content"]],
}
)
for message in messages:
role = message.get("role")
content = message.get("content")

Check warning on line 17 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L14-L17

Added lines #L14 - L17 were not covered by tests

if role and content is not None:
if isinstance(content, str):
parts = [types.Part.from_text(content)]
elif isinstance(content, list):
parts = content

Check warning on line 23 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L19-L23

Added lines #L19 - L23 were not covered by tests
else:
raise ValueError(f"Unexpected content type: {type(content)}")

Check warning on line 25 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L25

Added line #L25 was not covered by tests

cleaned_messages.append(types.Content(role=role, parts=parts))

Check warning on line 27 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L27

Added line #L27 was not covered by tests

return cleaned_messages

Check warning on line 29 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L29

Added line #L29 was not covered by tests

def _clean_tools_format(self, tools_data):
if isinstance(tools_data, list):
return [self._clean_tools_format(item) for item in tools_data]
elif isinstance(tools_data, dict):
if (
"function" in tools_data
and "type" in tools_data
and tools_data["type"] == "function"
):
# Handle the case where tools are nested under 'function'
cleaned_function = self._clean_tools_format(tools_data["function"])
return {"function_declarations": [cleaned_function]}
elif (
"function" in tools_data
and "type_" in tools_data
and tools_data["type_"] == "function"
):
# Handle the case where tools are nested under 'function' and type is already 'type_'
cleaned_function = self._clean_tools_format(tools_data["function"])
return {"function_declarations": [cleaned_function]}
else:
new_tools_data = {}
for key, value in tools_data.items():
if key == "type":
if value == "string":
new_tools_data["type_"] = "STRING"
elif value == "object":
new_tools_data["type_"] = "OBJECT"
elif key == "additionalProperties":
continue
elif key == "properties":
if isinstance(value, dict):
new_properties = {}
for prop_name, prop_value in value.items():
if (
isinstance(prop_value, dict)
and "type" in prop_value
):
if prop_value["type"] == "string":
new_properties[prop_name] = {
"type_": "STRING",
"description": prop_value.get(
"description"
),
}
# Add more type mappings as needed
else:
new_properties[prop_name] = (
self._clean_tools_format(prop_value)
)
new_tools_data[key] = new_properties
else:
new_tools_data[key] = self._clean_tools_format(value)
def _clean_tools_format(self, tools_list):
genai_tools = []
for tool_data in tools_list:
if tool_data["type"] == "function":
function = tool_data["function"]
genai_function = dict(

Check warning on line 36 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L32-L36

Added lines #L32 - L36 were not covered by tests
name=function["name"],
description=function["description"],
parameters={
"type": "OBJECT",
"properties": {
k: {
**v,
"type": v["type"].upper() if v["type"] else None,
}
for k, v in function["parameters"]["properties"].items()
},
"required": (
function["parameters"]["required"]
if "required" in function["parameters"]
else []
),
},
)
genai_tool = types.Tool(function_declarations=[genai_function])
genai_tools.append(genai_tool)

Check warning on line 56 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L55-L56

Added lines #L55 - L56 were not covered by tests

else:
new_tools_data[key] = self._clean_tools_format(value)
return new_tools_data
else:
return tools_data
return genai_tools

Check warning on line 58 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L58

Added line #L58 was not covered by tests

def _raw_gen(
self,
Expand All @@ -90,61 +65,51 @@ def _raw_gen(
stream=False,
tools=None,
formatting="openai",
**kwargs
**kwargs,
):
config = {}
model_name = "gemini-2.0-flash-exp"
client = self.client
if formatting == "openai":
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()

Check warning on line 73 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L70-L73

Added lines #L70 - L73 were not covered by tests

if formatting == "raw":
client = genai.GenerativeModel(model_name=model_name)
response = client.generate_content(contents=messages)
return response.text
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools

Check warning on line 77 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L75-L77

Added lines #L75 - L77 were not covered by tests
response = client.models.generate_content(
model=model,
contents=messages,
config=config,
)
return response

Check warning on line 83 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L83

Added line #L83 was not covered by tests
else:
if tools:
client = genai.GenerativeModel(
model_name=model_name,
generation_config=config,
system_instruction=messages[0]["content"],
tools=self._clean_tools_format(tools),
)
chat_session = gen_model.start_chat(
history=self._clean_messages_google(messages)[:-1]
)
response = chat_session.send_message(
self._clean_messages_google(messages)[-1]
)
return response
else:
gen_model = genai.GenerativeModel(
model_name=model_name,
generation_config=config,
system_instruction=messages[0]["content"],
)
chat_session = gen_model.start_chat(
history=self._clean_messages_google(messages)[:-1]
)
response = chat_session.send_message(
self._clean_messages_google(messages)[-1]
)
return response.text
response = client.models.generate_content(

Check warning on line 85 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L85

Added line #L85 was not covered by tests
model=model, contents=messages, config=config
)
return response.text

def _raw_gen_stream(
self, baseself, model, messages, stream=True, tools=None, **kwargs
self,
baseself,
model,
messages,
stream=True,
tools=None,
formatting="openai",
**kwargs,
):
config = {}
model_name = "gemini-2.0-flash-exp"
client = self.client
if formatting == "openai":
cleaned_messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()

Check warning on line 103 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L100-L103

Added lines #L100 - L103 were not covered by tests

gen_model = genai.GenerativeModel(
model_name=model_name,
generation_config=config,
system_instruction=messages[0]["content"],
tools=self._clean_tools_format(tools),
)
chat_session = gen_model.start_chat(
history=self._clean_messages_google(messages)[:-1],
)
response = chat_session.send_message(
self._clean_messages_google(messages)[-1], stream=stream
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools

Check warning on line 107 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L105-L107

Added lines #L105 - L107 were not covered by tests

response = client.models.generate_content_stream(

Check warning on line 109 in application/llm/google_ai.py

View check run for this annotation

Codecov / codecov/patch

application/llm/google_ai.py#L109

Added line #L109 was not covered by tests
model=model,
contents=cleaned_messages,
config=config,
)
for chunk in response:
if chunk.text is not None:
Expand Down
2 changes: 0 additions & 2 deletions application/tools/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def _simple_tool_agent(self, messages):

resp = self.llm_handler.handle_response(self, resp, tools_dict, messages)

Check warning on line 96 in application/tools/agent.py

View check run for this annotation

Codecov / codecov/patch

application/tools/agent.py#L96

Added line #L96 was not covered by tests

# If no tool calls are needed, generate the final response
if isinstance(resp, str):
yield resp
elif hasattr(resp, "message") and hasattr(resp.message, "content"):

Check warning on line 100 in application/tools/agent.py

View check run for this annotation

Codecov / codecov/patch

application/tools/agent.py#L100

Added line #L100 was not covered by tests
Expand All @@ -110,7 +109,6 @@ def _simple_tool_agent(self, messages):
return

def gen(self, messages):
# Generate initial response from the LLM
if self.llm.supports_tools():
resp = self._simple_tool_agent(messages)
for line in resp:
Expand Down
64 changes: 31 additions & 33 deletions application/tools/llm_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,43 +47,41 @@ def handle_response(self, agent, resp, tools_dict, messages):

class GoogleLLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages):
import google.generativeai as genai
from google.genai import types

Check warning on line 50 in application/tools/llm_handler.py

View check run for this annotation

Codecov / codecov/patch

application/tools/llm_handler.py#L50

Added line #L50 was not covered by tests

while (
hasattr(resp.candidates[0].content.parts[0], "function_call")
and resp.candidates[0].content.parts[0].function_call
):
responses = {}
for part in resp.candidates[0].content.parts:
if hasattr(part, "function_call") and part.function_call:
function_call_part = part
messages.append(
genai.protos.Part(
function_call=genai.protos.FunctionCall(
name=function_call_part.function_call.name,
args=function_call_part.function_call.args,
)
)
)
tool_response, call_id = agent._execute_tool_action(
tools_dict, function_call_part.function_call
)
responses[function_call_part.function_call.name] = tool_response
response_parts = [
genai.protos.Part(
function_response=genai.protos.FunctionResponse(
name=tool_name, response={"result": response}
)
)
for tool_name, response in responses.items()
]
if response_parts:
messages.append(response_parts)
resp = agent.llm.gen(
while True:
response = agent.llm.gen(

Check warning on line 53 in application/tools/llm_handler.py

View check run for this annotation

Codecov / codecov/patch

application/tools/llm_handler.py#L52-L53

Added lines #L52 - L53 were not covered by tests
model=agent.gpt_model, messages=messages, tools=agent.tools
)
if response.candidates and response.candidates[0].content.parts:
tool_call_found = False
for part in response.candidates[0].content.parts:
if part.function_call:
tool_call_found = True
tool_response, call_id = agent._execute_tool_action(

Check warning on line 61 in application/tools/llm_handler.py

View check run for this annotation

Codecov / codecov/patch

application/tools/llm_handler.py#L56-L61

Added lines #L56 - L61 were not covered by tests
tools_dict, part.function_call
)

function_response_part = types.Part.from_function_response(

Check warning on line 65 in application/tools/llm_handler.py

View check run for this annotation

Codecov / codecov/patch

application/tools/llm_handler.py#L65

Added line #L65 was not covered by tests
name=part.function_call.name,
response={"result": tool_response},
)
messages.append({"role": "model", "content": [part]})
messages.append(

Check warning on line 70 in application/tools/llm_handler.py

View check run for this annotation

Codecov / codecov/patch

application/tools/llm_handler.py#L69-L70

Added lines #L69 - L70 were not covered by tests
{"role": "tool", "content": [function_response_part]}
)

if (

Check warning on line 74 in application/tools/llm_handler.py

View check run for this annotation

Codecov / codecov/patch

application/tools/llm_handler.py#L74

Added line #L74 was not covered by tests
not tool_call_found
and response.candidates[0].content.parts
and response.candidates[0].content.parts[0].text
):
return response.candidates[0].content.parts[0].text
elif not tool_call_found:
return response.candidates[0].content.parts

Check warning on line 81 in application/tools/llm_handler.py

View check run for this annotation

Codecov / codecov/patch

application/tools/llm_handler.py#L79-L81

Added lines #L79 - L81 were not covered by tests

return resp.text
else:
return response

Check warning on line 84 in application/tools/llm_handler.py

View check run for this annotation

Codecov / codecov/patch

application/tools/llm_handler.py#L84

Added line #L84 was not covered by tests


def get_llm_handler(llm_type):
Expand Down
9 changes: 3 additions & 6 deletions application/tools/tool_action_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import json

from google.protobuf.json_format import MessageToDict


class ToolActionParser:
def __init__(self, llm_type):
Expand All @@ -22,8 +20,7 @@ def _parse_openai_llm(self, call):
return tool_id, action_name, call_args

Check warning on line 20 in application/tools/tool_action_parser.py

View check run for this annotation

Codecov / codecov/patch

application/tools/tool_action_parser.py#L17-L20

Added lines #L17 - L20 were not covered by tests

def _parse_google_llm(self, call):
call = MessageToDict(call._pb)
call_args = call["args"]
tool_id = call["name"].split("_")[-1]
action_name = call["name"].rsplit("_", 1)[0]
call_args = call.args
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
return tool_id, action_name, call_args

Check warning on line 26 in application/tools/tool_action_parser.py

View check run for this annotation

Codecov / codecov/patch

application/tools/tool_action_parser.py#L23-L26

Added lines #L23 - L26 were not covered by tests

0 comments on commit 904b0bf

Please sign in to comment.