Skip to content

Commit

Permalink
Merge pull request #15 from CogitoNTNU/astream_events_experiment
Browse files Browse the repository at this point in the history
Merge astream_events branch into main
  • Loading branch information
WilliamMRS authored Oct 17, 2024
2 parents 60ee229 + 8eeb862 commit e2de081
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 13 deletions.
119 changes: 119 additions & 0 deletions core/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from langchain_openai import ChatOpenAI
from graphstate import GraphState
from tools.tools import get_tools
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import BaseMessage, AIMessageChunk, HumanMessage
from models import Model
import json
from config import OPENAI_API_KEY



class Agent1:
llm = ChatOpenAI(
model = Model.gpt_4o,
temperature=0,
max_tokens=512,
#streaming=True, #Can't use because of metadata
)


class Agent:
def __init__(self, model_type) -> None:
#Langsmith Tracing
LANGCHAIN_TRACING_V2: str = "true"

self.llm = ChatOpenAI(
model=model_type,
temperature=0,
max_tokens=512,
)

self.llm_with_tools = self.llm.bind_tools(get_tools())

self.workflow = StateGraph(GraphState)

# Adding nodes to the workflow
self.workflow.add_node("chatbot", self.chatbot)
self.workflow.add_node("tools", ToolNode(get_tools()))
# TODO: Visualize these tools

# Defining edges between nodes
self.workflow.add_edge(START, "chatbot")
self.workflow.add_edge("tools", "chatbot")
self.workflow.add_edge("chatbot", END)

# Defining conditional edges
self.workflow.add_conditional_edges(
"chatbot",
tools_condition
)

self.graph = self.workflow.compile() # Compiles the workflow in a graph.

#Saving image of graph node comment in and out as needed
#with open("core/graph_node_network.png", 'wb') as f:
# f.write(self.graph.get_graph().draw_mermaid_png())

def chatbot(self, state: GraphState):
"""
Simple bot that invokes the list of previous messages
and returns the result which will be added to the list of messages.
"""
#state_of_chatbot = self.llm_with_tools.invoke(state["messages"]).tool_calls
#print("Tools called: " + state_of_chatbot["name"][-1].content)

return {"messages": [self.llm_with_tools.invoke(state["messages"])]}


# UNFINISHED
def run_stream_only(self, user_prompt: str):
"""
Run the agent, returning a token stream.
"""
print('Running stream...')
print(user_prompt)
print(type(user_prompt))
for chunk in self.llm.stream(user_prompt):
yield chunk.content

#for running the agent comment out for testing in terminal
def run(self, user_prompt: str) -> tuple[str, int]:
"""
Run the agent with a user prompt and return a tuple containing the llm
response and the total amount of tokens used.
"""
first = True
for event in self.graph.stream("tell me about orcas?"):
for value in event.values():
messages = value["messages"][-1]
gathered = ""
# if messages.content and not isinstance(messages, HumanMessage):
# print(messages.content, end="|", flush=True)

gathered += messages

if isinstance(messages, BaseMessage):
if hasattr(messages, 'usage_metadata'):
total_tokens = messages.usage_metadata.get('total_tokens', 0)
gathered += messages.content
else:
print(f"Warning: Message of type {type(messages)} does not have usage_metadata")

return gathered, total_tokens

#for testing in terminal
""" def run(self, user_prompt: str):
for event in self.graph.stream({"messages": [("user", user_prompt)]}):
for value in event.values():
if isinstance(value["messages"][-1], BaseMessage):
print("Assistant:", value["messages"][-1].content)
if __name__ == "__main__":
agent = Agent("gpt-4o-mini")
while True:
user_prompt = input("User: ")
agent.run(user_prompt) """

# Counting tokens: https://python.langchain.com/docs/how_to/llm_token_usage_tracking/
39 changes: 26 additions & 13 deletions core/graphAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tools.tools import get_tools
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import BaseMessage, AIMessageChunk, HumanMessage
from langchain_core.messages import BaseMessage, AIMessageChunk, HumanMessage, AIMessage
from models import Model
import json
from config import OPENAI_API_KEY
Expand All @@ -29,6 +29,7 @@ def __init__(self):
# Defining edges between nodes
self.workflow.add_edge(START, "chatbot")
self.workflow.add_edge("tools", "chatbot")
self.workflow.add_edge("chatbot", END)

# Defining conditional edges
self.workflow.add_conditional_edges(
Expand Down Expand Up @@ -60,24 +61,36 @@ def run_stream_only(self, user_prompt: str):
yield chunk.content

#for running the agent comment out for testing in terminal
async def run(self, user_prompt: str, socketio) -> tuple[str, int]:
async def run(self, user_prompt: str, socketio):
"""
Run the agent with a user prompt and return a tuple containing the llm
response and the total amount of tokens used.
Run the agent with a user prompt and emit the response and total tokens via socket
"""
try:
input = {"messages": [("human", user_prompt)]}
socketio.emit("start_message", " ")
async for chunk in self.graph.astream(input, stream_mode="values"):
if type(chunk["messages"][-1]) == HumanMessage:
async for event in self.graph.astream_events(input, version='v2'):
event_type = event.get('event')

# Passes over events that are start events
if event_type == 'on_chain_start':
print("This event is on_chain_start")
continue
event_message = chunk["messages"][-1].content
event_message = event_message.split(" ")
for word in event_message:
sleep(0.05)
socketio.emit("chunk", word+" ")
socketio.emit("chunk", "<br>")
socketio.emit("tokens", 0) # a way to emit ending of the message

# Returns the AI response
# //TODO Fix that it streams chuncks it rather than AIMessage
if event_type == 'on_chain_end':
print(event['data'])
for message in event['data']['output']['messages']:
if isinstance(message, AIMessage):
data = message.content
socketio.emit("chunk", data)

if hasattr(message, 'usage_metadata'):
usage_metadata = message.usage_metadata
if usage_metadata:
total_tokens = usage_metadata.get('total_tokens')
socketio.emit("tokens", total_tokens)

return "success"
except Exception as e:
print(e)
Expand Down

0 comments on commit e2de081

Please sign in to comment.