diff --git a/telebot/asyncio_storage/memory_storage.py b/telebot/asyncio_storage/memory_storage.py index e65ed74d4..661cc35e9 100644 --- a/telebot/asyncio_storage/memory_storage.py +++ b/telebot/asyncio_storage/memory_storage.py @@ -71,7 +71,7 @@ async def set_data( ) if self.data.get(_key) is None: - return False + raise RuntimeError(f"MemoryStorage: key {_key} does not exist.") self.data[_key]["data"][key] = value return True @@ -85,7 +85,7 @@ async def get_data( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - return self.data.get(_key, {}).get("data", None) + return self.data.get(_key, {}).get("data", {}) async def reset_data( self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, diff --git a/telebot/asyncio_storage/pickle_storage.py b/telebot/asyncio_storage/pickle_storage.py index 0c7da7eb1..9a8c9eead 100644 --- a/telebot/asyncio_storage/pickle_storage.py +++ b/telebot/asyncio_storage/pickle_storage.py @@ -98,7 +98,7 @@ async def set_data(self, chat_id: int, user_id: int, key: str, value: Union[str, state_data = data.get(_key, {}) state_data["data"][key] = value if _key not in data: - data[_key] = {"state": None, "data": state_data} + raise RuntimeError(f"StatePickleStorage: key {_key} does not exist.") else: data[_key]["data"][key] = value await self._write_to_file(data) diff --git a/telebot/asyncio_storage/redis_storage.py b/telebot/asyncio_storage/redis_storage.py index a6b19780d..b07bd4159 100644 --- a/telebot/asyncio_storage/redis_storage.py +++ b/telebot/asyncio_storage/redis_storage.py @@ -86,7 +86,7 @@ async def set_data(self, pipe, chat_id: int, user_id: int, key: str, value: Unio data = await pipe.execute() data = data[0] if data is None: - await pipe.hset(_key, "data", json.dumps({key: value})) + raise RuntimeError(f"StateRedisStorage: key {_key} does not exist.") else: data = json.loads(data) data[key] = value diff --git a/telebot/storage/memory_storage.py b/telebot/storage/memory_storage.py index 82142430a..11acdc119 100644 --- a/telebot/storage/memory_storage.py +++ b/telebot/storage/memory_storage.py @@ -71,7 +71,7 @@ def set_data( ) if self.data.get(_key) is None: - return False + raise RuntimeError(f"StateMemoryStorage: key {_key} does not exist.") self.data[_key]["data"][key] = value return True @@ -85,7 +85,7 @@ def get_data( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - return self.data.get(_key, {}).get("data", None) + return self.data.get(_key, {}).get("data", {}) def reset_data( self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, diff --git a/telebot/storage/pickle_storage.py b/telebot/storage/pickle_storage.py index 32a9653c7..b449c17f4 100644 --- a/telebot/storage/pickle_storage.py +++ b/telebot/storage/pickle_storage.py @@ -1,12 +1,18 @@ import os import pickle import threading -from typing import Optional, Union +from typing import Optional, Union, Callable from telebot.storage.base_storage import StateStorageBase, StateDataContext +def with_lock(func: Callable) -> Callable: + def wrapper(self, *args, **kwargs): + with self.lock: + return func(self, *args, **kwargs) + return wrapper + class StatePickleStorage(StateStorageBase): - def __init__(self, file_path: str="./.state-save/states.pkl", - prefix='telebot', separator: Optional[str]=":") -> None: + def __init__(self, file_path: str = "./.state-save/states.pkl", + prefix='telebot', separator: Optional[str] = ":") -> None: self.file_path = file_path self.prefix = prefix self.separator = separator @@ -32,98 +38,98 @@ def create_dir(self): with open(self.file_path,'wb') as file: pickle.dump({}, file) + @with_lock def set_state(self, chat_id: int, user_id: int, state: str, - business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, - bot_id: Optional[int]=None) -> bool: + business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None) -> bool: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - if _key not in data: - data[_key] = {"state": state, "data": {}} - else: - data[_key]["state"] = state - self._write_to_file(data) + data = self._read_from_file() + if _key not in data: + data[_key] = {"state": state, "data": {}} + else: + data[_key]["state"] = state + self._write_to_file(data) return True - def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Union[str, None]: + @with_lock + def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Union[str, None]: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - return data.get(_key, {}).get("state") + data = self._read_from_file() + return data.get(_key, {}).get("state") - def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: + @with_lock + def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - if _key in data: - del data[_key] - self._write_to_file(data) - return True - return False + data = self._read_from_file() + if _key in data: + del data[_key] + self._write_to_file(data) + return True + return False + @with_lock def set_data(self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], - business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, - bot_id: Optional[int]=None) -> bool: + business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None) -> bool: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - state_data = data.get(_key, {}) - state_data["data"][key] = value - if _key not in data: - data[_key] = {"state": None, "data": state_data} - else: - data[_key]["data"][key] = value - self._write_to_file(data) + data = self._read_from_file() + state_data = data.get(_key, {}) + state_data["data"][key] = value + + if _key not in data: + raise RuntimeError(f"PickleStorage: key {_key} does not exist.") + + self._write_to_file(data) return True - def get_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> dict: + @with_lock + def get_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> dict: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - return data.get(_key, {}).get("data", {}) + data = self._read_from_file() + return data.get(_key, {}).get("data", {}) - def reset_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: + @with_lock + def reset_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - if _key in data: - data[_key]["data"] = {} - self._write_to_file(data) - return True - return False + data = self._read_from_file() + if _key in data: + data[_key]["data"] = {} + self._write_to_file(data) + return True + return False - def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Optional[dict]: + def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Optional[dict]: return StateDataContext( self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, message_thread_id=message_thread_id, bot_id=bot_id ) - def save(self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: + @with_lock + def save(self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - data[_key]["data"] = data - self._write_to_file(data) + data = self._read_from_file() + data[_key]["data"] = data + self._write_to_file(data) return True def __str__(self) -> str: diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index d41da05c8..f21d50fe1 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -85,7 +85,7 @@ def set_data_action(pipe): data = pipe.hget(_key, "data") data = data.execute()[0] if data is None: - pipe.hset(_key, "data", json.dumps({key: value})) + raise RuntimeError(f"RedisStorage: key {_key} does not exist.") else: data = json.loads(data) data[key] = value