Skip to content
This repository was archived by the owner on Dec 23, 2024. It is now read-only.

Commit 9bf742a

Browse files
committed
Introduce back some previously reverted changes
1 parent 03d60cd commit 9bf742a

File tree

11 files changed

+98
-88
lines changed

11 files changed

+98
-88
lines changed

pyrogram/client.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import shutil
2727
import sys
2828
from concurrent.futures.thread import ThreadPoolExecutor
29+
from datetime import datetime, timedelta
2930
from hashlib import sha256
3031
from importlib import import_module
3132
from io import StringIO, BytesIO
@@ -185,6 +186,9 @@ class Client(Methods):
185186
WORKERS = min(32, (os.cpu_count() or 0) + 4) # os.cpu_count() can be None
186187
WORKDIR = PARENT_DIR
187188

189+
# Interval of seconds in which the updates watchdog will kick in
190+
UPDATES_WATCHDOG_INTERVAL = 5 * 60
191+
188192
mimetypes = MimeTypes()
189193
mimetypes.readfp(StringIO(mime_types))
190194

@@ -273,6 +277,13 @@ def __init__(
273277

274278
self.message_cache = Cache(10000)
275279

280+
# Sometimes, for some reason, the server will stop sending updates and will only respond to pings.
281+
# This watchdog will invoke updates.GetState in order to wake up the server and enable it sending updates again
282+
# after some idle time has been detected.
283+
self.updates_watchdog_task = None
284+
self.updates_watchdog_event = asyncio.Event()
285+
self.last_update_time = datetime.now()
286+
276287
self.loop = asyncio.get_event_loop()
277288

278289
def __enter__(self):
@@ -293,6 +304,18 @@ async def __aexit__(self, *args):
293304
except ConnectionError:
294305
pass
295306

307+
async def updates_watchdog(self):
308+
while True:
309+
try:
310+
await asyncio.wait_for(self.updates_watchdog_event.wait(), self.UPDATES_WATCHDOG_INTERVAL)
311+
except asyncio.TimeoutError:
312+
pass
313+
else:
314+
break
315+
316+
if datetime.now() - self.last_update_time > timedelta(seconds=self.UPDATES_WATCHDOG_INTERVAL):
317+
await self.invoke(raw.functions.updates.GetState())
318+
296319
async def authorize(self) -> User:
297320
if self.bot_token:
298321
return await self.sign_in_bot(self.bot_token)
@@ -485,6 +508,8 @@ async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, ra
485508
return is_min
486509

487510
async def handle_updates(self, updates):
511+
self.last_update_time = datetime.now()
512+
488513
if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
489514
is_min = any((
490515
await self.fetch_peers(updates.users),

pyrogram/connection/connection.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ async def connect(self):
4848
await self.protocol.connect(self.address)
4949
except OSError as e:
5050
log.warning("Unable to connect due to network issues: %s", e)
51-
self.protocol.close()
51+
await self.protocol.close()
5252
await asyncio.sleep(1)
5353
else:
5454
log.info("Connected! %s DC%s%s - IPv%s",
@@ -59,17 +59,14 @@ async def connect(self):
5959
break
6060
else:
6161
log.warning("Connection failed! Trying again...")
62-
raise TimeoutError
62+
raise ConnectionError
6363

64-
def close(self):
65-
self.protocol.close()
64+
async def close(self):
65+
await self.protocol.close()
6666
log.info("Disconnected")
6767

6868
async def send(self, data: bytes):
69-
try:
70-
await self.protocol.send(data)
71-
except Exception as e:
72-
raise OSError(e)
69+
await self.protocol.send(data)
7370

7471
async def recv(self) -> Optional[bytes]:
7572
return await self.protocol.recv()

pyrogram/connection/transport/tcp/tcp.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
import ipaddress
2121
import logging
2222
import socket
23-
import time
24-
from concurrent.futures import ThreadPoolExecutor
25-
2623
import socks
2724

2825
log = logging.getLogger(__name__)
@@ -34,8 +31,8 @@ class TCP:
3431
def __init__(self, ipv6: bool, proxy: dict):
3532
self.socket = None
3633

37-
self.reader = None # type: asyncio.StreamReader
38-
self.writer = None # type: asyncio.StreamWriter
34+
self.reader = None
35+
self.writer = None
3936

4037
self.lock = asyncio.Lock()
4138
self.loop = asyncio.get_event_loop()
@@ -63,39 +60,37 @@ def __init__(self, ipv6: bool, proxy: dict):
6360

6461
log.info("Using proxy %s", hostname)
6562
else:
66-
self.socket = socks.socksocket(
63+
self.socket = socket.socket(
6764
socket.AF_INET6 if ipv6
6865
else socket.AF_INET
6966
)
7067

71-
self.socket.settimeout(TCP.TIMEOUT)
68+
self.socket.setblocking(False)
7269

7370
async def connect(self, address: tuple):
74-
# The socket used by the whole logic is blocking and thus it blocks when connecting.
75-
# Offload the task to a thread executor to avoid blocking the main event loop.
76-
with ThreadPoolExecutor(1) as executor:
77-
await self.loop.run_in_executor(executor, self.socket.connect, address)
71+
try:
72+
await asyncio.wait_for(asyncio.get_event_loop().sock_connect(self.socket, address), TCP.TIMEOUT)
73+
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
74+
raise TimeoutError("Connection timed out")
7875

7976
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
8077

81-
def close(self):
78+
async def close(self):
8279
try:
83-
self.writer.close()
84-
except AttributeError:
85-
try:
86-
self.socket.shutdown(socket.SHUT_RDWR)
87-
except OSError:
88-
pass
89-
finally:
90-
# A tiny sleep placed here helps avoiding .recv(n) hanging until the timeout.
91-
# This is a workaround that seems to fix the occasional delayed stop of a client.
92-
time.sleep(0.001)
93-
self.socket.close()
80+
if self.writer is not None:
81+
self.writer.close()
82+
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
83+
except Exception as e:
84+
log.warning("Close exception: %s %s", type(e).__name__, e)
9485

9586
async def send(self, data: bytes):
9687
async with self.lock:
97-
self.writer.write(data)
98-
await self.writer.drain()
88+
try:
89+
if self.writer is not None:
90+
self.writer.write(data)
91+
await self.writer.drain()
92+
except Exception as e:
93+
log.warning("Send exception: %s %s", type(e).__name__, e)
9994

10095
async def recv(self, length: int = 0):
10196
data = b""

pyrogram/methods/auth/initialize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# You should have received a copy of the GNU Lesser General Public License
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

19+
import asyncio
1920
import logging
2021

2122
import pyrogram
@@ -46,4 +47,6 @@ async def initialize(
4647

4748
await self.dispatcher.start()
4849

50+
self.updates_watchdog_task = asyncio.create_task(self.updates_watchdog())
51+
4952
self.is_initialized = True

pyrogram/methods/auth/terminate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,11 @@ async def terminate(
5151

5252
self.media_sessions.clear()
5353

54+
self.updates_watchdog_event.set()
55+
56+
if self.updates_watchdog_task is not None:
57+
await self.updates_watchdog_task
58+
59+
self.updates_watchdog_event.clear()
60+
5461
self.is_initialized = False

pyrogram/session/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,4 +278,4 @@ async def create(self):
278278
else:
279279
return auth_key
280280
finally:
281-
self.connection.close()
281+
await self.connection.close()

pyrogram/session/internals/msg_id.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ class MsgId:
2727
offset = 0
2828

2929
def __new__(cls) -> int:
30-
now = time.time()
30+
now = int(time.time())
3131
cls.offset = (cls.offset + 4) if now == cls.last_time else 0
32-
msg_id = int(now * 2 ** 32) + cls.offset
32+
msg_id = (now * 2 ** 32) + cls.offset
3333
cls.last_time = now
3434

3535
return msg_id

pyrogram/session/internals/seq_no.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,15 @@
1616
# You should have received a copy of the GNU Lesser General Public License
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

19-
from threading import Lock
20-
2119

2220
class SeqNo:
2321
def __init__(self):
2422
self.content_related_messages_sent = 0
25-
self.lock = Lock()
2623

2724
def __call__(self, is_content_related: bool) -> int:
28-
with self.lock:
29-
seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0)
25+
seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0)
3026

31-
if is_content_related:
32-
self.content_related_messages_sent += 1
27+
if is_content_related:
28+
self.content_related_messages_sent += 1
3329

34-
return seq_no
30+
return seq_no

pyrogram/session/session.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ def __init__(self):
4444

4545

4646
class Session:
47-
START_TIMEOUT = 1
47+
START_TIMEOUT = 5
4848
WAIT_TIMEOUT = 15
4949
SLEEP_THRESHOLD = 10
50-
MAX_RETRIES = 5
51-
ACKS_THRESHOLD = 8
50+
MAX_RETRIES = 10
51+
ACKS_THRESHOLD = 10
5252
PING_INTERVAL = 5
5353

5454
def __init__(
@@ -156,14 +156,11 @@ async def stop(self):
156156

157157
self.ping_task_event.clear()
158158

159-
self.connection.close()
159+
await self.connection.close()
160160

161161
if self.recv_task:
162162
await self.recv_task
163163

164-
for i in self.results.values():
165-
i.event.set()
166-
167164
if not self.is_media and callable(self.client.disconnect_handler):
168165
try:
169166
await self.client.disconnect_handler(self.client)
@@ -189,6 +186,7 @@ async def handle_packet(self, packet):
189186
)
190187
except SecurityCheckMismatch as e:
191188
log.warning("Discarding packet: %s", e)
189+
await self.connection.close()
192190
return
193191

194192
messages = (
@@ -284,9 +282,6 @@ async def send(self, data: TLObject, wait_response: bool = True, timeout: float
284282
message = self.msg_factory(data)
285283
msg_id = message.msg_id
286284

287-
if wait_response:
288-
self.results[msg_id] = Result()
289-
290285
log.debug("Sent: %s", message)
291286

292287
payload = await self.loop.run_in_executor(
@@ -299,34 +294,35 @@ async def send(self, data: TLObject, wait_response: bool = True, timeout: float
299294
self.auth_key_id
300295
)
301296

302-
try:
303-
await self.connection.send(payload)
304-
except OSError as e:
305-
self.results.pop(msg_id, None)
306-
raise e
297+
await self.connection.send(payload)
307298

308299
if wait_response:
300+
self.results[msg_id] = Result()
301+
309302
try:
310303
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
311304
except asyncio.TimeoutError:
312305
pass
313-
finally:
314-
result = self.results.pop(msg_id).value
306+
307+
result = self.results.pop(msg_id).value
315308

316309
if result is None:
317310
raise TimeoutError("Request timed out")
318-
elif isinstance(result, raw.types.RpcError):
311+
312+
if isinstance(result, raw.types.RpcError):
319313
if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)):
320314
data = data.query
321315

322316
RPCError.raise_it(result, type(data))
323-
elif isinstance(result, raw.types.BadMsgNotification):
317+
318+
if isinstance(result, raw.types.BadMsgNotification):
324319
raise BadMsgNotification(result.error_code)
325-
elif isinstance(result, raw.types.BadServerSalt):
320+
321+
if isinstance(result, raw.types.BadServerSalt):
326322
self.salt = result.new_server_salt
327323
return await self.send(data, wait_response, timeout)
328-
else:
329-
return result
324+
325+
return result
330326

331327
async def invoke(
332328
self,

pyrogram/storage/file_storage.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ def update(self):
3838
version = self.version()
3939

4040
if version == 1:
41-
with self.lock, self.conn:
41+
with self.conn:
4242
self.conn.execute("DELETE FROM peers")
4343

4444
version += 1
4545

4646
if version == 2:
47-
with self.lock, self.conn:
47+
with self.conn:
4848
self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
4949

5050
version += 1
@@ -63,10 +63,7 @@ async def open(self):
6363
self.update()
6464

6565
with self.conn:
66-
try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum
67-
self.conn.execute("VACUUM")
68-
except sqlite3.OperationalError:
69-
pass
66+
self.conn.execute("VACUUM")
7067

7168
async def delete(self):
7269
os.remove(self.database)

0 commit comments

Comments
 (0)