diff --git a/docs/source/releases/changes-in-this-fork.rst b/docs/source/releases/changes-in-this-fork.rst index 918747447..7044bdc32 100644 --- a/docs/source/releases/changes-in-this-fork.rst +++ b/docs/source/releases/changes-in-this-fork.rst @@ -14,6 +14,7 @@ If you found any issue or have any suggestions, feel free to make `an issue `__) - Added the :meth:`~pyrogram.Client.delete_account`, :meth:`~pyrogram.Client.transfer_chat_ownership`, :meth:`~pyrogram.Client.update_status` (`#49 `__, `#51 `__) - Added the class :obj:`~pyrogram.types.RefundedPayment`, containing information about a refunded payment. - Added the field ``refunded_payment`` to the class :obj:`~pyrogram.types.Message`, describing a service message about a refunded payment. diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 88c8b4469..cfd122c15 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -20,7 +20,7 @@ import bisect import logging import os -from datetime import datetime, timedelta +from time import time from hashlib import sha1 from io import BytesIO from typing import Optional @@ -49,14 +49,15 @@ def __init__(self): class Session: - START_TIMEOUT = 2 + START_TIMEOUT = 5 WAIT_TIMEOUT = 15 + REART_TIMEOUT = 5 SLEEP_THRESHOLD = 10 MAX_RETRIES = 10 ACKS_THRESHOLD = 10 PING_INTERVAL = 5 STORED_MSG_IDS_MAX_SIZE = 1000 * 2 - RECONNECT_THRESHOLD = timedelta(seconds=10) + RECONNECT_THRESHOLD = 13 TRANSPORT_ERRORS = { 404: "auth key not found", @@ -110,10 +111,17 @@ def __init__( self.loop = asyncio.get_event_loop() + self.instant_stop = False # set internally self.last_reconnect_attempt = None + self.currently_restarting = False + self.currently_stopping = False async def start(self): while True: + if self.instant_stop: + log.info("session init stopped") + return # stop instantly + self.connection = self.client.connection_factory( dc_id=self.dc_id, test_mode=self.test_mode, @@ -173,46 +181,98 @@ async def start(self): log.info("Session started") - async def stop(self): - self.is_started.clear() - - self.stored_msg_ids.clear() - - self.ping_task_event.set() - - if self.ping_task is not None: - await self.ping_task - - self.ping_task_event.clear() + async def stop(self, restart: bool = False): + if self.currently_stopping: + return # don't stop twice + if self.instant_stop: + log.info("session stop process stopped") + return # stop doing anything instantly, client is manually handling - await self.connection.close() - - if self.recv_task: - await self.recv_task + try: + self.currently_stopping = True + self.is_started.clear() + self.stored_msg_ids.clear() + + if restart: + self.instant_stop = True # tell all funcs that we want to stop + + self.ping_task_event.set() + for _ in range(2): + try: + if self.ping_task is not None: + await asyncio.wait_for( + self.ping_task, timeout=self.REART_TIMEOUT + ) + break + except TimeoutError: + self.ping_task.cancel() + continue # next stage + self.ping_task_event.clear() - if not self.is_media and callable(self.client.disconnect_handler): try: - await self.client.disconnect_handler(self.client) + await asyncio.wait_for( + self.connection.close(), timeout=self.REART_TIMEOUT + ) except Exception as e: log.exception(e) - log.info("Session stopped") + for _ in range(2): + try: + if self.recv_task: + await asyncio.wait_for( + self.recv_task, timeout=self.REART_TIMEOUT + ) + break + except TimeoutError: + self.recv_task.cancel() + continue # next stage + + if not self.is_media and callable(self.client.disconnect_handler): + try: + await self.client.disconnect_handler(self.client) + except Exception as e: + log.exception(e) + + log.info("session stopped") + finally: + self.currently_stopping = False + if restart: + self.instant_stop = False # reset async def restart(self): - now = datetime.now() - if ( - self.last_reconnect_attempt - and now - self.last_reconnect_attempt < self.RECONNECT_THRESHOLD - ): - log.info("Reconnecting too frequently, sleeping for a while") - await asyncio.sleep(5) + if self.currently_restarting: + return # don't restart twice + if self.instant_stop: + return # stop instantly - self.last_reconnect_attempt = now + try: + self.currently_restarting = True + now = time() + if ( + self.last_reconnect_attempt + and (now - self.last_reconnect_attempt) < self.RECONNECT_THRESHOLD + ): + to_wait = int( + self.RECONNECT_THRESHOLD - (now - self.last_reconnect_attempt) + ) + log.warning( + "[pyrogram] Client [%s] is reconnecting too frequently, sleeping for %s seconds", + self.client.name, + to_wait + ) + await asyncio.sleep(to_wait) - await self.stop() - await self.start() + self.last_reconnect_attempt = now + await self.stop(restart=True) + await self.start() + finally: + self.currently_restarting = False async def handle_packet(self, packet): + if self.instant_stop: + log.info("Stopped packet handler") + return # stop instantly + data = await self.loop.run_in_executor( pyrogram.crypto_executor, mtproto.unpack, @@ -298,9 +358,17 @@ async def handle_packet(self, packet): self.pending_acks.clear() async def ping_worker(self): + if self.instant_stop: + log.info("PingTask force stopped") + return # stop instantly + log.info("PingTask started") while True: + if self.instant_stop: + log.info("PingTask force stopped (loop)") + return # stop instantly + try: await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL) except asyncio.TimeoutError: @@ -326,15 +394,27 @@ async def recv_worker(self): log.info("NetworkTask started") while True: + if self.instant_stop: + log.info("NetworkTask force stopped (loop)") + return # stop instantly + packet = await self.connection.recv() if packet is None or len(packet) == 4: if packet: error_code = -Int.read(BytesIO(packet)) + if error_code == 404: + raise Unauthorized( + "Auth key not found in the system. You must delete your session file " + "and log in again with your phone number or bot token." + ) + log.warning( - "Server sent transport error: %s (%s)", - error_code, Session.TRANSPORT_ERRORS.get(error_code, "unknown error") + "[%s] Server sent transport error: %s (%s)", + self.client.name, + error_code, + Session.TRANSPORT_ERRORS.get(error_code, "unknown error"), ) if self.is_started.is_set(): @@ -346,7 +426,15 @@ async def recv_worker(self): log.info("NetworkTask stopped") - async def send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT): + async def send( + self, + data: TLObject, + wait_response: bool = True, + timeout: float = WAIT_TIMEOUT, + ): + if self.instant_stop: + return # stop instantly + message = self.msg_factory(data) msg_id = message.msg_id @@ -404,10 +492,8 @@ async def invoke( timeout: float = WAIT_TIMEOUT, sleep_threshold: float = SLEEP_THRESHOLD ): - try: - await asyncio.wait_for(self.is_started.wait(), self.WAIT_TIMEOUT) - except asyncio.TimeoutError: - pass + if self.instant_stop: + return # stop instantly if isinstance(query, Session.CUR_ALWD_INNR_QRYS): inner_query = query.query @@ -417,6 +503,19 @@ async def invoke( query_name = ".".join(inner_query.QUALNAME.split(".")[1:]) while retries > 0: + if self.instant_stop: + return # stop instantly + + # sleep until the restart is performed + if self.currently_restarting: + while self.currently_restarting: + if self.instant_stop: + return # stop instantly + await asyncio.sleep(1) + + if not self.is_started.is_set(): + await self.is_started.wait() + try: return await self.send(query, timeout=timeout) except (FloodWait, FloodPremiumWait) as e: @@ -425,11 +524,21 @@ async def invoke( if amount > sleep_threshold >= 0: raise - log.warning('[%s] Waiting for %s seconds before continuing (required by "%s")', - self.client.name, amount, query_name) + log.warning( + '[%s] Waiting for %s seconds before continuing (required by "%s")', + self.client.name, + amount, + query_name, + ) await asyncio.sleep(amount) - except (OSError, InternalServerError, ServiceUnavailable) as e: + except ( + OSError, + RuntimeError, + InternalServerError, + ServiceUnavailable, + TimeoutError, + ) as e: retries -= 1 if ( retries == 0 or @@ -443,13 +552,26 @@ async def invoke( ): raise e from None - (log.warning if retries < 2 else log.info)( - '[%s] Retrying "%s" due to: %s', - Session.MAX_RETRIES - retries, - query_name, - str(e) or repr(e) - ) + if (isinstance(e, (OSError, RuntimeError)) and "handler" in str(e)) or ( + isinstance(e, TimeoutError) + ): + (log.warning if retries < 2 else log.info)( + '[%s] [%s] reconnecting session requesting "%s", due to: %s', + self.client.name, + Session.MAX_RETRIES - retries, + query_name, + str(e) or repr(e), + ) + self.loop.create_task(self.restart()) + else: + (log.warning if retries < 2 else log.info)( + '[%s] [%s] Retrying "%s" due to: %s', + self.client.name, + Session.MAX_RETRIES - retries, + query_name, + str(e) or repr(e), + ) - await asyncio.sleep(3) + await asyncio.sleep(1) raise TimeoutError("Exceeded maximum number of retries")