Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async memory management for OpenAIAgentWorker #17375

Merged
merged 15 commits into from
Jan 10, 2025
47 changes: 46 additions & 1 deletion llama-index-core/llama_index/core/agent/runner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def finalize_response(
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Finalize response."""

async def afinalize_response(
self,
task_id: str,
step_output: Optional[TaskStepOutput] = None,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Finalize response."""
return self.finalize_response(task_id, step_output)

@abstractmethod
def undo_step(self, task_id: str) -> None:
"""Undo previous step."""
Expand Down Expand Up @@ -559,6 +567,43 @@ def finalize_response(

return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output)

@dispatcher.span
async def afinalize_response(
self,
task_id: str,
step_output: Optional[TaskStepOutput] = None,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Finalize response."""
if step_output is None:
step_output = self.state.get_completed_steps(task_id)[-1]
if not step_output.is_last:
raise ValueError(
"finalize_response can only be called on the last step output"
)

if not isinstance(
step_output.output,
(AgentChatResponse, StreamingAgentChatResponse),
):
raise ValueError(
"When `is_last` is True, cur_step_output.output must be "
f"AGENT_CHAT_RESPONSE_TYPE: {step_output.output}"
)

# finalize task
await self.agent_worker.afinalize_task(self.state.get_task(task_id))

if self.delete_task_on_finish:
self.delete_task(task_id)

# Attach all sources generated across all steps
step_output.output.sources = self.get_task(task_id).extra_state.get(
"sources", []
)
step_output.output.set_source_nodes()

return cast(AGENT_CHAT_RESPONSE_TYPE, step_output.output)

@dispatcher.span
def _chat(
self,
Expand Down Expand Up @@ -622,7 +667,7 @@ async def _achat(
# ensure tool_choice does not cause endless loops
tool_choice = "auto"

result = self.finalize_response(
result = await self.afinalize_response(
task.task_id,
result_output,
)
Expand Down
4 changes: 4 additions & 0 deletions llama-index-core/llama_index/core/base/agent/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ async def astream_step(
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""

async def afinalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
self.finalize_task(task, **kwargs)

def set_callback_manager(self, callback_manager: CallbackManager) -> None:
"""Set callback manager."""
# TODO: make this abstractmethod (right now will break some agent impls)
Expand Down
11 changes: 10 additions & 1 deletion llama-index-core/llama_index/core/chat_engine/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from inspect import iscoroutinefunction
from queue import Queue, Empty
from threading import Event
from typing import AsyncGenerator, Callable, Generator, List, Optional, Union, Dict, Any
Expand Down Expand Up @@ -265,7 +267,14 @@ async def awrite_response_to_history(
self.is_function_false_event.set()
self.new_item_event.set()
if on_stream_end_fn is not None and not self.is_function:
on_stream_end_fn()
if iscoroutinefunction(
on_stream_end_fn.func
if isinstance(on_stream_end_fn, partial)
else on_stream_end_fn
):
await on_stream_end_fn()
else:
on_stream_end_fn()

@property
def response_gen(self) -> Generator[str, None, None]:
Expand Down
5 changes: 5 additions & 0 deletions llama-index-core/llama_index/core/memory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def put_messages(self, messages: List[ChatMessage]) -> None:
for message in messages:
self.put(message)

async def aput_messages(self, messages: List[ChatMessage]) -> None:
"""Put chat history."""
for message in messages:
await self.aput(message)

@abstractmethod
def set(self, messages: List[ChatMessage]) -> None:
"""Set chat history."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ async def _get_async_stream_ai_response(
asyncio.create_task(
chat_stream_response.awrite_response_to_history(
task.extra_state["new_memory"],
on_stream_end_fn=partial(self.finalize_task, task),
on_stream_end_fn=partial(self.afinalize_task, task),
)
)
chat_stream_response._ensure_async_setup()
Expand Down Expand Up @@ -789,6 +789,13 @@ def finalize_task(self, task: Task, **kwargs: Any) -> None:
# reset new memory
task.extra_state["new_memory"].reset()

async def afinalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
# add new messages to memory
await task.memory.aput_messages(task.extra_state["new_memory"].get_all())
# reset new memory
task.extra_state["new_memory"].reset()

def undo_step(self, task: Task, **kwargs: Any) -> Optional[TaskStep]:
"""Undo step from task.

Expand Down
Loading