Skip to content

Commit bff583e

Browse files
committed
Revert some of the latest changes
1 parent a81b8a2 commit bff583e

10 files changed

Lines changed: 86 additions & 111 deletions

File tree

pyrogram/client.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import shutil
2727
import sys
2828
from concurrent.futures.thread import ThreadPoolExecutor
29-
from datetime import datetime, timedelta
3029
from hashlib import sha256
3130
from importlib import import_module
3231
from io import StringIO, BytesIO
@@ -186,9 +185,6 @@ class Client(Methods):
186185
WORKERS = min(32, (os.cpu_count() or 0) + 4) # os.cpu_count() can be None
187186
WORKDIR = PARENT_DIR
188187

189-
# Interval of seconds in which the updates watchdog will kick in
190-
UPDATES_WATCHDOG_INTERVAL = 5 * 60
191-
192188
mimetypes = MimeTypes()
193189
mimetypes.readfp(StringIO(mime_types))
194190

@@ -277,13 +273,6 @@ def __init__(
277273

278274
self.message_cache = Cache(10000)
279275

280-
# Sometimes, for some reason, the server will stop sending updates and will only respond to pings.
281-
# This watchdog will invoke updates.GetState in order to wake up the server and enable it sending updates again
282-
# after some idle time has been detected.
283-
self.updates_watchdog_task = None
284-
self.updates_watchdog_event = asyncio.Event()
285-
self.last_update_time = datetime.now()
286-
287276
self.loop = asyncio.get_event_loop()
288277

289278
def __enter__(self):
@@ -304,18 +293,6 @@ async def __aexit__(self, *args):
304293
except ConnectionError:
305294
pass
306295

307-
async def updates_watchdog(self):
308-
while True:
309-
try:
310-
await asyncio.wait_for(self.updates_watchdog_event.wait(), self.UPDATES_WATCHDOG_INTERVAL)
311-
except asyncio.TimeoutError:
312-
pass
313-
else:
314-
break
315-
316-
if datetime.now() - self.last_update_time > timedelta(seconds=self.UPDATES_WATCHDOG_INTERVAL):
317-
await self.invoke(raw.functions.updates.GetState())
318-
319296
async def authorize(self) -> User:
320297
if self.bot_token:
321298
return await self.sign_in_bot(self.bot_token)
@@ -508,8 +485,6 @@ async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, ra
508485
return is_min
509486

510487
async def handle_updates(self, updates):
511-
self.last_update_time = datetime.now()
512-
513488
if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
514489
is_min = any((
515490
await self.fetch_peers(updates.users),

pyrogram/connection/connection.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ async def connect(self):
4848
await self.protocol.connect(self.address)
4949
except OSError as e:
5050
log.warning("Unable to connect due to network issues: %s", e)
51-
await self.protocol.close()
51+
self.protocol.close()
5252
await asyncio.sleep(1)
5353
else:
5454
log.info("Connected! %s DC%s%s - IPv%s",
@@ -59,14 +59,17 @@ async def connect(self):
5959
break
6060
else:
6161
log.warning("Connection failed! Trying again...")
62-
raise ConnectionError
62+
raise TimeoutError
6363

64-
async def close(self):
65-
await self.protocol.close()
64+
def close(self):
65+
self.protocol.close()
6666
log.info("Disconnected")
6767

6868
async def send(self, data: bytes):
69-
await self.protocol.send(data)
69+
try:
70+
await self.protocol.send(data)
71+
except Exception as e:
72+
raise OSError(e)
7073

7174
async def recv(self) -> Optional[bytes]:
7275
return await self.protocol.recv()

pyrogram/connection/transport/tcp/tcp.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
import ipaddress
2121
import logging
2222
import socket
23+
import time
24+
from concurrent.futures import ThreadPoolExecutor
25+
2326
import socks
2427

2528
log = logging.getLogger(__name__)
@@ -31,12 +34,10 @@ class TCP:
3134
def __init__(self, ipv6: bool, proxy: dict):
3235
self.socket = None
3336

34-
self.reader = None
35-
self.writer = None
36-
37-
self.send_queue = asyncio.Queue()
38-
self.send_task = None
37+
self.reader = None # type: asyncio.StreamReader
38+
self.writer = None # type: asyncio.StreamWriter
3939

40+
self.lock = asyncio.Lock()
4041
self.loop = asyncio.get_event_loop()
4142

4243
if proxy:
@@ -62,50 +63,39 @@ def __init__(self, ipv6: bool, proxy: dict):
6263

6364
log.info("Using proxy %s", hostname)
6465
else:
65-
self.socket = socket.socket(
66+
self.socket = socks.socksocket(
6667
socket.AF_INET6 if ipv6
6768
else socket.AF_INET
6869
)
6970

70-
self.socket.setblocking(False)
71+
self.socket.settimeout(TCP.TIMEOUT)
7172

7273
async def connect(self, address: tuple):
73-
try:
74-
await asyncio.wait_for(asyncio.get_event_loop().sock_connect(self.socket, address), TCP.TIMEOUT)
75-
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
76-
raise TimeoutError("Connection timed out")
74+
# The socket used by the whole logic is blocking and thus it blocks when connecting.
75+
# Offload the task to a thread executor to avoid blocking the main event loop.
76+
with ThreadPoolExecutor(1) as executor:
77+
await self.loop.run_in_executor(executor, self.socket.connect, address)
7778

7879
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
79-
self.send_task = asyncio.create_task(self.send_worker())
80-
81-
async def close(self):
82-
await self.send_queue.put(None)
83-
84-
if self.send_task is not None:
85-
await self.send_task
8680

81+
def close(self):
8782
try:
88-
if self.writer is not None:
89-
self.writer.close()
90-
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
91-
except Exception as e:
92-
log.info("Close exception: %s %s", type(e).__name__, e)
83+
self.writer.close()
84+
except AttributeError:
85+
try:
86+
self.socket.shutdown(socket.SHUT_RDWR)
87+
except OSError:
88+
pass
89+
finally:
90+
# A tiny sleep placed here helps avoiding .recv(n) hanging until the timeout.
91+
# This is a workaround that seems to fix the occasional delayed stop of a client.
92+
time.sleep(0.001)
93+
self.socket.close()
9394

9495
async def send(self, data: bytes):
95-
await self.send_queue.put(data)
96-
97-
async def send_worker(self):
98-
while True:
99-
data = await self.send_queue.get()
100-
101-
if data is None:
102-
break
103-
104-
try:
105-
self.writer.write(data)
106-
await self.writer.drain()
107-
except Exception as e:
108-
log.info("Send exception: %s %s", type(e).__name__, e)
96+
async with self.lock:
97+
self.writer.write(data)
98+
await self.writer.drain()
10999

110100
async def recv(self, length: int = 0):
111101
data = b""

pyrogram/methods/auth/initialize.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
# You should have received a copy of the GNU Lesser General Public License
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

19-
import asyncio
2019
import logging
2120

2221
import pyrogram
@@ -47,6 +46,4 @@ async def initialize(
4746

4847
await self.dispatcher.start()
4948

50-
self.updates_watchdog_task = asyncio.create_task(self.updates_watchdog())
51-
5249
self.is_initialized = True

pyrogram/methods/auth/terminate.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,4 @@ async def terminate(
5151

5252
self.media_sessions.clear()
5353

54-
self.updates_watchdog_event.set()
55-
56-
if self.updates_watchdog_task is not None:
57-
await self.updates_watchdog_task
58-
59-
self.updates_watchdog_event.clear()
60-
6154
self.is_initialized = False

pyrogram/session/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,4 +278,4 @@ async def create(self):
278278
else:
279279
return auth_key
280280
finally:
281-
await self.connection.close()
281+
self.connection.close()

pyrogram/session/internals/seq_no.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616
# You should have received a copy of the GNU Lesser General Public License
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

19+
from threading import Lock
20+
1921

2022
class SeqNo:
2123
def __init__(self):
2224
self.content_related_messages_sent = 0
25+
self.lock = Lock()
2326

2427
def __call__(self, is_content_related: bool) -> int:
25-
seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0)
28+
with self.lock:
29+
seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0)
2630

27-
if is_content_related:
28-
self.content_related_messages_sent += 1
31+
if is_content_related:
32+
self.content_related_messages_sent += 1
2933

30-
return seq_no
34+
return seq_no

pyrogram/session/session.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,14 @@ async def stop(self):
156156

157157
self.ping_task_event.clear()
158158

159-
await self.connection.close()
159+
self.connection.close()
160160

161161
if self.recv_task:
162162
await self.recv_task
163163

164+
for i in self.results.values():
165+
i.event.set()
166+
164167
if not self.is_media and callable(self.client.disconnect_handler):
165168
try:
166169
await self.client.disconnect_handler(self.client)
@@ -185,8 +188,7 @@ async def handle_packet(self, packet):
185188
self.stored_msg_ids
186189
)
187190
except SecurityCheckMismatch as e:
188-
log.info("Discarding packet: %s", e)
189-
await self.connection.close()
191+
log.warning("Discarding packet: %s", e)
190192
return
191193

192194
messages = (
@@ -282,6 +284,9 @@ async def send(self, data: TLObject, wait_response: bool = True, timeout: float
282284
message = self.msg_factory(data)
283285
msg_id = message.msg_id
284286

287+
if wait_response:
288+
self.results[msg_id] = Result()
289+
285290
log.debug("Sent: %s", message)
286291

287292
payload = await self.loop.run_in_executor(
@@ -294,35 +299,34 @@ async def send(self, data: TLObject, wait_response: bool = True, timeout: float
294299
self.auth_key_id
295300
)
296301

297-
await self.connection.send(payload)
302+
try:
303+
await self.connection.send(payload)
304+
except OSError as e:
305+
self.results.pop(msg_id, None)
306+
raise e
298307

299308
if wait_response:
300-
self.results[msg_id] = Result()
301-
302309
try:
303310
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
304311
except asyncio.TimeoutError:
305312
pass
306-
307-
result = self.results.pop(msg_id).value
313+
finally:
314+
result = self.results.pop(msg_id).value
308315

309316
if result is None:
310317
raise TimeoutError("Request timed out")
311-
312-
if isinstance(result, raw.types.RpcError):
318+
elif isinstance(result, raw.types.RpcError):
313319
if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)):
314320
data = data.query
315321

316322
RPCError.raise_it(result, type(data))
317-
318-
if isinstance(result, raw.types.BadMsgNotification):
323+
elif isinstance(result, raw.types.BadMsgNotification):
319324
raise BadMsgNotification(result.error_code)
320-
321-
if isinstance(result, raw.types.BadServerSalt):
325+
elif isinstance(result, raw.types.BadServerSalt):
322326
self.salt = result.new_server_salt
323327
return await self.send(data, wait_response, timeout)
324-
325-
return result
328+
else:
329+
return result
326330

327331
async def invoke(
328332
self,

pyrogram/storage/file_storage.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ def update(self):
3838
version = self.version()
3939

4040
if version == 1:
41-
with self.conn:
41+
with self.lock, self.conn:
4242
self.conn.execute("DELETE FROM peers")
4343

4444
version += 1
4545

4646
if version == 2:
47-
with self.conn:
47+
with self.lock, self.conn:
4848
self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
4949

5050
version += 1
@@ -63,7 +63,10 @@ async def open(self):
6363
self.update()
6464

6565
with self.conn:
66-
self.conn.execute("VACUUM")
66+
try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum
67+
self.conn.execute("VACUUM")
68+
except sqlite3.OperationalError:
69+
pass
6770

6871
async def delete(self):
6972
os.remove(self.database)

0 commit comments

Comments
 (0)