From d9e6a1d20c6d36a7f66a61221dbeddf52bccd9e0 Mon Sep 17 00:00:00 2001 From: Jon Bergland Date: Thu, 24 Oct 2024 19:15:29 +0200 Subject: [PATCH] refactor: optimize graphAgent --- core/graphAgent.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/core/graphAgent.py b/core/graphAgent.py index b0364f6..8be02ac 100644 --- a/core/graphAgent.py +++ b/core/graphAgent.py @@ -1,3 +1,4 @@ +from typing import Literal from langchain_openai import ChatOpenAI from graphstate import GraphState from tools.tools import get_tools @@ -14,6 +15,7 @@ import functools class Graph: + MAIN_AGENT = "chatbot" def __init__(self): LANGCHAIN_TRACING_V2: str = "true" @@ -22,17 +24,17 @@ def __init__(self): self.workflow = StateGraph(GraphState) - self.workflow.add_node("chatbot", self.chatbot) + self.workflow.add_node(self.MAIN_AGENT, self.chatbot) self.workflow.add_node("tools", ToolNode(get_tools())) - self.workflow.add_edge(START, "chatbot") - self.workflow.add_edge("tools", "chatbot") - self.workflow.add_edge("chatbot", END) + self.workflow.add_edge(START, self.MAIN_AGENT) + self.workflow.add_edge("tools", self.MAIN_AGENT) # Defining conditional edges self.workflow.add_conditional_edges( - "chatbot", - tools_condition + self.MAIN_AGENT, + tools_condition, + {"tools": "tools", "__end__": END} ) self.graph = self.workflow.compile() @@ -73,8 +75,10 @@ async def run(self, user_prompt: str, socketio): # There may be better events to base the response on if event_type == 'on_chain_stream' and event['name'] == 'LangGraph': chunk = event['data']['chunk'] - if 'chatbot' in chunk: - ai_message = event['data']['chunk']['chatbot']['messages'][-1] + + # Filters the stream to only get events by main agent + if self.MAIN_AGENT in chunk: + ai_message = event['data']['chunk'][self.MAIN_AGENT]['messages'][-1] if isinstance(ai_message, AIMessage): if 'tool_calls' in ai_message.additional_kwargs: