From 11da4ef89b6fa1a4573e8e0c5bc0b79489fd649c Mon Sep 17 00:00:00 2001 From: Krukov Date: Sun, 1 Dec 2024 19:13:38 +0300 Subject: [PATCH] feat: serializer composition --- cashews/backends/diskcache.py | 35 +++++---- cashews/backends/interface.py | 5 +- cashews/backends/memory.py | 22 +++--- cashews/backends/redis/__init__.py | 7 +- cashews/backends/redis/backend.py | 17 +++-- cashews/backends/redis/client_side.py | 7 +- cashews/backends/transaction.py | 2 + cashews/commands.py | 1 + cashews/contrib/fastapi.py | 2 +- cashews/decorators/cache/defaults.py | 2 +- cashews/helpers.py | 19 +++-- cashews/key.py | 9 ++- cashews/picklers.py | 32 +++++---- cashews/serialize.py | 95 +++++++------------------ cashews/wrapper/backend_settings.py | 26 ++++--- cashews/wrapper/commands.py | 26 ++++--- cashews/wrapper/disable_control.py | 2 +- cashews/wrapper/transaction.py | 29 ++++---- cashews/wrapper/wrapper.py | 36 +++++----- examples/keys.py | 16 +++-- examples/simple.py | 5 +- pytest.ini | 2 +- tests/conftest.py | 16 ----- tests/test_add_prefix.py | 8 ++- tests/test_cache.py | 2 +- tests/test_client_side_cache.py | 6 +- tests/test_intergations/test_fastapi.py | 5 ++ tests/test_middleware.py | 59 +++++++-------- tests/test_pickle_serializer.py | 24 +++---- tests/test_redis_down.py | 4 +- tests/test_settings_url.py | 6 +- tests/test_tags_feature.py | 3 +- tests/test_wrapper.py | 18 ++--- 33 files changed, 281 insertions(+), 267 deletions(-) diff --git a/cashews/backends/diskcache.py b/cashews/backends/diskcache.py index e7922cd7..f15b1a93 100644 --- a/cashews/backends/diskcache.py +++ b/cashews/backends/diskcache.py @@ -8,14 +8,15 @@ from diskcache import Cache, FanoutCache from cashews._typing import Key, Value -from cashews.serialize import SerializerMixin +from cashews.serialize import DEFAULT_SERIALIZER, Serializer from cashews.utils import Bitarray from .interface import NOT_EXIST, UNLIMITED, Backend -class _DiskCache(Backend): +class DiskCache(Backend): def __init__(self, *args, directory=None, shards=8, **kwargs: Any) -> None: + serializer = kwargs.pop("serializer", DEFAULT_SERIALIZER) self.__is_init = False self._set_locks: dict[str, asyncio.Lock] = {} self._sharded = shards > 1 @@ -23,7 +24,8 @@ def __init__(self, *args, directory=None, shards=8, **kwargs: Any) -> None: self._cache = Cache(directory=directory, **kwargs) else: self._cache = FanoutCache(directory=directory, shards=shards, **kwargs) - super().__init__(**kwargs) + super().__init__(serializer=serializer, **kwargs) + self._serializer: Serializer async def init(self): self.__is_init = True @@ -46,6 +48,7 @@ async def set( expire: float | None = None, exist: bool | None = None, ) -> bool: + value = await self._serializer.encode(self, key=key, value=value, expire=expire) future = self._run_in_executor(self._set, key, value, expire, exist) if exist is not None: # we should have async lock until value real set @@ -69,25 +72,34 @@ async def set_raw(self, key: Key, value: Any, **kwargs: Any): return self._cache.set(key, value, **kwargs) async def get(self, key: Key, default: Value | None = None) -> Value: - return await self._run_in_executor(self._cache.get, key, default) + value = await self._run_in_executor(self._cache.get, key, default) + return await self._serializer.decode(self, key=key, value=value, default=default) async def get_raw(self, key: Key) -> Value: return self._cache.get(key) - async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value]: - return await self._run_in_executor(self._get_many, keys, default) + async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value | None, ...]: + if not keys: + return () + values = await self._run_in_executor(self._get_many, keys, default) + values = await asyncio.gather( + *[self._serializer.decode(self, key=key, value=value, default=default) for key, value in zip(keys, values)] + ) + return tuple(None if isinstance(value, Bitarray) else value for value in values) def _get_many(self, keys: list[Key], default: Value | None = None): values = [] for key in keys: val = self._cache.get(key, default=default) - if isinstance(val, Bitarray): - val = None values.append(val) return values async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None): - return await self._run_in_executor(self._set_many, pairs, expire) + _pairs = {} + for key, value in pairs.items(): + value = await self._serializer.encode(self, key=key, value=value, expire=expire) + _pairs[key] = value + return await self._run_in_executor(self._set_many, _pairs, expire) def _set_many(self, pairs: Mapping[Key, Value], expire: float | None = None): for key, value in pairs.items(): @@ -215,6 +227,7 @@ async def is_locked( return await self.exists(key) async def unlock(self, key: Key, value: Value) -> bool: + value = await self._serializer.encode(self, key=key, value=value, expire=None) return await self._run_in_executor(self._unlock, key, value) def _unlock(self, key: Key, value: Value) -> bool: @@ -269,7 +282,3 @@ async def set_pop(self, key: Key, count: int = 100) -> Iterable[str]: async def get_keys_count(self) -> int: return await self._run_in_executor(lambda: len(self._cache)) - - -class DiskCache(SerializerMixin, _DiskCache): - pass diff --git a/cashews/backends/interface.py b/cashews/backends/interface.py index a223c3f2..afa2cd72 100644 --- a/cashews/backends/interface.py +++ b/cashews/backends/interface.py @@ -9,6 +9,7 @@ from cashews.commands import ALL, Command from cashews.exceptions import CacheBackendInteractionError, LockedError +from cashews.serialize import Serializer if TYPE_CHECKING: # pragma: no cover from cashews._typing import Default, Key, OnRemoveCallback, Value @@ -226,8 +227,10 @@ def enable(self, *cmds: Command) -> None: class Backend(ControlMixin, _BackendInterface, metaclass=ABCMeta): - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args, serializer: Serializer | None = None, **kwargs) -> None: super().__init__() + self._id = uuid.uuid4().hex + self._serializer = serializer self._on_remove_callbacks: list[OnRemoveCallback] = [] def on_remove_callback(self, callback: OnRemoveCallback) -> None: diff --git a/cashews/backends/memory.py b/cashews/backends/memory.py index f1e0a561..80779faf 100644 --- a/cashews/backends/memory.py +++ b/cashews/backends/memory.py @@ -8,7 +8,6 @@ from copy import copy from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Mapping, overload -from cashews.serialize import SerializerMixin from cashews.utils import Bitarray, get_obj_size from .interface import NOT_EXIST, UNLIMITED, Backend @@ -22,7 +21,7 @@ _missed = object() -class _Memory(Backend): +class Memory(Backend): """ Inmemory backend lru with ttl """ @@ -74,17 +73,22 @@ async def set( ) -> bool: if exist is not None and (key in self.store) is not exist: return False + if self._serializer: + value = await self._serializer.encode(self, key=key, value=value, expire=expire) self._set(key, value, expire) return True async def set_raw(self, key: Key, value: Value, **kwargs: Any) -> None: - self.store[key] = value + self.store[key] = (None, value) async def get(self, key: Key, default: Value | None = None) -> Value: return await self._get(key, default=default) async def get_raw(self, key: Key) -> Value: - return self.store.get(key) + val = self.store.get(key) + if val: + return val[1] + return None async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value | None, ...]: values = [] @@ -97,6 +101,8 @@ async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Valu async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None): for key, value in pairs.items(): + if self._serializer: + value = await self._serializer.encode(self, key=key, value=value, expire=expire) self._set(key, value, expire) async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: # type: ignore @@ -200,7 +206,9 @@ async def _get(self, key: Key, default: Default | None = None) -> Value | None: if expire_at and expire_at < time.time(): await self._delete(key) return default - return value + if not self._serializer: + return value + return await self._serializer.decode(self, key=key, value=value, default=default) async def _key_exist(self, key: Key) -> bool: return (await self._get(key, default=_missed)) is not _missed @@ -279,7 +287,3 @@ async def close(self): del self.__remove_expired_stop self.__remove_expired_stop = None self.__is_init = False - - -class Memory(SerializerMixin, _Memory): - pass diff --git a/cashews/backends/redis/__init__.py b/cashews/backends/redis/__init__.py index 455ba21e..f36076c1 100644 --- a/cashews/backends/redis/__init__.py +++ b/cashews/backends/redis/__init__.py @@ -1,10 +1,7 @@ -from cashews.picklers import DEFAULT_PICKLE -from cashews.serialize import SerializerMixin - from .backend import _Redis __all__ = ["Redis"] -class Redis(SerializerMixin, _Redis): - pickle_type = DEFAULT_PICKLE +class Redis(_Redis): + pass diff --git a/cashews/backends/redis/backend.py b/cashews/backends/redis/backend.py index f717e057..6b957e8e 100644 --- a/cashews/backends/redis/backend.py +++ b/cashews/backends/redis/backend.py @@ -8,6 +8,7 @@ from cashews._typing import Key, Value from cashews.backends.interface import Backend +from cashews.serialize import DEFAULT_SERIALIZER, Serializer from .client import Redis, SafePipeline, SafeRedis @@ -76,7 +77,8 @@ def __init__( self._kwargs = kwargs self._address = address self.__is_init = False - super().__init__() + super().__init__(serializer=kwargs.pop("serializer", None)) + self._serializer: Serializer = self._serializer or DEFAULT_SERIALIZER @property def is_init(self) -> bool: @@ -105,6 +107,7 @@ async def set( expire: float | None = None, exist=None, ) -> bool: + value = await self._serializer.encode(self, key=key, value=value, expire=expire) nx = xx = False if exist is True: xx = True @@ -118,6 +121,7 @@ async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None px = int(expire * 1000) if expire else None async with self._pipeline as pipe: for key, value in pairs.items(): + value = await self._serializer.encode(self, key=key, value=value, expire=expire) await pipe.set(key, value, px=px) await pipe.execute() @@ -211,7 +215,7 @@ async def get_size(self, key: Key) -> int: async def get(self, key: Key, default: Value | None = None) -> Value: value = await self._client.get(key) - return self._transform_value(value, default) + return await self._transform_value(key, value, default) async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Value | None, ...]: if not keys: @@ -219,15 +223,16 @@ async def get_many(self, *keys: Key, default: Value | None = None) -> tuple[Valu values = await self._client.mget(*keys) if values is None: return tuple([default] * len(keys)) - return tuple(self._transform_value(value, default) for value in values) + return tuple( + await asyncio.gather(*[self._transform_value(key, value, default) for key, value in zip(keys, values)]) + ) - @staticmethod - def _transform_value(value: bytes | None, default: Value | None): + async def _transform_value(self, key: Key, value: bytes | None, default: Value | None): if value is None: return default if value.isdigit(): return int(value) - return value + return await self._serializer.decode(self, key=key, value=value, default=default) async def incr(self, key: Key, value: int = 1, expire: float | None = None) -> int: if not expire: diff --git a/cashews/backends/redis/client_side.py b/cashews/backends/redis/client_side.py index 1fbbc58b..da6636eb 100644 --- a/cashews/backends/redis/client_side.py +++ b/cashews/backends/redis/client_side.py @@ -83,11 +83,12 @@ def __init__( self._expire_for_recently_update = 5 self._listen_started = asyncio.Event() self.__listen_stop = asyncio.Event() - super().__init__(*args, suppress=suppress, **kwargs) + kwargs["suppress"] = suppress + super().__init__(*args, **kwargs) async def init(self): - self._listen_started = asyncio.Event() - self.__listen_stop = asyncio.Event() + self._listen_started.clear() + self.__listen_stop.clear() await self._local_cache.init() await self._recently_update.init() await super().init() diff --git a/cashews/backends/transaction.py b/cashews/backends/transaction.py index 64d1f4e9..71b9aa89 100644 --- a/cashews/backends/transaction.py +++ b/cashews/backends/transaction.py @@ -21,6 +21,7 @@ class TransactionBackend(Backend): "_local_cache", "_to_delete", "__disable", + "_id", ] def __init__(self, backend: Backend): @@ -28,6 +29,7 @@ def __init__(self, backend: Backend): self._local_cache = Memory() self._to_delete: set[Key] = set() super().__init__() + self._id = backend._id def _key_is_delete(self, key: Key) -> bool: if key in self._to_delete: diff --git a/cashews/commands.py b/cashews/commands.py index 43b8d31d..0efd91be 100644 --- a/cashews/commands.py +++ b/cashews/commands.py @@ -13,6 +13,7 @@ class Command(Enum): DELETE_MANY = "delete_many" DELETE_MATCH = "delete_match" + EXISTS = "exists" EXIST = "exists" SCAN = "scan" INCR = "incr" diff --git a/cashews/contrib/fastapi.py b/cashews/contrib/fastapi.py index 976abbcf..b3a09621 100644 --- a/cashews/contrib/fastapi.py +++ b/cashews/contrib/fastapi.py @@ -154,7 +154,7 @@ def set_callback(key: str, result: Any): _data = None else: _key = calls[0][0] - _data = calls[0][1][0]["value"] + _data = calls[0][1]["value"] _etag = await self._set_etag(_key, _data) return self._response_etag(response, _etag, request_etag) diff --git a/cashews/decorators/cache/defaults.py b/cashews/decorators/cache/defaults.py index 476386dd..118cefc7 100644 --- a/cashews/decorators/cache/defaults.py +++ b/cashews/decorators/cache/defaults.py @@ -16,7 +16,7 @@ def __init__(self, previous_level=0, unset_token=None): self._previous_level = previous_level def _set(self, key: Key, **kwargs: Any) -> None: - self._value.append((key, [kwargs])) + self._value.append((key, kwargs)) @property def calls(self): diff --git a/cashews/helpers.py b/cashews/helpers.py index c4fdc4a9..70dfeb9a 100644 --- a/cashews/helpers.py +++ b/cashews/helpers.py @@ -52,11 +52,20 @@ async def _middleware(call: AsyncCallable_T, cmd: Command, backend: Backend, *ar def memory_limit(min_bytes: int = 0, max_bytes: int | None = None) -> Middleware: async def _middleware(call: AsyncCallable_T, cmd: Command, backend: Backend, *args, **kwargs) -> Result_T | None: - if cmd != Command.SET: - return await call(*args, **kwargs) - value_size = get_obj_size(kwargs["value"]) - if max_bytes and value_size > max_bytes or value_size < min_bytes: - return None + if cmd == Command.SET_MANY: + pairs = {} + for key, value in kwargs["pairs"].items(): + value_size = get_obj_size(value) + if max_bytes and value_size > max_bytes or value_size < min_bytes: + continue + pairs[key] = value + if not pairs: + return None + kwargs["pairs"] = pairs + elif cmd == Command.SET: + value_size = get_obj_size(kwargs["value"]) + if max_bytes and value_size > max_bytes or value_size < min_bytes: + return None return await call(*args, **kwargs) return _middleware diff --git a/cashews/key.py b/cashews/key.py index 7ef9048f..23c354c0 100644 --- a/cashews/key.py +++ b/cashews/key.py @@ -106,10 +106,13 @@ def generate_key_template(func: Callable, exclude_parameters: Container = ()) -> class _Star: def __getattr__(self, item): - return _Star() + return self def __getitem__(self, item): - return _Star() + return self + + def __call__(self, *args, **kwargs): + return "*" def _check_key_params(key: KeyOrTemplate, func_params: Iterable[str]): @@ -142,7 +145,7 @@ def _get_func_signature(func: Callable): def _get_call_values(func: Callable, args: Args, kwargs: Kwargs): - if len(args) == 0: + if not args: _kwargs = {**kwargs} for name, parameter in _get_func_signature(func).parameters.items(): if parameter.kind != inspect.Parameter.VAR_KEYWORD and name in _kwargs: diff --git a/cashews/picklers.py b/cashews/picklers.py index 66e1815e..b8b74394 100644 --- a/cashews/picklers.py +++ b/cashews/picklers.py @@ -1,5 +1,6 @@ import json import pickle +from enum import Enum from ._typing import Value from .exceptions import UnsupportedPicklerError @@ -76,29 +77,34 @@ def dumps(cls, value) -> bytes: return json.dumps(value, default=cls.json_serial).encode() -DEFAULT_PICKLE = "default" -NULL_PICKLE = "null" +class PicklerType(Enum): + DEFAULT = "default" + NULL = "null" + JSON = "json" + DILL = "dill" + SQLALCHEMY = "sqlalchemy" + _picklers = { - DEFAULT_PICKLE: Pickler, - "sqlalchemy": SQLAlchemyPickler, - "dill": DillPickler, - NULL_PICKLE: NonPickler, - "json": JsonPickler, + PicklerType.DEFAULT: Pickler, + PicklerType.SQLALCHEMY: SQLAlchemyPickler, + PicklerType.DILL: DillPickler, + PicklerType.NULL: NonPickler, + PicklerType.JSON: JsonPickler, } -def get_pickler(name: str): - if name not in _picklers: +def get_pickler(pickler_type: PicklerType): + if pickler_type not in _picklers: raise UnsupportedPicklerError() - if name == "sqlalchemy" and not _SQLALC_PICKLE: + if pickler_type == PicklerType.SQLALCHEMY and not _SQLALC_PICKLE: raise UnsupportedPicklerError() - if name == "dill" and not _DILL_PICKLE: + if pickler_type == PicklerType.DILL and not _DILL_PICKLE: raise UnsupportedPicklerError() - return _picklers[name] + return _picklers[pickler_type] -DEFAULT_PICKLER = get_pickler(DEFAULT_PICKLE) +DEFAULT_PICKLER = get_pickler(PicklerType.DEFAULT) diff --git a/cashews/serialize.py b/cashews/serialize.py index ba2a02fe..f04ff5fe 100644 --- a/cashews/serialize.py +++ b/cashews/serialize.py @@ -2,10 +2,10 @@ import hashlib import hmac -from typing import TYPE_CHECKING, Any, Mapping +from typing import TYPE_CHECKING from .exceptions import SignIsMissingError, UnSecureDataError -from .picklers import DEFAULT_PICKLE, NULL_PICKLE, Pickler, get_pickler +from .picklers import Pickler, PicklerType, get_pickler if TYPE_CHECKING: # pragma: no cover from ._typing import ICustomDecoder, ICustomEncoder, Key, Value @@ -24,72 +24,6 @@ def simple_sign(key: bytes, value: bytes) -> bytes: return f"{s:x}".encode() -class SerializerMixin: - pickle_type = NULL_PICKLE - - def __init__( - self, - *args, - secret: str | bytes | None = None, - digestmod: str | bytes = b"md5", - check_repr: bool = True, - pickle_type: str | None = None, - **kwargs: Any, - ): - super().__init__(*args, **kwargs) - self._serializer = Serializer(check_repr=check_repr) - if secret: - self._serializer.set_signer(HashSigner(secret, digestmod)) - - self._serializer.set_pickler(self._get_pickler(pickle_type, bool(secret))) - - @classmethod - def _get_pickler(cls, pickle_type: str | None, hash_key: bool) -> Pickler: - pickle_type = pickle_type or cls.pickle_type - if pickle_type is NULL_PICKLE and hash_key: - pickle_type = DEFAULT_PICKLE - return get_pickler(pickle_type) - - async def get(self, key: Key, default: Value | None = None): - raw_value = await super().get(key, default=default) # type: ignore[misc] - return await self._serializer.decode(self, key, raw_value, default=default) # type: ignore[arg-type] - - async def get_many(self, *keys: Key, default: Value | None = None) -> Value: - encoded_values = await super().get_many(*keys, default=default) # type: ignore[misc] - values = [] - for key, value in zip(keys, encoded_values): - deserialized_value = await self._serializer.decode( - backend=self, # type: ignore[arg-type] - key=key, - value=value, - default=default, - ) - values.append(deserialized_value) - return tuple(values) - - async def set( - self, - key: Key, - value: Value, - expire: float | None = None, - exist: bool | None = None, - ): - value = await self._serializer.encode(self, key, value, expire=expire) # type: ignore[arg-type] - return await super().set(key, value, expire=expire, exist=exist) # type: ignore[misc] - - async def set_many(self, pairs: Mapping[Key, Value], expire: float | None = None): - transformed_pairs = {} - for key, value in pairs.items(): - transformed_pairs[key] = await self._serializer.encode(self, key, value, expire) # type: ignore[arg-type] - return await super().set_many(transformed_pairs, expire=expire) # type: ignore[misc] - - def set_raw(self, *args: Any, **kwargs: Any): - return super().set(*args, **kwargs) # type: ignore - - def get_raw(self, *args: Any, **kwargs: Any): - return super().get(*args, **kwargs) # type: ignore - - def _to_bytes(value: str | bytes) -> bytes: if isinstance(value, str): value = value.encode() @@ -152,7 +86,7 @@ class Serializer: def __init__(self, check_repr=False): self._check_repr = check_repr - self._pickler = get_pickler(NULL_PICKLE) + self._pickler = get_pickler(PicklerType.NULL) self._signer = NullSigner() def set_signer(self, signer): @@ -239,3 +173,26 @@ async def bytes_decoder(value: bytes, *args, **kwargs): register_type(bytes, bytes_encoder, bytes_decoder) + + +def get_serializer( + secret: str | bytes | None = None, + digestmod: str | bytes = b"md5", + check_repr: bool = True, + pickle_type: PicklerType | None = None, +) -> Serializer: + _serializer = Serializer(check_repr=check_repr) + if secret: + _serializer.set_signer(HashSigner(secret, digestmod)) + _serializer.set_pickler(_get_pickler(pickle_type or PicklerType.NULL, bool(secret))) + return _serializer + + +def _get_pickler(pickle_type: PicklerType, hash_key: bool) -> Pickler: + if pickle_type is PicklerType.NULL and hash_key: + pickle_type = PicklerType.DEFAULT + return get_pickler(pickle_type) + + +DEFAULT_SERIALIZER = get_serializer(pickle_type=PicklerType.DEFAULT) +NULL_SERIALIZER = get_serializer(pickle_type=PicklerType.NULL) diff --git a/cashews/wrapper/backend_settings.py b/cashews/wrapper/backend_settings.py index 3ff9525a..951c1ca0 100644 --- a/cashews/wrapper/backend_settings.py +++ b/cashews/wrapper/backend_settings.py @@ -6,6 +6,7 @@ from cashews.backends.interface import Backend from cashews.backends.memory import Memory from cashews.exceptions import BackendNotAvailableError +from cashews.picklers import PicklerType if TYPE_CHECKING: # pragma: no cover BackendOrFabric = Union[type[Backend], Callable[..., Backend]] @@ -16,11 +17,16 @@ "rediss": _NO_REDIS_ERROR, "disk": "Disk backend requires `diskcache` to be installed.", } -_BACKENDS: dict[str, tuple[BackendOrFabric, bool]] = {} +_BACKENDS: dict[str, tuple[BackendOrFabric, bool, PicklerType]] = {} -def register_backend(alias: str, backend_class: BackendOrFabric, pass_uri: bool = False) -> None: - _BACKENDS[alias] = (backend_class, pass_uri) +def register_backend( + alias: str, + backend_class: BackendOrFabric, + pass_uri: bool = False, + pickler: PicklerType = PicklerType.NULL, +) -> None: + _BACKENDS[alias] = (backend_class, pass_uri, pickler) register_backend("mem", Memory) @@ -39,8 +45,8 @@ def _redis_fabric(**params) -> Redis | BcastClientSide: return BcastClientSide(**params) return Redis(**params) - register_backend("redis", _redis_fabric, pass_uri=True) - register_backend("rediss", _redis_fabric, pass_uri=True) + register_backend("redis", _redis_fabric, pass_uri=True, pickler=PicklerType.DEFAULT) + register_backend("rediss", _redis_fabric, pass_uri=True, pickler=PicklerType.DEFAULT) try: @@ -50,25 +56,25 @@ def _redis_fabric(**params) -> Redis | BcastClientSide: else: from cashews.backends.diskcache import DiskCache - register_backend("disk", DiskCache) + register_backend("disk", DiskCache, pickler=PicklerType.DEFAULT) -def settings_url_parse(url: str) -> tuple[BackendOrFabric, dict[str, Any]]: +def settings_url_parse(url: str) -> tuple[BackendOrFabric, dict[str, Any], PicklerType]: parse_result = urlparse(url) params: dict[str, Any] = dict(parse_qsl(parse_result.query)) params = _serialize_params(params) alias = parse_result.scheme if alias == "": - return Memory, {"disable": True} + return Memory, {"disable": True}, PicklerType.NULL if alias not in _BACKENDS: error = _CUSTOM_ERRORS.get(alias, f"wrong backend alias {alias}") raise BackendNotAvailableError(error) - backend_class, pass_uri = _BACKENDS[alias] + backend_class, pass_uri, pickler = _BACKENDS[alias] if pass_uri: params["address"] = url.split("?")[0] - return backend_class, params + return backend_class, params, pickler def _serialize_params(params: dict[str, str]) -> dict[str, str | int | bool | float]: diff --git a/cashews/wrapper/commands.py b/cashews/wrapper/commands.py index ad23c791..c218debf 100644 --- a/cashews/wrapper/commands.py +++ b/cashews/wrapper/commands.py @@ -59,19 +59,23 @@ async def get_or_set( else: _default = default await self.set(key, _default, expire=expire) - return default + return _default async def get_raw(self, key: Key) -> Value: return await self._with_middlewares(Command.GET_RAW, key)(key=key) async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Key]: - backend, middlewares = self._get_backend_and_config(pattern) + backend = self._get_backend(pattern) async def call(pattern, batch_size): - return backend.scan(pattern, batch_size=batch_size) + return backend.scan(pattern=pattern, batch_size=batch_size) - for middleware in middlewares: + for middleware in reversed(self._default_middlewares): + call = partial(middleware, call, Command.SCAN, backend) + + for middleware in self._middlewares[backend._id]: call = partial(middleware, call, Command.SCAN, backend) + async for key in await call(pattern=pattern, batch_size=batch_size): yield key @@ -80,10 +84,14 @@ async def get_match( pattern: str, batch_size: int = 100, ) -> AsyncIterator[tuple[Key, Value]]: - backend, middlewares = self._get_backend_and_config(pattern) + backend = self._get_backend(pattern) + middlewares = self._middlewares[backend._id] async def call(pattern, batch_size): - return backend.get_match(pattern, batch_size=batch_size) + return backend.get_match(pattern=pattern, batch_size=batch_size) + + for middleware in reversed(self._default_middlewares): + call = partial(middleware, call, Command.GET_MATCH, backend) for middleware in middlewares: call = partial(middleware, call, Command.GET_MATCH, backend) @@ -159,7 +167,7 @@ async def get_expire(self, key: Key) -> int: return await self._with_middlewares(Command.GET_EXPIRE, key)(key=key) async def exists(self, key: Key) -> bool: - return await self._with_middlewares(Command.EXIST, key)(key=key) + return await self._with_middlewares(Command.EXISTS, key)(key=key) async def set_lock(self, key: Key, value: Value, expire: TTL) -> bool: return await self._with_middlewares(Command.SET_LOCK, key)(key=key, value=value, expire=ttl_to_seconds(expire)) @@ -176,7 +184,7 @@ async def ping(self, message: bytes | None = None) -> bytes: async def get_keys_count(self) -> int: result = 0 - for backend, _ in self._backends.values(): + for backend in self._backends.values(): count = await self._with_middlewares_for_backend( Command.GET_KEYS_COUNT, backend, self._default_middlewares )() @@ -184,7 +192,7 @@ async def get_keys_count(self) -> int: return result async def clear(self) -> None: - for backend, _ in self._backends.values(): + for backend in self._backends.values(): await self._with_middlewares_for_backend(Command.CLEAR, backend, self._default_middlewares)() async def is_locked( diff --git a/cashews/wrapper/disable_control.py b/cashews/wrapper/disable_control.py index a5181f07..81b00c52 100644 --- a/cashews/wrapper/disable_control.py +++ b/cashews/wrapper/disable_control.py @@ -50,4 +50,4 @@ def is_enable(self, *cmds: Command, prefix: str = "") -> bool: @property def is_full_disable(self) -> bool: - return all(backend.is_full_disable for backend, _ in self._backends.values()) + return all(backend.is_full_disable for backend in self._backends.values()) diff --git a/cashews/wrapper/transaction.py b/cashews/wrapper/transaction.py index b615e647..db8812fe 100644 --- a/cashews/wrapper/transaction.py +++ b/cashews/wrapper/transaction.py @@ -1,6 +1,6 @@ from __future__ import annotations -from contextvars import ContextVar +from contextvars import ContextVar, Token from enum import Enum from functools import wraps from typing import TYPE_CHECKING @@ -11,7 +11,7 @@ from .wrapper import Wrapper if TYPE_CHECKING: # pragma: no cover - from cashews._typing import DecoratedFunc, Middleware + from cashews._typing import DecoratedFunc _transaction: ContextVar[Transaction | None] = ContextVar("transaction", default=None) @@ -32,12 +32,12 @@ def set_transaction_timeout(self, timeout: int) -> None: def set_transaction_mode(self, mode: TransactionMode) -> None: self.transaction_mode = mode - def _get_backend_and_config(self, key: str) -> tuple[Backend, tuple[Middleware, ...]]: - backend, config = super()._get_backend_and_config(key) + def _get_backend(self, key: str) -> Backend: + backend = super()._get_backend(key) tx: Transaction | None = _transaction.get() - if not tx: - return backend, config - return tx.wrap(backend), config + if tx: + return tx.wrap(backend) + return backend def transaction( self, mode: TransactionMode | None = None, timeout: float | None = None @@ -48,12 +48,13 @@ def transaction( class TransactionContextDecorator: - __slots__ = ["_mode", "_timeout", "_inner"] + __slots__ = ["_mode", "_timeout", "_inner", "_return_token"] def __init__(self, mode: TransactionMode | None = None, timeout: float | None = None): self._mode = mode self._timeout = timeout self._inner = False + self._return_token: Token | None = None @property def current_tx(self) -> Transaction | None: @@ -67,11 +68,11 @@ async def __aenter__(self) -> Transaction: def start(self) -> Transaction: tx = Transaction(self._mode, self._timeout) - _transaction.set(tx) + self._return_token = _transaction.set(tx) return tx def close(self): - _transaction.set(None) + _transaction.reset(self._return_token) async def __aexit__(self, exc_type, exc_value, exc_tb) -> None: if not self.current_tx or self._inner: @@ -106,12 +107,12 @@ class Transaction: def __init__(self, mode: TransactionMode | None = None, timeout: float | None = None): self._mode = mode self._timeout = timeout - self._backends: dict[int, TransactionBackend] = {} + self._backends: dict[str, TransactionBackend] = {} def wrap(self, backend: Backend) -> Backend: - if id(backend) not in self._backends: - self._backends[id(backend)] = self._get_tx_backend(backend) - return self._backends[id(backend)] + if backend._id not in self._backends: + self._backends[backend._id] = self._get_tx_backend(backend) + return self._backends[backend._id] def _get_tx_backend(self, backend: Backend) -> TransactionBackend: if self._mode == TransactionMode.FAST: diff --git a/cashews/wrapper/wrapper.py b/cashews/wrapper/wrapper.py index c39a27ba..02d17b30 100644 --- a/cashews/wrapper/wrapper.py +++ b/cashews/wrapper/wrapper.py @@ -7,6 +7,8 @@ from cashews.backends.interface import Backend from cashews.commands import Command from cashews.exceptions import NotConfiguredError +from cashews.picklers import PicklerType +from cashews.serialize import get_serializer from .auto_init import create_auto_init from .backend_settings import settings_url_parse @@ -19,7 +21,8 @@ class Wrapper: default_prefix = "" def __init__(self, name: str = ""): - self._backends: dict[str, tuple[Backend, tuple[Middleware, ...]]] = {} + self._backends: dict[str, Backend] = {} + self._middlewares: dict[str, tuple[Middleware, ...]] = {} self._sorted_prefixes: tuple[str, ...] = () self._default_middlewares: list[Middleware] = [ create_auto_init(), @@ -31,19 +34,16 @@ def __init__(self, name: str = ""): def add_middleware(self, middleware: Middleware) -> None: self._default_middlewares.append(middleware) - def _get_backend_and_config(self, key: Key) -> tuple[Backend, tuple[Middleware, ...]]: + def _get_backend(self, key: Key) -> Backend: for prefix in self._sorted_prefixes: if key.startswith(prefix): return self._backends[prefix] self._check_setup() raise NotConfiguredError("Backend for given key not configured") - def _get_backend(self, key: Key) -> Backend: - backend, _ = self._get_backend_and_config(key) - return backend - def _with_middlewares(self, cmd: Command, key: Key): - backend, middlewares = self._get_backend_and_config(key) + backend = self._get_backend(key) + middlewares = [*self._default_middlewares, *self._middlewares[backend._id]] return self._with_middlewares_for_backend(cmd, backend, middlewares) def _with_middlewares_for_backend(self, cmd: Command, backend, middlewares): @@ -59,12 +59,18 @@ def setup( prefix: str = default_prefix, **kwargs, ) -> Backend: - backend_class, params = settings_url_parse(settings_url) + backend_class, params, pickle_type = settings_url_parse(settings_url) params.update(kwargs) disable = params.pop("disable") if "disable" in params else not params.pop("enable", True) - backend = backend_class(**params) + serializer = get_serializer( + secret=params.pop("secret", None), + digestmod=params.pop("digestmod", b"md5"), + check_repr=params.pop("check_repr", True), + pickle_type=PicklerType(params.pop("pickle_type", pickle_type)), + ) + backend = backend_class(**params, serializer=serializer) if disable: backend.disable() self._add_backend(backend, middlewares, prefix) @@ -78,22 +84,20 @@ def _check_setup(self) -> None: raise NotConfiguredError("run `cache.setup(...)` before using cache") def _add_backend(self, backend: Backend, middlewares=(), prefix: str = default_prefix) -> None: - self._backends[prefix] = ( - backend, - middlewares + tuple(self._default_middlewares), - ) + self._backends[prefix] = backend + self._middlewares[backend._id] = middlewares self._sorted_prefixes = tuple(sorted(self._backends.keys(), reverse=True)) async def init(self, *args, **kwargs) -> None: if args or kwargs: self.setup(*args, **kwargs) - for backend, _ in self._backends.values(): + for backend in self._backends.values(): await backend.init() @property def is_init(self) -> bool: - return all(backend.is_init for backend, _ in self._backends.values()) + return all(backend.is_init for backend in self._backends.values()) async def close(self) -> None: - for backend, _ in self._backends.values(): + for backend in self._backends.values(): await backend.close() diff --git a/examples/keys.py b/examples/keys.py index 4674a91e..7e7b587a 100644 --- a/examples/keys.py +++ b/examples/keys.py @@ -6,7 +6,7 @@ @cache(ttl="10m", prefix="auto") -async def auto_key(foo, **kwargs): +async def auto_key(foo, *ar, test="val", **all): return @@ -35,7 +35,10 @@ def _human(value, upper=False): for char in value: if res: res += "-" - res += INT_TO_STR_MAP.get(char) + if char in INT_TO_STR_MAP: + res += INT_TO_STR_MAP.get(char) + else: + res += char if upper: return res.upper() return res @@ -48,8 +51,10 @@ async def key_with_format(foo): async def main(): await _call(auto_key, "fooval", key="test") - await _call(manual_key, "fooval", key="test") - await _call(key_with_format, 521) + await _call(auto_key, "fooval", test="my", key="test") + await _call(auto_key, foo="fooval", test="my", key="test") + # await _call(manual_key, "fooval", key="test") + # await _call(key_with_format, 521) async def _call(function, *args, **kwargs): @@ -57,11 +62,12 @@ async def _call(function, *args, **kwargs): with cache.detect as detector: await function(*args, **kwargs) key = list(detector.calls.keys())[-1] - + template = detector.calls[key]["template"] print( f""" function "{function.__name__}" called with args={args} and kwargs={kwargs} the key: {key} + the template: {template} """ ) diff --git a/examples/simple.py b/examples/simple.py index a1faf497..ab60fc5b 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -6,7 +6,7 @@ cache.setup( "redis://0.0.0.0/2", client_name=None, - hash_key="test", + secret="test", digestmod="md5", middlewares=(add_prefix("test:"), all_keys_lower()), ) @@ -23,11 +23,12 @@ async def basic(): await cache.set_many({"key2": "test", "key3": Decimal("10.1")}, expire="1m") print("Get: ", await cache.get("key1")) # -> Any + print("Get: ", await cache.get("key2")) # -> Any async for key in cache.scan("key*"): print("Scan:", key) # -> Any - async for key, value in cache.get_match("key*"): + async for key, value in cache.get_match("*"): print("Get match:", key, value) # -> Any print("Get many:", await cache.get_many("key2", "key3")) # -> Any diff --git a/pytest.ini b/pytest.ini index 7459e6ba..06a3ac74 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,5 @@ [pytest] -addopts = --disable-pytest-warnings -vv +addopts = --disable-pytest-warnings -vv --showlocals testpaths = tests markers = redis: mark test as requiring `redis` library diff --git a/tests/conftest.py b/tests/conftest.py index ce2963ee..29555abb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,6 @@ import random from typing import TYPE_CHECKING from unittest.mock import Mock -from uuid import uuid4 import pytest @@ -52,7 +51,6 @@ def factory(backend_cls: type[Backend], *args, **kwargs): "memory", "transactional", pytest.param("redis", marks=pytest.mark.redis), - pytest.param("redis_hash", marks=pytest.mark.redis), pytest.param("redis_cs", marks=pytest.mark.redis), pytest.param("diskcache", marks=pytest.mark.diskcache), ], @@ -68,31 +66,17 @@ async def _backend(request, redis_dsn, backend_factory): backend = backend_factory( Redis, redis_dsn, - secret=None, max_connections=20, suppress=False, socket_timeout=1, wait_for_connection_timeout=1, ) - elif request.param == "redis_hash": - from cashews.backends.redis import Redis - - backend = backend_factory( - Redis, - redis_dsn, - secret=uuid4().hex, - max_connections=20, - suppress=False, - socket_timeout=10, - wait_for_connection_timeout=1, - ) elif request.param == "redis_cs": from cashews.backends.redis.client_side import BcastClientSide backend = backend_factory( BcastClientSide, redis_dsn, - secret=None, max_connections=5, suppress=False, socket_timeout=0.1, diff --git a/tests/test_add_prefix.py b/tests/test_add_prefix.py index 48d54569..8873dbe2 100644 --- a/tests/test_add_prefix.py +++ b/tests/test_add_prefix.py @@ -1,3 +1,5 @@ +from unittest.mock import ANY + import pytest from cashews import Cache @@ -6,7 +8,7 @@ @pytest.fixture(autouse=True) def _add_prefix(cache: Cache, target): - cache._add_backend(target, (add_prefix("prefix!"),)) + cache.add_middleware(add_prefix("prefix!")) async def test_add_prefix_get(cache: Cache, target): @@ -18,7 +20,7 @@ async def test_add_prefix_set(cache: Cache, target): await cache.set(key="key", value="value") target.set.assert_called_once_with( key="prefix!key", - value="value", + value=ANY, exist=None, expire=None, ) @@ -36,7 +38,7 @@ async def test_add_prefix_get_many(cache: Cache, target): async def test_add_prefix_set_many(cache: Cache, target): await cache.set_many({"key": "value"}) - target.set_many.assert_called_once_with(pairs={"prefix!key": "value"}, expire=None) + target.set_many.assert_called_once_with(pairs={"prefix!key": ANY}, expire=None) async def test_add_prefix_delete(cache: Cache, target): diff --git a/tests/test_cache.py b/tests/test_cache.py index 0bcbafc7..3ce55744 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -177,7 +177,7 @@ async def func(resp=b"ok"): async def test_early_cache_no_background(cache: Cache): mock = Mock() - @cache.early(ttl=EXPIRE, key="key", background=False) + @cache.early(ttl=EXPIRE, early_ttl=0.01, key="key", background=False) async def func(resp=b"ok"): mock() return resp diff --git a/tests/test_client_side_cache.py b/tests/test_client_side_cache.py index 94bf8883..0c5b9bbc 100644 --- a/tests/test_client_side_cache.py +++ b/tests/test_client_side_cache.py @@ -17,7 +17,7 @@ def _create_cache(redis_dsn, backend_factory): from cashews.backends.redis.client_side import BcastClientSide async def call(local_cache=None): - backend = backend_factory(BcastClientSide, redis_dsn, secret=None, local_cache=local_cache) + backend = backend_factory(BcastClientSide, redis_dsn, local_cache=local_cache) await backend.init() await backend.clear() return backend @@ -58,12 +58,12 @@ async def test_set_none_bcast(create_cache): assert await caches_local.exists("key") assert await caches.get("key") is None - await cachef.set("key", None, expire=10000) + await cachef.set("key", "val", expire=10000) await asyncio.sleep(0.01) # skip init signal about invalidation assert await cachef.exists("key") assert await cachef_local.exists("key") - assert await caches.get("key") is None + assert await caches.get("key") == "val" assert await caches.exists("key") assert await caches_local.exists("key") diff --git a/tests/test_intergations/test_fastapi.py b/tests/test_intergations/test_fastapi.py index 6b55bd8c..8116747f 100644 --- a/tests/test_intergations/test_fastapi.py +++ b/tests/test_intergations/test_fastapi.py @@ -1,5 +1,6 @@ from contextlib import asynccontextmanager, contextmanager from random import random +from unittest.mock import Mock import pytest @@ -95,7 +96,10 @@ def test_cache(client): def test_cache_stream(client, app, cache): from starlette.responses import StreamingResponse + call = Mock() + def iterator(): + call() for i in range(10): yield f"{i}" @@ -110,6 +114,7 @@ async def stream(): assert response.content == b"0123456789" response = client.get("/stream") + call.assert_called_once() assert response.status_code == 201 assert response.headers["X-Test"] == "TRUE" assert response.content == b"0123456789" diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 3471b8d7..ee3f6598 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,39 +1,40 @@ +from unittest import mock + +import pytest + from cashews import Cache -from cashews.helpers import all_keys_lower, memory_limit +from cashews.helpers import memory_limit -async def test_all_keys_lower(cache: Cache, target): - cache._add_backend(target, (all_keys_lower(),)) - await cache.get(key="KEY") - target.get.assert_called_once_with(key="key", default=None) +@pytest.mark.parametrize( + ("min_bytes", "max_bytes", "size", "called"), + ( + (10, 80, 11, True), + (10, 80, 10, True), + (10, 80, 80, False), + (1, 11, 10, False), + (10, 1, 80, False), + ), +) +async def test_memory_limit_set(cache: Cache, target, min_bytes, max_bytes, size, called): + cache.add_middleware(memory_limit(min_bytes=min_bytes, max_bytes=max_bytes)) - await cache.set(key="KEY", value="value") - target.set.assert_called_once_with( - key="key", - value="value", - exist=None, - expire=None, - ) - await cache.set_many({"KEY": "value"}) - target.set_many.assert_called_once_with( - pairs={"key": "value"}, - expire=None, - ) - await cache.ping() - target.ping.assert_called_once_with(message=b"PING") + await cache.set(key="key", value="v" * size) + if called: + target.set.assert_called_once_with(key="key", value=mock.ANY, expire=None, exist=None) + else: + target.set.assert_not_called() -async def test_memory_limit(cache: Cache, target): - cache._add_backend(target, (memory_limit(min_bytes=52, max_bytes=75),)) - await cache.set(key="key", value="v") - target.set.assert_not_called() +async def test_memory_limit_set_many(cache: Cache, target): + cache.add_middleware(memory_limit(min_bytes=52, max_bytes=75)) - await cache.set(key="key", value="v" * 35) - target.set.assert_not_called() + await cache.set_many({"key": "v" * 35}) + target.set_many.assert_not_called() - await cache.set(key="key", value="v" * 15) - target.set.assert_called_once() + await cache.set_many({"key": "v" * 35, "key2": "v"}) + target.set_many.assert_not_called() - await cache.ping() - target.ping.assert_called_once_with(message=b"PING") + await cache.set_many({"key": "v" * 35, "key2": "v", "key3": "v" * 15}) + target.set_many.assert_called_once_with(pairs={"key3": mock.ANY}, expire=None) diff --git a/tests/test_pickle_serializer.py b/tests/test_pickle_serializer.py index 75e34f7f..6a2be8a0 100644 --- a/tests/test_pickle_serializer.py +++ b/tests/test_pickle_serializer.py @@ -8,7 +8,8 @@ from hypothesis import strategies as st from cashews.backends.memory import Memory -from cashews.serialize import UnSecureDataError +from cashews.picklers import PicklerType +from cashews.serialize import UnSecureDataError, get_serializer @dataclasses.dataclass() @@ -37,13 +38,13 @@ async def _cache(request, redis_dsn): if pickle_type == "redis": from cashews.backends.redis import Redis - redis = Redis(redis_dsn, secret="test", suppress=False, digestmod=digestmod) + redis = Redis(redis_dsn, suppress=False, serializer=get_serializer(secret=b"test", digestmod=digestmod)) await redis.init() await redis.clear() yield redis await redis.close() else: - yield Memory(secret=b"test", digestmod=digestmod, pickle_type=pickle_type) + yield Memory(serializer=get_serializer(secret=b"test", digestmod=digestmod)) @pytest.mark.parametrize( @@ -217,10 +218,7 @@ async def test_replace_values(cache): async def test_pickle_error_value(cache): - await cache.set_raw( - "key", - cache._serializer._signer.sign("key", b"no_pickle_data"), - ) + await cache.set_raw("key", b"nopickledata") assert await cache.get("key", default="default") == "default" @@ -264,29 +262,29 @@ async def test_get_set_raw(cache): @settings(max_examples=500) @example(key="_key:_!@#$%^&*()", value='_value:_!@_#$%^&:*(?)".,4ะน') async def test_no_hash(key, value): - cache = Memory() + cache = Memory(serializer=get_serializer()) await cache.set(key, value) assert await cache.get(key) == value async def test_cache_from_hash_to_no_hash(): val = Decimal("10.2") - cache = Memory(secret="test") + cache = Memory(serializer=get_serializer(secret="test")) await cache.set("key", val) assert await cache.get("key") == val - cache_no_hash = Memory() + cache_no_hash = Memory(serializer=get_serializer()) cache_no_hash.store = cache.store assert await cache_no_hash.get("key", default="default") == "default" async def test_cache_from_no_hash_to_hash(): val = Decimal("10.2") - cache = Memory() + cache = Memory(serializer=get_serializer()) await cache.set("key", val) assert await cache.get("key") == val - cache_hash = Memory(secret="test") + cache_hash = Memory(serializer=get_serializer(secret="test")) cache_hash.store = cache.store assert await cache_hash.get("key") == val @@ -303,6 +301,6 @@ async def test_cache_from_no_hash_to_hash(): ), ) async def test_json_serialize(value): - cache = Memory(secret=b"test", pickle_type="json") + cache = Memory(serializer=get_serializer(secret=b"test", pickle_type=PicklerType.JSON)) await cache.set("key", value) assert await cache.get("key") == value diff --git a/tests/test_redis_down.py b/tests/test_redis_down.py index feb9c7fc..38902ddb 100644 --- a/tests/test_redis_down.py +++ b/tests/test_redis_down.py @@ -20,7 +20,7 @@ async def test_safe_redis(redis_backend): await redis.init() assert await redis.set("test", "test") is False - assert await redis.set_raw("test", "test") is False + assert await redis.set_raw("test", "test") is None assert await redis.set_lock("test", "test", 1) is False assert await redis.unlock("test", "test") is None @@ -37,7 +37,7 @@ async def test_safe_redis(redis_backend): assert await redis.get_expire("test") == 0 assert await redis.incr("test") is None assert await redis.slice_incr("test", 1, 2, 10) is None - assert await redis.get_size("test") == 0 + async for _ in redis.scan("*"): raise AssertionError() diff --git a/tests/test_settings_url.py b/tests/test_settings_url.py index 5f5cd865..5ac1cad5 100644 --- a/tests/test_settings_url.py +++ b/tests/test_settings_url.py @@ -16,7 +16,7 @@ ), ) def test_url(url, params): - backend_class, _params = settings_url_parse(url) + backend_class, _params, _ = settings_url_parse(url) assert backend_class is Memory assert params == _params @@ -85,7 +85,7 @@ def test_url_but_backend_dependency_is_not_installed(url, error): def test_url_with_redis_as_backend(url, params): from cashews.backends.redis import Redis - backend_class, _params = settings_url_parse(url) + backend_class, _params, _ = settings_url_parse(url) assert isinstance(backend_class(**params), Redis) assert params == _params @@ -105,6 +105,6 @@ def test_url_with_redis_as_backend(url, params): def test_url_with_diskcache_as_backend(url, params): from cashews.backends.diskcache import DiskCache - backend_class, _params = settings_url_parse(url) + backend_class, _params, _ = settings_url_parse(url) assert backend_class is DiskCache assert params == _params diff --git a/tests/test_tags_feature.py b/tests/test_tags_feature.py index a1404b93..2a2b768f 100644 --- a/tests/test_tags_feature.py +++ b/tests/test_tags_feature.py @@ -218,14 +218,13 @@ async def func(a): async def test_delete_tags_separate_backend(cache: Cache, redis_dsn: str): tag_backend = cache.setup_tags_backend(redis_dsn) tag_backend.set_pop = AsyncMock(side_effect=[["key", "key2"], []]) - tag_backend.init = AsyncMock(wraps=tag_backend.init) cache.register_tag(tag="tag", key_template="key") + await cache.init() await cache.delete_tags("tag") tag_backend.set_pop.assert_awaited_with(key="_tag:tag", count=100) - tag_backend.init.assert_awaited_once() await tag_backend.close() diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index 4a486e80..c29092bf 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import Mock, PropertyMock +from unittest.mock import ANY, Mock, PropertyMock import pytest @@ -51,7 +51,9 @@ async def _set(): type(target).is_init = PropertyMock(side_effect=lambda: init) target.init.side_effect = set_init - cache._backends[""] = (target, (create_auto_init(),)) + cache._backends[""] = target + cache._middlewares[target._id] = [create_auto_init()] + await asyncio.gather(cache.ping(), cache.ping(), cache.get("test")) target.init.assert_called_once() @@ -60,13 +62,13 @@ async def test_smoke_cmds(cache: Cache, target: Mock): await cache.set(key="key", value={"any": True}, expire=60, exist=None) target.set.assert_called_once_with( key="key", - value={"any": True}, + value=ANY, expire=60, exist=None, ) - await cache.set_raw(key="key2", value="value", expire=60) - target.set_raw.assert_called_once_with(key="key2", value="value", expire=60) + await cache.set_raw(key="key2", value="value") + target.set_raw.assert_called_once_with(key="key2", value="value") await cache.get("key") # -> Any target.get.assert_called_once_with(key="key", default=None) @@ -75,7 +77,7 @@ async def test_smoke_cmds(cache: Cache, target: Mock): target.get_raw.assert_called_once_with(key="key") await cache.set_many({"key1": "value1", "key2": "value2"}, expire=60) - target.set_many.assert_called_once_with(pairs={"key1": "value1", "key2": "value2"}, expire=60) + target.set_many.assert_called_once_with(pairs={"key1": ANY, "key2": ANY}, expire=60) await cache.get_many("key1", "key2") target.get_many.assert_called_once_with("key1", "key2", default=None) @@ -124,10 +126,10 @@ async def test_smoke_cmds(cache: Cache, target: Mock): await cache.set("key", "value") assert [key async for key in cache.scan("key*")] == ["key"] - target.scan.assert_called_once_with("key*", batch_size=100) + target.scan.assert_called_once_with(pattern="key*", batch_size=100) assert [key_value async for key_value in cache.get_match("key*")] == [("key", "value")] - target.get_match.assert_called_once_with("key*", batch_size=100) + target.get_match.assert_called_once_with(pattern="key*", batch_size=100) await cache.get_size("key") target.get_size.assert_called_once_with(key="key")