From f4955499494082d0073a5982664ac6bfaae9a3a4 Mon Sep 17 00:00:00 2001 From: Marco Perini Date: Wed, 2 Oct 2024 12:33:00 +0200 Subject: [PATCH] feat: added return types --- brickllm/graphs/brickschema_graph.py | 18 ++++++++++++++---- examples/brickschema_ttl.py | 12 +++++++----- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/brickllm/graphs/brickschema_graph.py b/brickllm/graphs/brickschema_graph.py index a9f9b92..658d07b 100644 --- a/brickllm/graphs/brickschema_graph.py +++ b/brickllm/graphs/brickschema_graph.py @@ -1,4 +1,3 @@ -import json from langgraph.graph import START, END, StateGraph from .. import State, GraphConfig from ..nodes import ( @@ -51,7 +50,7 @@ def _compiled_graph(self): raise ValueError("Graph is not compiled yet. Please compile the graph first.") return self.graph - def display(self, filename="graph.png"): + def display(self, filename="graph.png") -> None: """Display the compiled graph as an image. Args: @@ -80,10 +79,21 @@ def run(self, prompt, stream=False): input_data = {"user_prompt": prompt} if stream: + events = [] # Stream the content of the graph state at each node for event in self.graph.stream(input_data, self.config, stream_mode="values"): - print(json.dumps(event, indent=2)) + events.append(event) + return events else: # Invoke the graph without streaming result = self.graph.invoke(input_data, self.config) - print(json.dumps(result, indent=2)) + return result + + def get_state_snapshots(self) -> list: + """Get all the state snapshots from the graph execution.""" + + all_states = [] + for state in self.graph.get_state_history(self.config): + all_states.append(state) + + return all_states \ No newline at end of file diff --git a/examples/brickschema_ttl.py b/examples/brickschema_ttl.py index e0948be..0bc9fee 100644 --- a/examples/brickschema_ttl.py +++ b/examples/brickschema_ttl.py @@ -1,8 +1,5 @@ from brickllm.graphs import BrickSchemaGraph -# Create an instance of BrickSchemaGraph -brick_graph = BrickSchemaGraph() - # Specify the user prompt building_description = """ I have a building located in Bolzano. @@ -13,11 +10,16 @@ - CO sensor. """ +# Create an instance of BrickSchemaGraph +brick_graph = BrickSchemaGraph() + # Display the graph brick_graph.display() # Run the graph without streaming -brick_graph.run( +result = brick_graph.run( prompt=building_description, stream=False -) \ No newline at end of file +) + +print(result) \ No newline at end of file