Skip to content

Commit

Permalink
Add ChatCompletionCache along with AbstractStore for caching completions
Browse files Browse the repository at this point in the history
  • Loading branch information
srjoglekar246 committed Jan 8, 2025
1 parent a427b38 commit 1e18fb3
Show file tree
Hide file tree
Showing 6 changed files with 448 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .abstract_store_base import AbstractStore

__all__ = ["AbstractStore"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any, Optional, Protocol


class AbstractStore(Protocol):
"""
This protocol defines the basic interface for store/cache operations.
Allows duck-typing with any object that implements the get and set methods,
such as redis or diskcache interfaces.
"""

def get(self, key: Any, default: Optional[Any] = None) -> Optional[Any]:
"""
Retrieve an item from the store.
Args:
key: The key identifying the item in the store.
default (optional): The default value to return if the key is not found.
Defaults to None.
Returns:
The value associated with the key if found, else the default value.
"""
...

def set(self, key: Any, value: Any) -> None:
"""
Set an item in the store.
Args:
key: The key under which the item is to be stored.
value: The value to be stored in the store.
"""
...
32 changes: 32 additions & 0 deletions python/packages/autogen-core/tests/test_abstract_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from unittest.mock import Mock

from autogen_core.store import AbstractStore


def test_set_and_get_object_key_value() -> None:
mock_store = Mock(spec=AbstractStore)
test_key = object()
test_value = object()
mock_store.set(test_key, test_value)
mock_store.get.return_value = test_value
mock_store.set.assert_called_with(test_key, test_value)
assert mock_store.get(test_key) == test_value


def test_get_non_existent_key() -> None:
mock_store = Mock(spec=AbstractStore)
key = "non_existent_key"
mock_store.get.return_value = None
assert mock_store.get(key) is None


def test_set_overwrite_existing_key() -> None:
mock_store = Mock(spec=AbstractStore)
key = "test_key"
initial_value = "initial_value"
new_value = "new_value"
mock_store.set(key, initial_value)
mock_store.set(key, new_value)
mock_store.get.return_value = new_value
mock_store.set.assert_called_with(key, new_value)
assert mock_store.get(key) == new_value
186 changes: 186 additions & 0 deletions python/packages/autogen-ext/src/autogen_ext/models/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import hashlib
import json
import warnings
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast

from autogen_core import CancellationToken
from autogen_core.models import (
ChatCompletionClient,
CreateResult,
LLMMessage,
ModelCapabilities, # type: ignore
ModelInfo,
RequestUsage,
)
from autogen_core.store import AbstractStore
from autogen_core.tools import Tool, ToolSchema


class ChatCompletionCache(ChatCompletionClient):
"""
A wrapper around a ChatCompletionClient that caches creation results from an underlying client.
Cache hits do not contribute to token usage of the original client.
"""

def __init__(self, client: ChatCompletionClient, store: AbstractStore):
"""
Initialize a new ChatCompletionCache.
Args:
client (ChatCompletionClient): The original ChatCompletionClient to wrap.
store (AbstractStore): A store object that implements get and set methods.
The user is responsible for managing the store's lifecycle & clearing it (if needed).
"""
self.client = client
self.store = store

def _check_cache(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool | ToolSchema],
json_output: Optional[bool],
extra_create_args: Mapping[str, Any],
force_cache: bool,
force_client: bool,
) -> tuple[Optional[Union[CreateResult, List[Union[str, CreateResult]]]], str]:
"""
Helper function to check the cache for a result.
Returns a tuple of (cached_result, cache_key).
cached_result is None if the cache is empty or force_client is True.
Raises an error if there is a cache miss and force_cache is True.
"""
if force_client and force_cache:
raise ValueError("force_cache and force_client cannot both be True")

data = {
"messages": [message.model_dump() for message in messages],
"tools": [(tool.schema if isinstance(tool, Tool) else tool) for tool in tools],
"json_output": json_output,
"extra_create_args": extra_create_args,
}
serialized_data = json.dumps(data, sort_keys=True)
cache_key = hashlib.sha256(serialized_data.encode()).hexdigest()

if not force_client:
cached_result = cast(Optional[CreateResult], self.store.get(cache_key))
if cached_result is not None:
return cached_result, cache_key
elif force_cache:
raise ValueError("Encountered cache miss for force_cache request")

return None, cache_key

async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
force_cache: bool = False,
force_client: bool = False,
) -> CreateResult:
"""
Cached version of ChatCompletionClient.create.
If the result of a call to create has been cached, it will be returned immediately
without invoking the underlying client.
NOTE: cancellation_token is ignored for cached results.
Additional parameters:
- force_cache: If True, the cache will be used and an error will be raised if a result is unavailable.
- force_client: If True, the cache will be bypassed and the underlying client will be called.
"""
cached_result, cache_key = self._check_cache(
messages, tools, json_output, extra_create_args, force_cache, force_client
)
if cached_result:
assert isinstance(cached_result, CreateResult)
cached_result.cached = True
return cached_result

result = await self.client.create(
messages,
tools=tools,
json_output=json_output,
extra_create_args=extra_create_args,
cancellation_token=cancellation_token,
)
self.store.set(cache_key, result)
return result

def create_stream(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
force_cache: bool = False,
force_client: bool = False,
) -> AsyncGenerator[Union[str, CreateResult], None]:
"""
Cached version of ChatCompletionClient.create_stream.
If the result of a call to create_stream has been cached, it will be returned
without streaming from the underlying client.
NOTE: cancellation_token is ignored for cached results.
Additional parameters:
- force_cache: If True, the cache will be used and an error will be raised if a result is unavailable.
- force_client: If True, the cache will be bypassed and the underlying client will be called.
"""

if force_client and force_cache:
raise ValueError("force_cache and force_client cannot both be True")

async def _generator() -> AsyncGenerator[Union[str, CreateResult], None]:
cached_result, cache_key = self._check_cache(
messages, tools, json_output, extra_create_args, force_cache, force_client
)
if cached_result:
assert isinstance(cached_result, list)
for result in cached_result:
if isinstance(result, CreateResult):
result.cached = True
yield result
return

result_stream = self.client.create_stream(
messages,
tools=tools,
json_output=json_output,
extra_create_args=extra_create_args,
)

output_results: List[Union[str, CreateResult]] = []
self.store.set(cache_key, output_results)

async for result in result_stream:
output_results.append(result)
yield result

return _generator()

def actual_usage(self) -> RequestUsage:
return self.client.actual_usage()

def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return self.client.count_tokens(messages, tools=tools)

@property
def capabilities(self) -> ModelCapabilities: # type: ignore
warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2)
return self.client.capabilities

@property
def model_info(self) -> ModelInfo:
return self.client.model_info

def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int:
return self.client.remaining_tokens(messages, tools=tools)

def total_usage(self) -> RequestUsage:
return self.client.total_usage()
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
self._current_index = 0
self._cached_bool_value = True

async def create(
self,
Expand All @@ -148,7 +149,9 @@ async def create(
if isinstance(response, str):
_, output_token_count = self._tokenize(response)
self._cur_usage = RequestUsage(prompt_tokens=prompt_token_count, completion_tokens=output_token_count)
response = CreateResult(finish_reason="stop", content=response, usage=self._cur_usage, cached=True)
response = CreateResult(
finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value
)
else:
self._cur_usage = RequestUsage(
prompt_tokens=prompt_token_count, completion_tokens=response.usage.completion_tokens
Expand Down Expand Up @@ -207,6 +210,9 @@ def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[To
0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens
)

def set_cached_bool_value(self, value: bool) -> None:
self._cached_bool_value = value

def _tokenize(self, messages: Union[str, LLMMessage, Sequence[LLMMessage]]) -> tuple[list[str], int]:
total_tokens = 0
all_tokens: List[str] = []
Expand Down
Loading

0 comments on commit 1e18fb3

Please sign in to comment.