diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 24043d9cd..c14c89b99 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -1,86 +1,61 @@ -import google.generativeai as genai +from google import genai +from google.genai import types from application.core.settings import settings from application.llm.base import BaseLLM class GoogleLLM(BaseLLM): - def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): super().__init__(*args, **kwargs) - self.api_key = settings.API_KEY - genai.configure(api_key=self.api_key) + self.client = genai.Client(api_key="AIzaSyDmbZX65qlQKXcvfMBkJV2KwH82_0yIMlE") def _clean_messages_google(self, messages): cleaned_messages = [] - for message in messages[1:]: - cleaned_messages.append( - { - "role": "model" if message["role"] == "system" else message["role"], - "parts": [message["content"]], - } - ) + for message in messages: + role = message.get("role") + content = message.get("content") + + if role and content is not None: + if isinstance(content, str): + parts = [types.Part.from_text(content)] + elif isinstance(content, list): + parts = content + else: + raise ValueError(f"Unexpected content type: {type(content)}") + + cleaned_messages.append(types.Content(role=role, parts=parts)) + return cleaned_messages - def _clean_tools_format(self, tools_data): - if isinstance(tools_data, list): - return [self._clean_tools_format(item) for item in tools_data] - elif isinstance(tools_data, dict): - if ( - "function" in tools_data - and "type" in tools_data - and tools_data["type"] == "function" - ): - # Handle the case where tools are nested under 'function' - cleaned_function = self._clean_tools_format(tools_data["function"]) - return {"function_declarations": [cleaned_function]} - elif ( - "function" in tools_data - and "type_" in tools_data - and tools_data["type_"] == "function" - ): - # Handle the case where tools are nested under 'function' and type is already 'type_' - cleaned_function = self._clean_tools_format(tools_data["function"]) - return {"function_declarations": [cleaned_function]} - else: - new_tools_data = {} - for key, value in tools_data.items(): - if key == "type": - if value == "string": - new_tools_data["type_"] = "STRING" - elif value == "object": - new_tools_data["type_"] = "OBJECT" - elif key == "additionalProperties": - continue - elif key == "properties": - if isinstance(value, dict): - new_properties = {} - for prop_name, prop_value in value.items(): - if ( - isinstance(prop_value, dict) - and "type" in prop_value - ): - if prop_value["type"] == "string": - new_properties[prop_name] = { - "type_": "STRING", - "description": prop_value.get( - "description" - ), - } - # Add more type mappings as needed - else: - new_properties[prop_name] = ( - self._clean_tools_format(prop_value) - ) - new_tools_data[key] = new_properties - else: - new_tools_data[key] = self._clean_tools_format(value) + def _clean_tools_format(self, tools_list): + genai_tools = [] + for tool_data in tools_list: + if tool_data["type"] == "function": + function = tool_data["function"] + genai_function = dict( + name=function["name"], + description=function["description"], + parameters={ + "type": "OBJECT", + "properties": { + k: { + **v, + "type": v["type"].upper() if v["type"] else None, + } + for k, v in function["parameters"]["properties"].items() + }, + "required": ( + function["parameters"]["required"] + if "required" in function["parameters"] + else [] + ), + }, + ) + genai_tool = types.Tool(function_declarations=[genai_function]) + genai_tools.append(genai_tool) - else: - new_tools_data[key] = self._clean_tools_format(value) - return new_tools_data - else: - return tools_data + return genai_tools def _raw_gen( self, @@ -90,61 +65,51 @@ def _raw_gen( stream=False, tools=None, formatting="openai", - **kwargs + **kwargs, ): - config = {} - model_name = "gemini-2.0-flash-exp" + client = self.client + if formatting == "openai": + messages = self._clean_messages_google(messages) + config = types.GenerateContentConfig() - if formatting == "raw": - client = genai.GenerativeModel(model_name=model_name) - response = client.generate_content(contents=messages) - return response.text + if tools: + cleaned_tools = self._clean_tools_format(tools) + config.tools = cleaned_tools + response = client.models.generate_content( + model=model, + contents=messages, + config=config, + ) + return response else: - if tools: - client = genai.GenerativeModel( - model_name=model_name, - generation_config=config, - system_instruction=messages[0]["content"], - tools=self._clean_tools_format(tools), - ) - chat_session = gen_model.start_chat( - history=self._clean_messages_google(messages)[:-1] - ) - response = chat_session.send_message( - self._clean_messages_google(messages)[-1] - ) - return response - else: - gen_model = genai.GenerativeModel( - model_name=model_name, - generation_config=config, - system_instruction=messages[0]["content"], - ) - chat_session = gen_model.start_chat( - history=self._clean_messages_google(messages)[:-1] - ) - response = chat_session.send_message( - self._clean_messages_google(messages)[-1] - ) - return response.text + response = client.models.generate_content( + model=model, contents=messages, config=config + ) + return response.text def _raw_gen_stream( - self, baseself, model, messages, stream=True, tools=None, **kwargs + self, + baseself, + model, + messages, + stream=True, + tools=None, + formatting="openai", + **kwargs, ): - config = {} - model_name = "gemini-2.0-flash-exp" + client = self.client + if formatting == "openai": + cleaned_messages = self._clean_messages_google(messages) + config = types.GenerateContentConfig() - gen_model = genai.GenerativeModel( - model_name=model_name, - generation_config=config, - system_instruction=messages[0]["content"], - tools=self._clean_tools_format(tools), - ) - chat_session = gen_model.start_chat( - history=self._clean_messages_google(messages)[:-1], - ) - response = chat_session.send_message( - self._clean_messages_google(messages)[-1], stream=stream + if tools: + cleaned_tools = self._clean_tools_format(tools) + config.tools = cleaned_tools + + response = client.models.generate_content_stream( + model=model, + contents=cleaned_messages, + config=config, ) for chunk in response: if chunk.text is not None: diff --git a/application/tools/agent.py b/application/tools/agent.py index bbf6bcac3..209184d2c 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -95,7 +95,6 @@ def _simple_tool_agent(self, messages): resp = self.llm_handler.handle_response(self, resp, tools_dict, messages) - # If no tool calls are needed, generate the final response if isinstance(resp, str): yield resp elif hasattr(resp, "message") and hasattr(resp.message, "content"): @@ -110,7 +109,6 @@ def _simple_tool_agent(self, messages): return def gen(self, messages): - # Generate initial response from the LLM if self.llm.supports_tools(): resp = self._simple_tool_agent(messages) for line in resp: diff --git a/application/tools/llm_handler.py b/application/tools/llm_handler.py index 6be89ad79..2383d3f55 100644 --- a/application/tools/llm_handler.py +++ b/application/tools/llm_handler.py @@ -47,43 +47,41 @@ def handle_response(self, agent, resp, tools_dict, messages): class GoogleLLMHandler(LLMHandler): def handle_response(self, agent, resp, tools_dict, messages): - import google.generativeai as genai + from google.genai import types - while ( - hasattr(resp.candidates[0].content.parts[0], "function_call") - and resp.candidates[0].content.parts[0].function_call - ): - responses = {} - for part in resp.candidates[0].content.parts: - if hasattr(part, "function_call") and part.function_call: - function_call_part = part - messages.append( - genai.protos.Part( - function_call=genai.protos.FunctionCall( - name=function_call_part.function_call.name, - args=function_call_part.function_call.args, - ) - ) - ) - tool_response, call_id = agent._execute_tool_action( - tools_dict, function_call_part.function_call - ) - responses[function_call_part.function_call.name] = tool_response - response_parts = [ - genai.protos.Part( - function_response=genai.protos.FunctionResponse( - name=tool_name, response={"result": response} - ) - ) - for tool_name, response in responses.items() - ] - if response_parts: - messages.append(response_parts) - resp = agent.llm.gen( + while True: + response = agent.llm.gen( model=agent.gpt_model, messages=messages, tools=agent.tools ) + if response.candidates and response.candidates[0].content.parts: + tool_call_found = False + for part in response.candidates[0].content.parts: + if part.function_call: + tool_call_found = True + tool_response, call_id = agent._execute_tool_action( + tools_dict, part.function_call + ) + + function_response_part = types.Part.from_function_response( + name=part.function_call.name, + response={"result": tool_response}, + ) + messages.append({"role": "model", "content": [part]}) + messages.append( + {"role": "tool", "content": [function_response_part]} + ) + + if ( + not tool_call_found + and response.candidates[0].content.parts + and response.candidates[0].content.parts[0].text + ): + return response.candidates[0].content.parts[0].text + elif not tool_call_found: + return response.candidates[0].content.parts - return resp.text + else: + return response def get_llm_handler(llm_type): diff --git a/application/tools/tool_action_parser.py b/application/tools/tool_action_parser.py index 254c13b4c..ac0a70c16 100644 --- a/application/tools/tool_action_parser.py +++ b/application/tools/tool_action_parser.py @@ -1,7 +1,5 @@ import json -from google.protobuf.json_format import MessageToDict - class ToolActionParser: def __init__(self, llm_type): @@ -22,8 +20,7 @@ def _parse_openai_llm(self, call): return tool_id, action_name, call_args def _parse_google_llm(self, call): - call = MessageToDict(call._pb) - call_args = call["args"] - tool_id = call["name"].split("_")[-1] - action_name = call["name"].rsplit("_", 1)[0] + call_args = call.args + tool_id = call.name.split("_")[-1] + action_name = call.name.rsplit("_", 1)[0] return tool_id, action_name, call_args