Skip to content

Commit

Permalink
Dynamic session ReStart + restart optimizations (#56)
Browse files Browse the repository at this point in the history
Co-authored-by: Marvin <eymarv07@gmail.com>
  • Loading branch information
SpEcHiDe and eyMarv authored Jul 16, 2024
1 parent 39fe826 commit 2f95181
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 48 deletions.
1 change: 1 addition & 0 deletions docs/source/releases/changes-in-this-fork.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ If you found any issue or have any suggestions, feel free to make `an issue <htt
| Scheme layer used: 184 |
+------------------------+

- Dynamic session ReStart + restart optimizations (`#56 <https://github.com/TelegramPlayGround/pyrogram/pull/56>`__)
- Added the :meth:`~pyrogram.Client.delete_account`, :meth:`~pyrogram.Client.transfer_chat_ownership`, :meth:`~pyrogram.Client.update_status` (`#49 <https://github.com/TelegramPlayGround/pyrogram/pull/49>`__, `#51 <https://github.com/TelegramPlayGround/pyrogram/pull/51>`__)
- Added the class :obj:`~pyrogram.types.RefundedPayment`, containing information about a refunded payment.
- Added the field ``refunded_payment`` to the class :obj:`~pyrogram.types.Message`, describing a service message about a refunded payment.
Expand Down
218 changes: 170 additions & 48 deletions pyrogram/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import bisect
import logging
import os
from datetime import datetime, timedelta
from time import time
from hashlib import sha1
from io import BytesIO
from typing import Optional
Expand Down Expand Up @@ -49,14 +49,15 @@ def __init__(self):


class Session:
START_TIMEOUT = 2
START_TIMEOUT = 5
WAIT_TIMEOUT = 15
REART_TIMEOUT = 5
SLEEP_THRESHOLD = 10
MAX_RETRIES = 10
ACKS_THRESHOLD = 10
PING_INTERVAL = 5
STORED_MSG_IDS_MAX_SIZE = 1000 * 2
RECONNECT_THRESHOLD = timedelta(seconds=10)
RECONNECT_THRESHOLD = 13

TRANSPORT_ERRORS = {
404: "auth key not found",
Expand Down Expand Up @@ -110,10 +111,17 @@ def __init__(

self.loop = asyncio.get_event_loop()

self.instant_stop = False # set internally
self.last_reconnect_attempt = None
self.currently_restarting = False
self.currently_stopping = False

async def start(self):
while True:
if self.instant_stop:
log.info("session init stopped")
return # stop instantly

self.connection = self.client.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
Expand Down Expand Up @@ -173,46 +181,98 @@ async def start(self):

log.info("Session started")

async def stop(self):
self.is_started.clear()

self.stored_msg_ids.clear()

self.ping_task_event.set()

if self.ping_task is not None:
await self.ping_task

self.ping_task_event.clear()
async def stop(self, restart: bool = False):
if self.currently_stopping:
return # don't stop twice
if self.instant_stop:
log.info("session stop process stopped")
return # stop doing anything instantly, client is manually handling

await self.connection.close()

if self.recv_task:
await self.recv_task
try:
self.currently_stopping = True
self.is_started.clear()
self.stored_msg_ids.clear()

if restart:
self.instant_stop = True # tell all funcs that we want to stop

self.ping_task_event.set()
for _ in range(2):
try:
if self.ping_task is not None:
await asyncio.wait_for(
self.ping_task, timeout=self.REART_TIMEOUT
)
break
except TimeoutError:
self.ping_task.cancel()
continue # next stage
self.ping_task_event.clear()

if not self.is_media and callable(self.client.disconnect_handler):
try:
await self.client.disconnect_handler(self.client)
await asyncio.wait_for(
self.connection.close(), timeout=self.REART_TIMEOUT
)
except Exception as e:
log.exception(e)

log.info("Session stopped")
for _ in range(2):
try:
if self.recv_task:
await asyncio.wait_for(
self.recv_task, timeout=self.REART_TIMEOUT
)
break
except TimeoutError:
self.recv_task.cancel()
continue # next stage

if not self.is_media and callable(self.client.disconnect_handler):
try:
await self.client.disconnect_handler(self.client)
except Exception as e:
log.exception(e)

log.info("session stopped")
finally:
self.currently_stopping = False
if restart:
self.instant_stop = False # reset

async def restart(self):
now = datetime.now()
if (
self.last_reconnect_attempt
and now - self.last_reconnect_attempt < self.RECONNECT_THRESHOLD
):
log.info("Reconnecting too frequently, sleeping for a while")
await asyncio.sleep(5)
if self.currently_restarting:
return # don't restart twice
if self.instant_stop:
return # stop instantly

self.last_reconnect_attempt = now
try:
self.currently_restarting = True
now = time()
if (
self.last_reconnect_attempt
and (now - self.last_reconnect_attempt) < self.RECONNECT_THRESHOLD
):
to_wait = int(
self.RECONNECT_THRESHOLD - (now - self.last_reconnect_attempt)
)
log.warning(
"[pyrogram] Client [%s] is reconnecting too frequently, sleeping for %s seconds",
self.client.name,
to_wait
)
await asyncio.sleep(to_wait)

await self.stop()
await self.start()
self.last_reconnect_attempt = now
await self.stop(restart=True)
await self.start()
finally:
self.currently_restarting = False

async def handle_packet(self, packet):
if self.instant_stop:
log.info("Stopped packet handler")
return # stop instantly

data = await self.loop.run_in_executor(
pyrogram.crypto_executor,
mtproto.unpack,
Expand Down Expand Up @@ -298,9 +358,17 @@ async def handle_packet(self, packet):
self.pending_acks.clear()

async def ping_worker(self):
if self.instant_stop:
log.info("PingTask force stopped")
return # stop instantly

log.info("PingTask started")

while True:
if self.instant_stop:
log.info("PingTask force stopped (loop)")
return # stop instantly

try:
await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL)
except asyncio.TimeoutError:
Expand All @@ -326,15 +394,27 @@ async def recv_worker(self):
log.info("NetworkTask started")

while True:
if self.instant_stop:
log.info("NetworkTask force stopped (loop)")
return # stop instantly

packet = await self.connection.recv()

if packet is None or len(packet) == 4:
if packet:
error_code = -Int.read(BytesIO(packet))

if error_code == 404:
raise Unauthorized(
"Auth key not found in the system. You must delete your session file "
"and log in again with your phone number or bot token."
)

log.warning(
"Server sent transport error: %s (%s)",
error_code, Session.TRANSPORT_ERRORS.get(error_code, "unknown error")
"[%s] Server sent transport error: %s (%s)",
self.client.name,
error_code,
Session.TRANSPORT_ERRORS.get(error_code, "unknown error"),
)

if self.is_started.is_set():
Expand All @@ -346,7 +426,15 @@ async def recv_worker(self):

log.info("NetworkTask stopped")

async def send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
async def send(
self,
data: TLObject,
wait_response: bool = True,
timeout: float = WAIT_TIMEOUT,
):
if self.instant_stop:
return # stop instantly

message = self.msg_factory(data)
msg_id = message.msg_id

Expand Down Expand Up @@ -404,10 +492,8 @@ async def invoke(
timeout: float = WAIT_TIMEOUT,
sleep_threshold: float = SLEEP_THRESHOLD
):
try:
await asyncio.wait_for(self.is_started.wait(), self.WAIT_TIMEOUT)
except asyncio.TimeoutError:
pass
if self.instant_stop:
return # stop instantly

if isinstance(query, Session.CUR_ALWD_INNR_QRYS):
inner_query = query.query
Expand All @@ -417,6 +503,19 @@ async def invoke(
query_name = ".".join(inner_query.QUALNAME.split(".")[1:])

while retries > 0:
if self.instant_stop:
return # stop instantly

# sleep until the restart is performed
if self.currently_restarting:
while self.currently_restarting:
if self.instant_stop:
return # stop instantly
await asyncio.sleep(1)

if not self.is_started.is_set():
await self.is_started.wait()

try:
return await self.send(query, timeout=timeout)
except (FloodWait, FloodPremiumWait) as e:
Expand All @@ -425,11 +524,21 @@ async def invoke(
if amount > sleep_threshold >= 0:
raise

log.warning('[%s] Waiting for %s seconds before continuing (required by "%s")',
self.client.name, amount, query_name)
log.warning(
'[%s] Waiting for %s seconds before continuing (required by "%s")',
self.client.name,
amount,
query_name,
)

await asyncio.sleep(amount)
except (OSError, InternalServerError, ServiceUnavailable) as e:
except (
OSError,
RuntimeError,
InternalServerError,
ServiceUnavailable,
TimeoutError,
) as e:
retries -= 1
if (
retries == 0 or
Expand All @@ -443,13 +552,26 @@ async def invoke(
):
raise e from None

(log.warning if retries < 2 else log.info)(
'[%s] Retrying "%s" due to: %s',
Session.MAX_RETRIES - retries,
query_name,
str(e) or repr(e)
)
if (isinstance(e, (OSError, RuntimeError)) and "handler" in str(e)) or (
isinstance(e, TimeoutError)
):
(log.warning if retries < 2 else log.info)(
'[%s] [%s] reconnecting session requesting "%s", due to: %s',
self.client.name,
Session.MAX_RETRIES - retries,
query_name,
str(e) or repr(e),
)
self.loop.create_task(self.restart())
else:
(log.warning if retries < 2 else log.info)(
'[%s] [%s] Retrying "%s" due to: %s',
self.client.name,
Session.MAX_RETRIES - retries,
query_name,
str(e) or repr(e),
)

await asyncio.sleep(3)
await asyncio.sleep(1)

raise TimeoutError("Exceeded maximum number of retries")

0 comments on commit 2f95181

Please sign in to comment.