-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ChatCompletionCache along with AbstractStore for caching completions
- Loading branch information
1 parent
a427b38
commit 1e18fb3
Showing
6 changed files
with
448 additions
and
1 deletion.
There are no files selected for viewing
3 changes: 3 additions & 0 deletions
3
python/packages/autogen-core/src/autogen_core/store/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .abstract_store_base import AbstractStore | ||
|
||
__all__ = ["AbstractStore"] |
34 changes: 34 additions & 0 deletions
34
python/packages/autogen-core/src/autogen_core/store/abstract_store_base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Any, Optional, Protocol | ||
|
||
|
||
class AbstractStore(Protocol): | ||
""" | ||
This protocol defines the basic interface for store/cache operations. | ||
Allows duck-typing with any object that implements the get and set methods, | ||
such as redis or diskcache interfaces. | ||
""" | ||
|
||
def get(self, key: Any, default: Optional[Any] = None) -> Optional[Any]: | ||
""" | ||
Retrieve an item from the store. | ||
Args: | ||
key: The key identifying the item in the store. | ||
default (optional): The default value to return if the key is not found. | ||
Defaults to None. | ||
Returns: | ||
The value associated with the key if found, else the default value. | ||
""" | ||
... | ||
|
||
def set(self, key: Any, value: Any) -> None: | ||
""" | ||
Set an item in the store. | ||
Args: | ||
key: The key under which the item is to be stored. | ||
value: The value to be stored in the store. | ||
""" | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from unittest.mock import Mock | ||
|
||
from autogen_core.store import AbstractStore | ||
|
||
|
||
def test_set_and_get_object_key_value() -> None: | ||
mock_store = Mock(spec=AbstractStore) | ||
test_key = object() | ||
test_value = object() | ||
mock_store.set(test_key, test_value) | ||
mock_store.get.return_value = test_value | ||
mock_store.set.assert_called_with(test_key, test_value) | ||
assert mock_store.get(test_key) == test_value | ||
|
||
|
||
def test_get_non_existent_key() -> None: | ||
mock_store = Mock(spec=AbstractStore) | ||
key = "non_existent_key" | ||
mock_store.get.return_value = None | ||
assert mock_store.get(key) is None | ||
|
||
|
||
def test_set_overwrite_existing_key() -> None: | ||
mock_store = Mock(spec=AbstractStore) | ||
key = "test_key" | ||
initial_value = "initial_value" | ||
new_value = "new_value" | ||
mock_store.set(key, initial_value) | ||
mock_store.set(key, new_value) | ||
mock_store.get.return_value = new_value | ||
mock_store.set.assert_called_with(key, new_value) | ||
assert mock_store.get(key) == new_value |
186 changes: 186 additions & 0 deletions
186
python/packages/autogen-ext/src/autogen_ext/models/cache.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import hashlib | ||
import json | ||
import warnings | ||
from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast | ||
|
||
from autogen_core import CancellationToken | ||
from autogen_core.models import ( | ||
ChatCompletionClient, | ||
CreateResult, | ||
LLMMessage, | ||
ModelCapabilities, # type: ignore | ||
ModelInfo, | ||
RequestUsage, | ||
) | ||
from autogen_core.store import AbstractStore | ||
from autogen_core.tools import Tool, ToolSchema | ||
|
||
|
||
class ChatCompletionCache(ChatCompletionClient): | ||
""" | ||
A wrapper around a ChatCompletionClient that caches creation results from an underlying client. | ||
Cache hits do not contribute to token usage of the original client. | ||
""" | ||
|
||
def __init__(self, client: ChatCompletionClient, store: AbstractStore): | ||
""" | ||
Initialize a new ChatCompletionCache. | ||
Args: | ||
client (ChatCompletionClient): The original ChatCompletionClient to wrap. | ||
store (AbstractStore): A store object that implements get and set methods. | ||
The user is responsible for managing the store's lifecycle & clearing it (if needed). | ||
""" | ||
self.client = client | ||
self.store = store | ||
|
||
def _check_cache( | ||
self, | ||
messages: Sequence[LLMMessage], | ||
tools: Sequence[Tool | ToolSchema], | ||
json_output: Optional[bool], | ||
extra_create_args: Mapping[str, Any], | ||
force_cache: bool, | ||
force_client: bool, | ||
) -> tuple[Optional[Union[CreateResult, List[Union[str, CreateResult]]]], str]: | ||
""" | ||
Helper function to check the cache for a result. | ||
Returns a tuple of (cached_result, cache_key). | ||
cached_result is None if the cache is empty or force_client is True. | ||
Raises an error if there is a cache miss and force_cache is True. | ||
""" | ||
if force_client and force_cache: | ||
raise ValueError("force_cache and force_client cannot both be True") | ||
|
||
data = { | ||
"messages": [message.model_dump() for message in messages], | ||
"tools": [(tool.schema if isinstance(tool, Tool) else tool) for tool in tools], | ||
"json_output": json_output, | ||
"extra_create_args": extra_create_args, | ||
} | ||
serialized_data = json.dumps(data, sort_keys=True) | ||
cache_key = hashlib.sha256(serialized_data.encode()).hexdigest() | ||
|
||
if not force_client: | ||
cached_result = cast(Optional[CreateResult], self.store.get(cache_key)) | ||
if cached_result is not None: | ||
return cached_result, cache_key | ||
elif force_cache: | ||
raise ValueError("Encountered cache miss for force_cache request") | ||
|
||
return None, cache_key | ||
|
||
async def create( | ||
self, | ||
messages: Sequence[LLMMessage], | ||
*, | ||
tools: Sequence[Tool | ToolSchema] = [], | ||
json_output: Optional[bool] = None, | ||
extra_create_args: Mapping[str, Any] = {}, | ||
cancellation_token: Optional[CancellationToken] = None, | ||
force_cache: bool = False, | ||
force_client: bool = False, | ||
) -> CreateResult: | ||
""" | ||
Cached version of ChatCompletionClient.create. | ||
If the result of a call to create has been cached, it will be returned immediately | ||
without invoking the underlying client. | ||
NOTE: cancellation_token is ignored for cached results. | ||
Additional parameters: | ||
- force_cache: If True, the cache will be used and an error will be raised if a result is unavailable. | ||
- force_client: If True, the cache will be bypassed and the underlying client will be called. | ||
""" | ||
cached_result, cache_key = self._check_cache( | ||
messages, tools, json_output, extra_create_args, force_cache, force_client | ||
) | ||
if cached_result: | ||
assert isinstance(cached_result, CreateResult) | ||
cached_result.cached = True | ||
return cached_result | ||
|
||
result = await self.client.create( | ||
messages, | ||
tools=tools, | ||
json_output=json_output, | ||
extra_create_args=extra_create_args, | ||
cancellation_token=cancellation_token, | ||
) | ||
self.store.set(cache_key, result) | ||
return result | ||
|
||
def create_stream( | ||
self, | ||
messages: Sequence[LLMMessage], | ||
*, | ||
tools: Sequence[Tool | ToolSchema] = [], | ||
json_output: Optional[bool] = None, | ||
extra_create_args: Mapping[str, Any] = {}, | ||
cancellation_token: Optional[CancellationToken] = None, | ||
force_cache: bool = False, | ||
force_client: bool = False, | ||
) -> AsyncGenerator[Union[str, CreateResult], None]: | ||
""" | ||
Cached version of ChatCompletionClient.create_stream. | ||
If the result of a call to create_stream has been cached, it will be returned | ||
without streaming from the underlying client. | ||
NOTE: cancellation_token is ignored for cached results. | ||
Additional parameters: | ||
- force_cache: If True, the cache will be used and an error will be raised if a result is unavailable. | ||
- force_client: If True, the cache will be bypassed and the underlying client will be called. | ||
""" | ||
|
||
if force_client and force_cache: | ||
raise ValueError("force_cache and force_client cannot both be True") | ||
|
||
async def _generator() -> AsyncGenerator[Union[str, CreateResult], None]: | ||
cached_result, cache_key = self._check_cache( | ||
messages, tools, json_output, extra_create_args, force_cache, force_client | ||
) | ||
if cached_result: | ||
assert isinstance(cached_result, list) | ||
for result in cached_result: | ||
if isinstance(result, CreateResult): | ||
result.cached = True | ||
yield result | ||
return | ||
|
||
result_stream = self.client.create_stream( | ||
messages, | ||
tools=tools, | ||
json_output=json_output, | ||
extra_create_args=extra_create_args, | ||
) | ||
|
||
output_results: List[Union[str, CreateResult]] = [] | ||
self.store.set(cache_key, output_results) | ||
|
||
async for result in result_stream: | ||
output_results.append(result) | ||
yield result | ||
|
||
return _generator() | ||
|
||
def actual_usage(self) -> RequestUsage: | ||
return self.client.actual_usage() | ||
|
||
def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: | ||
return self.client.count_tokens(messages, tools=tools) | ||
|
||
@property | ||
def capabilities(self) -> ModelCapabilities: # type: ignore | ||
warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2) | ||
return self.client.capabilities | ||
|
||
@property | ||
def model_info(self) -> ModelInfo: | ||
return self.client.model_info | ||
|
||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: | ||
return self.client.remaining_tokens(messages, tools=tools) | ||
|
||
def total_usage(self) -> RequestUsage: | ||
return self.client.total_usage() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.