diff --git a/python/packages/autogen-core/src/autogen_core/store/__init__.py b/python/packages/autogen-core/src/autogen_core/store/__init__.py new file mode 100644 index 000000000000..83986fabe13e --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/store/__init__.py @@ -0,0 +1,3 @@ +from .abstract_store_base import AbstractStore + +__all__ = ["AbstractStore"] diff --git a/python/packages/autogen-core/src/autogen_core/store/abstract_store_base.py b/python/packages/autogen-core/src/autogen_core/store/abstract_store_base.py new file mode 100644 index 000000000000..29b2b06b71d6 --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/store/abstract_store_base.py @@ -0,0 +1,34 @@ +from typing import Any, Optional, Protocol + + +class AbstractStore(Protocol): + """ + This protocol defines the basic interface for store/cache operations. + + Allows duck-typing with any object that implements the get and set methods, + such as redis or diskcache interfaces. + """ + + def get(self, key: Any, default: Optional[Any] = None) -> Optional[Any]: + """ + Retrieve an item from the store. + + Args: + key: The key identifying the item in the store. + default (optional): The default value to return if the key is not found. + Defaults to None. + + Returns: + The value associated with the key if found, else the default value. + """ + ... + + def set(self, key: Any, value: Any) -> None: + """ + Set an item in the store. + + Args: + key: The key under which the item is to be stored. + value: The value to be stored in the store. + """ + ... diff --git a/python/packages/autogen-core/tests/test_abstract_store.py b/python/packages/autogen-core/tests/test_abstract_store.py new file mode 100644 index 000000000000..8e1c0786ff3c --- /dev/null +++ b/python/packages/autogen-core/tests/test_abstract_store.py @@ -0,0 +1,32 @@ +from unittest.mock import Mock + +from autogen_core.store import AbstractStore + + +def test_set_and_get_object_key_value() -> None: + mock_store = Mock(spec=AbstractStore) + test_key = object() + test_value = object() + mock_store.set(test_key, test_value) + mock_store.get.return_value = test_value + mock_store.set.assert_called_with(test_key, test_value) + assert mock_store.get(test_key) == test_value + + +def test_get_non_existent_key() -> None: + mock_store = Mock(spec=AbstractStore) + key = "non_existent_key" + mock_store.get.return_value = None + assert mock_store.get(key) is None + + +def test_set_overwrite_existing_key() -> None: + mock_store = Mock(spec=AbstractStore) + key = "test_key" + initial_value = "initial_value" + new_value = "new_value" + mock_store.set(key, initial_value) + mock_store.set(key, new_value) + mock_store.get.return_value = new_value + mock_store.set.assert_called_with(key, new_value) + assert mock_store.get(key) == new_value diff --git a/python/packages/autogen-ext/src/autogen_ext/models/cache.py b/python/packages/autogen-ext/src/autogen_ext/models/cache.py new file mode 100644 index 000000000000..ba683c1f6763 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/cache.py @@ -0,0 +1,186 @@ +import hashlib +import json +import warnings +from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast + +from autogen_core import CancellationToken +from autogen_core.models import ( + ChatCompletionClient, + CreateResult, + LLMMessage, + ModelCapabilities, # type: ignore + ModelInfo, + RequestUsage, +) +from autogen_core.store import AbstractStore +from autogen_core.tools import Tool, ToolSchema + + +class ChatCompletionCache(ChatCompletionClient): + """ + A wrapper around a ChatCompletionClient that caches creation results from an underlying client. + Cache hits do not contribute to token usage of the original client. + """ + + def __init__(self, client: ChatCompletionClient, store: AbstractStore): + """ + Initialize a new ChatCompletionCache. + + Args: + client (ChatCompletionClient): The original ChatCompletionClient to wrap. + store (AbstractStore): A store object that implements get and set methods. + The user is responsible for managing the store's lifecycle & clearing it (if needed). + """ + self.client = client + self.store = store + + def _check_cache( + self, + messages: Sequence[LLMMessage], + tools: Sequence[Tool | ToolSchema], + json_output: Optional[bool], + extra_create_args: Mapping[str, Any], + force_cache: bool, + force_client: bool, + ) -> tuple[Optional[Union[CreateResult, List[Union[str, CreateResult]]]], str]: + """ + Helper function to check the cache for a result. + Returns a tuple of (cached_result, cache_key). + cached_result is None if the cache is empty or force_client is True. + Raises an error if there is a cache miss and force_cache is True. + """ + if force_client and force_cache: + raise ValueError("force_cache and force_client cannot both be True") + + data = { + "messages": [message.model_dump() for message in messages], + "tools": [(tool.schema if isinstance(tool, Tool) else tool) for tool in tools], + "json_output": json_output, + "extra_create_args": extra_create_args, + } + serialized_data = json.dumps(data, sort_keys=True) + cache_key = hashlib.sha256(serialized_data.encode()).hexdigest() + + if not force_client: + cached_result = cast(Optional[CreateResult], self.store.get(cache_key)) + if cached_result is not None: + return cached_result, cache_key + elif force_cache: + raise ValueError("Encountered cache miss for force_cache request") + + return None, cache_key + + async def create( + self, + messages: Sequence[LLMMessage], + *, + tools: Sequence[Tool | ToolSchema] = [], + json_output: Optional[bool] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + force_cache: bool = False, + force_client: bool = False, + ) -> CreateResult: + """ + Cached version of ChatCompletionClient.create. + If the result of a call to create has been cached, it will be returned immediately + without invoking the underlying client. + + NOTE: cancellation_token is ignored for cached results. + + Additional parameters: + - force_cache: If True, the cache will be used and an error will be raised if a result is unavailable. + - force_client: If True, the cache will be bypassed and the underlying client will be called. + """ + cached_result, cache_key = self._check_cache( + messages, tools, json_output, extra_create_args, force_cache, force_client + ) + if cached_result: + assert isinstance(cached_result, CreateResult) + cached_result.cached = True + return cached_result + + result = await self.client.create( + messages, + tools=tools, + json_output=json_output, + extra_create_args=extra_create_args, + cancellation_token=cancellation_token, + ) + self.store.set(cache_key, result) + return result + + def create_stream( + self, + messages: Sequence[LLMMessage], + *, + tools: Sequence[Tool | ToolSchema] = [], + json_output: Optional[bool] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + force_cache: bool = False, + force_client: bool = False, + ) -> AsyncGenerator[Union[str, CreateResult], None]: + """ + Cached version of ChatCompletionClient.create_stream. + If the result of a call to create_stream has been cached, it will be returned + without streaming from the underlying client. + + NOTE: cancellation_token is ignored for cached results. + + Additional parameters: + - force_cache: If True, the cache will be used and an error will be raised if a result is unavailable. + - force_client: If True, the cache will be bypassed and the underlying client will be called. + """ + + if force_client and force_cache: + raise ValueError("force_cache and force_client cannot both be True") + + async def _generator() -> AsyncGenerator[Union[str, CreateResult], None]: + cached_result, cache_key = self._check_cache( + messages, tools, json_output, extra_create_args, force_cache, force_client + ) + if cached_result: + assert isinstance(cached_result, list) + for result in cached_result: + if isinstance(result, CreateResult): + result.cached = True + yield result + return + + result_stream = self.client.create_stream( + messages, + tools=tools, + json_output=json_output, + extra_create_args=extra_create_args, + ) + + output_results: List[Union[str, CreateResult]] = [] + self.store.set(cache_key, output_results) + + async for result in result_stream: + output_results.append(result) + yield result + + return _generator() + + def actual_usage(self) -> RequestUsage: + return self.client.actual_usage() + + def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: + return self.client.count_tokens(messages, tools=tools) + + @property + def capabilities(self) -> ModelCapabilities: # type: ignore + warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2) + return self.client.capabilities + + @property + def model_info(self) -> ModelInfo: + return self.client.model_info + + def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: + return self.client.remaining_tokens(messages, tools=tools) + + def total_usage(self) -> RequestUsage: + return self.client.total_usage() diff --git a/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py b/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py index b62084b646b1..68d005232f7f 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py @@ -129,6 +129,7 @@ def __init__( self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) self._current_index = 0 + self._cached_bool_value = True async def create( self, @@ -148,7 +149,9 @@ async def create( if isinstance(response, str): _, output_token_count = self._tokenize(response) self._cur_usage = RequestUsage(prompt_tokens=prompt_token_count, completion_tokens=output_token_count) - response = CreateResult(finish_reason="stop", content=response, usage=self._cur_usage, cached=True) + response = CreateResult( + finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value + ) else: self._cur_usage = RequestUsage( prompt_tokens=prompt_token_count, completion_tokens=response.usage.completion_tokens @@ -207,6 +210,9 @@ def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[To 0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens ) + def set_cached_bool_value(self, value: bool) -> None: + self._cached_bool_value = value + def _tokenize(self, messages: Union[str, LLMMessage, Sequence[LLMMessage]]) -> tuple[list[str], int]: total_tokens = 0 all_tokens: List[str] = [] diff --git a/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py b/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py new file mode 100644 index 000000000000..1891661b5189 --- /dev/null +++ b/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py @@ -0,0 +1,186 @@ +import copy +from typing import Any, List, Optional, Tuple, Union + +import pytest +from autogen_core.models import ChatCompletionClient, CreateResult, LLMMessage, SystemMessage, UserMessage +from autogen_core.store import AbstractStore +from autogen_ext.models.cache import ChatCompletionCache +from autogen_ext.models.replay import ReplayChatCompletionClient + + +class DictStore(AbstractStore): + def __init__(self) -> None: + self._store: dict[Any, Any] = {} + + def get(self, key: Any, default: Optional[Any] = None) -> Optional[Any]: + return self._store.get(key, default) + + def set(self, key: Any, value: Any) -> None: + self._store[key] = value + + +def get_test_data() -> Tuple[list[str], list[str], SystemMessage, ChatCompletionClient, ChatCompletionCache]: + num_messages = 3 + responses = [f"This is dummy message number {i}" for i in range(num_messages)] + prompts = [f"This is dummy prompt number {i}" for i in range(num_messages)] + system_prompt = SystemMessage(content="This is a system prompt") + replay_client = ReplayChatCompletionClient(responses) + replay_client.set_cached_bool_value(False) + cached_client = ChatCompletionCache(replay_client, store=DictStore()) + + return responses, prompts, system_prompt, replay_client, cached_client + + +@pytest.mark.asyncio +async def test_cache_basic_with_args() -> None: + responses, prompts, system_prompt, _, cached_client = get_test_data() + + response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")]) + assert isinstance(response0, CreateResult) + assert not response0.cached + assert response0.content == responses[0] + + response1 = await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")]) + assert not response1.cached + assert response1.content == responses[1] + + # Cached output. + response0_cached = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")]) + assert isinstance(response0, CreateResult) + assert response0_cached.cached + assert response0_cached.content == responses[0] + + # Cache miss if args change. + response2 = await cached_client.create( + [system_prompt, UserMessage(content=prompts[0], source="user")], json_output=True + ) + assert isinstance(response2, CreateResult) + assert not response2.cached + assert response2.content == responses[2] + + +@pytest.mark.asyncio +async def test_cache_model_and_count_api() -> None: + _, prompts, system_prompt, replay_client, cached_client = get_test_data() + + assert replay_client.model_info == cached_client.model_info + + messages: List[LLMMessage] = [system_prompt, UserMessage(content=prompts[0], source="user")] + assert replay_client.count_tokens(messages) == cached_client.count_tokens(messages) + assert replay_client.remaining_tokens(messages) == cached_client.remaining_tokens(messages) + + +@pytest.mark.asyncio +async def test_cache_force_cache() -> None: + responses, prompts, system_prompt, _, cached_client = get_test_data() + + response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")]) + assert isinstance(response0, CreateResult) + assert not response0.cached + assert response0.content == responses[0] + + response0_cached = await cached_client.create( + [system_prompt, UserMessage(content=prompts[0], source="user")], force_cache=True + ) + assert isinstance(response0_cached, CreateResult) + assert response0_cached.cached + assert response0_cached.content == responses[0] + + # Ensure error when force_cache=True and cache miss. + with pytest.raises(ValueError, match="Encountered cache miss for force_cache request"): + await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")], force_cache=True) + + +@pytest.mark.asyncio +async def test_cache_force_client() -> None: + responses, prompts, system_prompt, _, cached_client = get_test_data() + + response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")]) + assert isinstance(response0, CreateResult) + assert not response0.cached + assert response0.content == responses[0] + + response1 = await cached_client.create( + [system_prompt, UserMessage(content=prompts[0], source="user")], force_client=True + ) + assert isinstance(response1, CreateResult) + assert not response1.cached + assert response1.content == responses[1] + + response2 = await cached_client.create( + [system_prompt, UserMessage(content=prompts[1], source="user")], force_client=True + ) + assert isinstance(response2, CreateResult) + assert not response2.cached + assert response2.content == responses[2] + + with pytest.raises(ValueError, match="force_cache and force_client cannot both be True"): + await cached_client.create( + [system_prompt, UserMessage(content=prompts[2], source="user")], force_cache=True, force_client=True + ) + + +@pytest.mark.asyncio +async def test_cache_token_usage() -> None: + responses, prompts, system_prompt, replay_client, cached_client = get_test_data() + + response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")]) + assert isinstance(response0, CreateResult) + assert not response0.cached + assert response0.content == responses[0] + actual_usage0 = copy.copy(cached_client.actual_usage()) + total_usage0 = copy.copy(cached_client.total_usage()) + + response1 = await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")]) + assert not response1.cached + assert response1.content == responses[1] + actual_usage1 = copy.copy(cached_client.actual_usage()) + total_usage1 = copy.copy(cached_client.total_usage()) + assert total_usage1.prompt_tokens > total_usage0.prompt_tokens + assert total_usage1.completion_tokens > total_usage0.completion_tokens + assert actual_usage1.prompt_tokens == actual_usage0.prompt_tokens + assert actual_usage1.completion_tokens == actual_usage0.completion_tokens + + # Cached output. + response0_cached = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")]) + assert isinstance(response0, CreateResult) + assert response0_cached.cached + assert response0_cached.content == responses[0] + total_usage2 = copy.copy(cached_client.total_usage()) + assert total_usage2.prompt_tokens == total_usage1.prompt_tokens + assert total_usage2.completion_tokens == total_usage1.completion_tokens + + assert cached_client.actual_usage() == replay_client.actual_usage() + assert cached_client.total_usage() == replay_client.total_usage() + + +@pytest.mark.asyncio +async def test_cache_create_stream() -> None: + _, prompts, system_prompt, _, cached_client = get_test_data() + + original_streamed_results: List[Union[str, CreateResult]] = [] + async for completion in cached_client.create_stream( + [system_prompt, UserMessage(content=prompts[0], source="user")] + ): + original_streamed_results.append(completion) + total_usage0 = copy.copy(cached_client.total_usage()) + + cached_completion_results: List[Union[str, CreateResult]] = [] + async for completion in cached_client.create_stream( + [system_prompt, UserMessage(content=prompts[0], source="user")] + ): + cached_completion_results.append(completion) + total_usage1 = copy.copy(cached_client.total_usage()) + + assert total_usage1.prompt_tokens == total_usage0.prompt_tokens + assert total_usage1.completion_tokens == total_usage0.completion_tokens + + for original, cached in zip(original_streamed_results, cached_completion_results, strict=False): + if isinstance(original, str): + assert original == cached + elif isinstance(original, CreateResult) and isinstance(cached, CreateResult): + assert original.content == cached.content + assert cached.cached + assert not original.cached + else: + raise ValueError(f"Unexpected types : {type(original)} and {type(cached)}")