Skip to content

Commit

Permalink
Added statesv2
Browse files Browse the repository at this point in the history
  • Loading branch information
coder2020official committed Jun 23, 2024
1 parent ab2dca8 commit fce1c3d
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 74 deletions.
113 changes: 99 additions & 14 deletions telebot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import threading
import time
import traceback
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Union, Dict

# these imports are used to avoid circular import error
import telebot.util
Expand Down Expand Up @@ -168,7 +168,8 @@ def __init__(
disable_notification: Optional[bool]=None,
protect_content: Optional[bool]=None,
allow_sending_without_reply: Optional[bool]=None,
colorful_logs: Optional[bool]=False
colorful_logs: Optional[bool]=False,
token_check: Optional[bool]=True
):

# update-related
Expand All @@ -186,6 +187,11 @@ def __init__(
self.webhook_listener = None
self._user = None

# token check
if token_check:
self._user = self.get_me()
self.bot_id = self._user.id

# logs-related
if colorful_logs:
try:
Expand Down Expand Up @@ -280,6 +286,8 @@ def __init__(
self.threaded = threaded
if self.threaded:
self.worker_pool = util.ThreadPool(self, num_threads=num_threads)



@property
def user(self) -> types.User:
Expand Down Expand Up @@ -6572,7 +6580,9 @@ def setup_middleware(self, middleware: BaseMiddleware):
self.middlewares.append(middleware)


def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Optional[int]=None) -> None:
def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Optional[int]=None,
business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None,
bot_id: Optional[int]=None) -> None:
"""
Sets a new state of a user.
Expand All @@ -6591,14 +6601,29 @@ def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Option
:param chat_id: Chat's identifier
:type chat_id: :obj:`int`
:param bot_id: Bot's identifier
:type bot_id: :obj:`int`
:param business_connection_id: Business identifier
:type business_connection_id: :obj:`str`
:param message_thread_id: Identifier of the message thread
:type message_thread_id: :obj:`int`
:return: None
"""
if chat_id is None:
chat_id = user_id
self.current_states.set_state(chat_id, user_id, state)
if bot_id is None:
bot_id = self.bot_id
self.current_states.set_state(
chat_id=chat_id, user_id=user_id, state=state, bot_id=bot_id,
business_connection_id=business_connection_id, message_thread_id=message_thread_id)


def reset_data(self, user_id: int, chat_id: Optional[int]=None):
def reset_data(self, user_id: int, chat_id: Optional[int]=None,
business_connection_id: Optional[str]=None,
message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> None:
"""
Reset data for a user in chat.
Expand All @@ -6608,14 +6633,27 @@ def reset_data(self, user_id: int, chat_id: Optional[int]=None):
:param chat_id: Chat's identifier
:type chat_id: :obj:`int`
:param bot_id: Bot's identifier
:type bot_id: :obj:`int`
:param business_connection_id: Business identifier
:type business_connection_id: :obj:`str`
:param message_thread_id: Identifier of the message thread
:type message_thread_id: :obj:`int`
:return: None
"""
if chat_id is None:
chat_id = user_id
self.current_states.reset_data(chat_id, user_id)
if bot_id is None:
bot_id = self.bot_id
self.current_states.reset_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
business_connection_id=business_connection_id, message_thread_id=message_thread_id)


def delete_state(self, user_id: int, chat_id: Optional[int]=None) -> None:
def delete_state(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None,
message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> None:
"""
Delete the current state of a user.
Expand All @@ -6629,10 +6667,14 @@ def delete_state(self, user_id: int, chat_id: Optional[int]=None) -> None:
"""
if chat_id is None:
chat_id = user_id
self.current_states.delete_state(chat_id, user_id)
if bot_id is None:
bot_id = self.bot_id
self.current_states.delete_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
business_connection_id=business_connection_id, message_thread_id=message_thread_id)


def retrieve_data(self, user_id: int, chat_id: Optional[int]=None) -> Optional[Any]:
def retrieve_data(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None,
message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Optional[Dict[str, Any]]:
"""
Returns context manager with data for a user in chat.
Expand All @@ -6642,15 +6684,30 @@ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None) -> Optional[A
:param chat_id: Chat's unique identifier, defaults to user_id
:type chat_id: int, optional
:param bot_id: Bot's identifier
:type bot_id: int, optional
:param business_connection_id: Business identifier
:type business_connection_id: str, optional
:param message_thread_id: Identifier of the message thread
:type message_thread_id: int, optional
:return: Context manager with data for a user in chat
:rtype: Optional[Any]
"""
if chat_id is None:
chat_id = user_id
return self.current_states.get_interactive_data(chat_id, user_id)
if bot_id is None:
bot_id = self.bot_id
return self.current_states.get_interactive_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
business_connection_id=business_connection_id,
message_thread_id=message_thread_id)


def get_state(self, user_id: int, chat_id: Optional[int]=None) -> Optional[Union[int, str, State]]:
def get_state(self, user_id: int, chat_id: Optional[int]=None,
business_connection_id: Optional[str]=None,
message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Union[int, str]:
"""
Gets current state of a user.
Not recommended to use this method. But it is ok for debugging.
Expand All @@ -6661,15 +6718,31 @@ def get_state(self, user_id: int, chat_id: Optional[int]=None) -> Optional[Union
:param chat_id: Chat's identifier
:type chat_id: :obj:`int`
:param bot_id: Bot's identifier
:type bot_id: :obj:`int`
:param business_connection_id: Business identifier
:type business_connection_id: :obj:`str`
:param message_thread_id: Identifier of the message thread
:type message_thread_id: :obj:`int`
:return: state of a user
:rtype: :obj:`int` or :obj:`str` or :class:`telebot.types.State`
"""
if chat_id is None:
chat_id = user_id
return self.current_states.get_state(chat_id, user_id)
if bot_id is None:
bot_id = self.bot_id
return self.current_states.get_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
business_connection_id=business_connection_id, message_thread_id=message_thread_id)


def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs):
def add_data(self, user_id: int, chat_id: Optional[int]=None,
business_connection_id: Optional[str]=None,
message_thread_id: Optional[int]=None,
bot_id: Optional[int]=None,
**kwargs) -> None:
"""
Add data to states.
Expand All @@ -6679,13 +6752,25 @@ def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs):
:param chat_id: Chat's identifier
:type chat_id: :obj:`int`
:param bot_id: Bot's identifier
:type bot_id: :obj:`int`
:param business_connection_id: Business identifier
:type business_connection_id: :obj:`str`
:param message_thread_id: Identifier of the message thread
:type message_thread_id: :obj:`int`
:param kwargs: Data to add
:return: None
"""
if chat_id is None:
chat_id = user_id
if bot_id is None:
bot_id = self.bot_id
for key, value in kwargs.items():
self.current_states.set_data(chat_id, user_id, key, value)
self.current_states.set_data(chat_id=chat_id, user_id=user_id, key=key, value=value, bot_id=bot_id,
business_connection_id=business_connection_id, message_thread_id=message_thread_id)


def register_next_step_handler_by_chat_id(
Expand Down
16 changes: 12 additions & 4 deletions telebot/custom_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@




class SimpleCustomFilter(ABC):
"""
Simple Custom Filter base class.
Expand Down Expand Up @@ -417,8 +418,6 @@ def check(self, message, text):
user_id = message.from_user.id
message = message.message




if isinstance(text, list):
new_text = []
Expand All @@ -430,15 +429,24 @@ def check(self, message, text):
text = text.name

if message.chat.type in ['group', 'supergroup']:
group_state = self.bot.current_states.get_state(chat_id, user_id)
group_state = self.bot.current_states.get_state(chat_id=chat_id, user_id=user_id, business_connection_id=message.business_connection_id, bot_id=self.bot._user.id,
message_thread_id=message.message_thread_id)
if group_state is None and not message.is_topic_message: # needed for general topic and group messages
group_state = self.bot.current_states.get_state(chat_id=chat_id, user_id=user_id, business_connection_id=message.business_connection_id, bot_id=self.bot._user.id)

if group_state == text:
return True
elif type(text) is list and group_state in text:
return True


else:
user_state = self.bot.current_states.get_state(chat_id, user_id)
user_state = self.bot.current_states.get_state(
chat_id=chat_id,
user_id=user_id,
business_connection_id=message.business_connection_id,
bot_id=self.bot._user.id
)
if user_state == text:
return True
elif type(text) is list and user_state in text:
Expand Down
40 changes: 37 additions & 3 deletions telebot/storage/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,56 @@ def get_interactive_data(self, chat_id, user_id):

def save(self, chat_id, user_id, data):
raise NotImplementedError

def convert_params_to_key(
self,
chat_id: int,
user_id: int,
prefix: str,
separator: str,
business_connection_id: str=None,
message_thread_id: int=None,
bot_id: int=None
) -> str:
"""
Convert parameters to a key.
"""
params = [prefix]
if bot_id:
params.append(str(bot_id))
if business_connection_id:
params.append(business_connection_id)
if message_thread_id:
params.append(str(message_thread_id))
params.append(str(chat_id))
params.append(str(user_id))

return separator.join(params)






class StateContext:
"""
Class for data.
"""
def __init__(self , obj, chat_id, user_id) -> None:
def __init__(self , obj, chat_id, user_id, business_connection_id=None, message_thread_id=None, bot_id=None, ):
self.obj = obj
self.data = copy.deepcopy(obj.get_data(chat_id, user_id))
res = obj.get_data(chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id,
message_thread_id=message_thread_id, bot_id=bot_id)
self.data = copy.deepcopy(res)
self.chat_id = chat_id
self.user_id = user_id
self.bot_id = bot_id
self.business_connection_id = business_connection_id
self.message_thread_id = message_thread_id



def __enter__(self):
return self.data

def __exit__(self, exc_type, exc_val, exc_tb):
return self.obj.save(self.chat_id, self.user_id, self.data)
return self.obj.save(self.chat_id, self.user_id, self.data, self.business_connection_id, self.message_thread_id, self.bot_id)
Loading

0 comments on commit fce1c3d

Please sign in to comment.