diff --git a/README.md b/README.md index 5c1afbc..9b0fad0 100644 --- a/README.md +++ b/README.md @@ -178,20 +178,20 @@ client.check_flag( ### Flag Check Options -By default, the client will do some local caching for flag checks. If you would like to change this behavior, you can do so using an initialization option to specify the max size of the cache (in bytes) and the max age of the cache (in seconds): +By default, the client will do some local caching for flag checks. If you would like to change this behavior, you can do so using an initialization option to specify the max size of the cache (in terms of number of entries) and the max age of the cache (in milliseconds): ```python from schematic.client import LocalCache, Schematic -cache_size_bytes = 1000000 +cache_size = 100 cache_ttl = 1000 # in milliseconds config = SchematicConfig( - cache_providers=[LocalCache[bool](cache_size_bytes, cache_ttl)], + cache_providers=[LocalCache[bool](cache_size, cache_ttl)], ) client = Schematic("YOUR_API_KEY", config) ``` -You can also disable local caching entirely with an initialization option; bear in mind that, in this case, every flag check will result in a network request: +You can also disable local caching entirely; bear in mind that, in this case, every flag check will result in a network request: ```python from schematic.client import Schematic @@ -200,7 +200,7 @@ config = SchematicConfig(cache_providers=[]) client = Schematic("YOUR_API_KEY", config) ``` -You may want to specify default flag values for your application, which will be used if there is a service interruption or if the client is running in offline mode (see below). You can do this using an initialization option: +You may want to specify default flag values for your application, which will be used if there is a service interruption or if the client is running in offline mode (see below): ```python from schematic.client import Schematic diff --git a/src/schematic/cache.py b/src/schematic/cache.py index 2bf701f..7b134f0 100644 --- a/src/schematic/cache.py +++ b/src/schematic/cache.py @@ -1,12 +1,11 @@ -import sys +from collections import OrderedDict import time -from typing import Dict, Generic, Optional, TypeVar - -DEFAULT_CACHE_SIZE = 10 * 1024 # 10KB -DEFAULT_CACHE_TTL = 5 # 5 seconds +from typing import Generic, Optional, TypeVar T = TypeVar("T") +DEFAULT_CACHE_SIZE = 1000 # 1000 items +DEFAULT_CACHE_TTL = 5000 # 5 seconds class CacheProvider(Generic[T]): def get(self, key: str) -> Optional[T]: @@ -15,42 +14,30 @@ def get(self, key: str) -> Optional[T]: def set(self, key: str, val: T, ttl_override: Optional[int] = None) -> None: pass - class CachedItem(Generic[T]): - def __init__(self, value: T, access_counter: int, size: int, expiration: float): + def __init__(self, value: T, expiration: float): self.value = value - self.access_counter = access_counter - self.size = size self.expiration = expiration - -class LocalCache(CacheProvider[T]): +class LocalCache(Generic[T]): def __init__(self, max_size: int, ttl: int): - self.cache: Dict[str, CachedItem[T]] = {} + self.cache = OrderedDict() self.max_size = max_size - self.current_size = 0 - self.access_counter = 0 self.ttl = ttl def get(self, key: str) -> Optional[T]: - if self.max_size == 0: + if self.max_size == 0 or key not in self.cache: return None - item = self.cache.get(key) - if item is None: - return None + item = self.cache[key] + current_time = time.time() * 1000 - # Check if the item has expired - if time.time() > item.expiration: - self.current_size -= item.size + if current_time > item.expiration: del self.cache[key] return None - # Update the access counter for LRU eviction - self.access_counter += 1 - item.access_counter = self.access_counter - self.cache[key] = item - + # Move the accessed item to the end (most recently used) + self.cache.move_to_end(key) return item.value def set(self, key: str, val: T, ttl_override: Optional[int] = None) -> None: @@ -58,40 +45,20 @@ def set(self, key: str, val: T, ttl_override: Optional[int] = None) -> None: return ttl = self.ttl if ttl_override is None else ttl_override - size = sys.getsizeof(val) + expiration = time.time() * 1000 + ttl - # Check if the key already exists in the cache + # If the key already exists, update it and move it to the end if key in self.cache: - item = self.cache[key] - self.current_size -= item.size - self.current_size += size - self.access_counter += 1 - self.cache[key] = CachedItem(val, self.access_counter, size, time.time() + ttl) - return - - # Evict expired items - for k, item in list(self.cache.items()): - if time.time() > item.expiration: - self.current_size -= item.size - del self.cache[k] - - # Evict records if the cache size exceeds the max size - while self.current_size + size > self.max_size: - oldest_key = None - oldest_access_counter = float("inf") - - for k, v in self.cache.items(): - if v.access_counter < oldest_access_counter: - oldest_key = k - oldest_access_counter = v.access_counter - - if oldest_key is not None: - self.current_size -= self.cache[oldest_key].size - del self.cache[oldest_key] - else: - break - - # Add the new item to the cache - self.access_counter += 1 - self.cache[key] = CachedItem(val, self.access_counter, size, time.time() + ttl) - self.current_size += size + self.cache[key] = CachedItem(val, expiration) + self.cache.move_to_end(key) + else: + # If we're at capacity, remove the least recently used item + if len(self.cache) >= self.max_size: + self.cache.popitem(last=False) + + # Add the new item + self.cache[key] = CachedItem(val, expiration) + + def clean_expired(self): + current_time = time.time() * 1000 + self.cache = OrderedDict((k, v) for k, v in self.cache.items() if v.expiration > current_time) diff --git a/tests/custom/test_cache.py b/tests/custom/test_cache.py index 294f3a2..5f7ba7e 100644 --- a/tests/custom/test_cache.py +++ b/tests/custom/test_cache.py @@ -1,15 +1,13 @@ import unittest -from unittest.mock import patch from schematic.cache import LocalCache class TestLocalCache(unittest.TestCase): def setUp(self): - self.cache = LocalCache(max_size=1024, ttl=2) + self.cache = LocalCache(max_size=2, ttl=2000) - @patch("sys.getsizeof", return_value=512) - def test_cache_size_limit(self, mock_getsizeof): + def test_cache_size_limit(self): self.cache.set("key1", "value1") self.cache.set("key2", "value2")