diff --git a/llama-index-core/llama_index/core/agent/runner/base.py b/llama-index-core/llama_index/core/agent/runner/base.py index 47061a3953b43..4f1ad21426f13 100644 --- a/llama-index-core/llama_index/core/agent/runner/base.py +++ b/llama-index-core/llama_index/core/agent/runner/base.py @@ -141,6 +141,14 @@ def finalize_response( ) -> AGENT_CHAT_RESPONSE_TYPE: """Finalize response.""" + async def afinalize_response( + self, + task_id: str, + step_output: Optional[TaskStepOutput] = None, + ) -> AGENT_CHAT_RESPONSE_TYPE: + """Finalize response.""" + return self.finalize_response(task_id, step_output) + @abstractmethod def undo_step(self, task_id: str) -> None: """Undo previous step.""" @@ -559,6 +567,43 @@ def finalize_response( return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output) + @dispatcher.span + async def afinalize_response( + self, + task_id: str, + step_output: Optional[TaskStepOutput] = None, + ) -> AGENT_CHAT_RESPONSE_TYPE: + """Finalize response.""" + if step_output is None: + step_output = self.state.get_completed_steps(task_id)[-1] + if not step_output.is_last: + raise ValueError( + "finalize_response can only be called on the last step output" + ) + + if not isinstance( + step_output.output, + (AgentChatResponse, StreamingAgentChatResponse), + ): + raise ValueError( + "When `is_last` is True, cur_step_output.output must be " + f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}" + ) + + # finalize task + await self.agent_worker.afinalize_task(self.state.get_task(task_id)) + + if self.delete_task_on_finish: + self.delete_task(task_id) + + # Attach all sources generated across all steps + step_output.output.sources = self.get_task(task_id).extra_state.get( + "sources", [] + ) + step_output.output.set_source_nodes() + + return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output) + @dispatcher.span def _chat( self, @@ -622,7 +667,7 @@ async def _achat( # ensure tool_choice does not cause endless loops tool_choice = "auto" - result = self.finalize_response( + result = await self.afinalize_response( task.task_id, result_output, ) diff --git a/llama-index-core/llama_index/core/base/agent/types.py b/llama-index-core/llama_index/core/base/agent/types.py index 69ed99b0e505c..ab314a59ef285 100644 --- a/llama-index-core/llama_index/core/base/agent/types.py +++ b/llama-index-core/llama_index/core/base/agent/types.py @@ -237,6 +237,10 @@ async def astream_step( def finalize_task(self, task: Task, **kwargs: Any) -> None: """Finalize task, after all the steps are completed.""" + async def afinalize_task(self, task: Task, **kwargs: Any) -> None: + """Finalize task, after all the steps are completed.""" + self.finalize_task(task, **kwargs) + def set_callback_manager(self, callback_manager: CallbackManager) -> None: """Set callback manager.""" # TODO: make this abstractmethod (right now will break some agent impls) diff --git a/llama-index-core/llama_index/core/chat_engine/types.py b/llama-index-core/llama_index/core/chat_engine/types.py index 24ccdb7bba473..a100a697ee08c 100644 --- a/llama-index-core/llama_index/core/chat_engine/types.py +++ b/llama-index-core/llama_index/core/chat_engine/types.py @@ -4,6 +4,8 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum +from functools import partial +from inspect import iscoroutinefunction from queue import Queue, Empty from threading import Event from typing import AsyncGenerator, Callable, Generator, List, Optional, Union, Dict, Any @@ -265,7 +267,14 @@ async def awrite_response_to_history( self.is_function_false_event.set() self.new_item_event.set() if on_stream_end_fn is not None and not self.is_function: - on_stream_end_fn() + if iscoroutinefunction( + on_stream_end_fn.func + if isinstance(on_stream_end_fn, partial) + else on_stream_end_fn + ): + await on_stream_end_fn() + else: + on_stream_end_fn() @property def response_gen(self) -> Generator[str, None, None]: diff --git a/llama-index-core/llama_index/core/memory/types.py b/llama-index-core/llama_index/core/memory/types.py index f78688dab6353..3b47734962d42 100644 --- a/llama-index-core/llama_index/core/memory/types.py +++ b/llama-index-core/llama_index/core/memory/types.py @@ -50,6 +50,11 @@ def put_messages(self, messages: List[ChatMessage]) -> None: for message in messages: self.put(message) + async def aput_messages(self, messages: List[ChatMessage]) -> None: + """Put chat history.""" + for message in messages: + await self.aput(message) + @abstractmethod def set(self, messages: List[ChatMessage]) -> None: """Set chat history.""" diff --git a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py index d772782b07382..cc13e9c3fdd76 100644 --- a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py +++ b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py @@ -280,7 +280,7 @@ async def _get_async_stream_ai_response( asyncio.create_task( chat_stream_response.awrite_response_to_history( task.extra_state["new_memory"], - on_stream_end_fn=partial(self.finalize_task, task), + on_stream_end_fn=partial(self.afinalize_task, task), ) ) chat_stream_response._ensure_async_setup() @@ -789,6 +789,13 @@ def finalize_task(self, task: Task, **kwargs: Any) -> None: # reset new memory task.extra_state["new_memory"].reset() + async def afinalize_task(self, task: Task, **kwargs: Any) -> None: + """Finalize task, after all the steps are completed.""" + # add new messages to memory + await task.memory.aput_messages(task.extra_state["new_memory"].get_all()) + # reset new memory + task.extra_state["new_memory"].reset() + def undo_step(self, task: Task, **kwargs: Any) -> Optional[TaskStep]: """Undo step from task.