Skip to content

Commit d298c62

Browse files
committed
Update session.py
1 parent 7182a7c commit d298c62

1 file changed

Lines changed: 44 additions & 53 deletions

File tree

pyrogram/session/session.py

Lines changed: 44 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from pyrogram.raw.all import layer
3434
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
35-
from .internals import MsgId, MsgFactory
35+
from .internals import MsgFactory
3636

3737
log = logging.getLogger(__name__)
3838

@@ -85,9 +85,9 @@ def __init__(
8585
self.ping_task = None
8686
self.ping_task_event = asyncio.Event()
8787

88-
self.network_task = None
88+
self.recv_task = None
8989

90-
self.is_connected = asyncio.Event()
90+
self.is_started = asyncio.Event()
9191

9292
self.loop = asyncio.get_event_loop()
9393

@@ -104,7 +104,7 @@ async def start(self):
104104
try:
105105
await self.connection.connect()
106106

107-
self.network_task = self.loop.create_task(self.network_worker())
107+
self.recv_task = self.loop.create_task(self.recv_worker())
108108

109109
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
110110

@@ -128,27 +128,26 @@ async def start(self):
128128

129129
self.ping_task = self.loop.create_task(self.ping_worker())
130130

131-
log.info(f"Session initialized: Layer {layer}")
132-
log.info(f"Device: {self.client.device_model} - {self.client.app_version}")
133-
log.info(f"System: {self.client.system_version} ({self.client.lang_code.upper()})")
134-
131+
log.info("Session initialized: Layer %s", layer)
132+
log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
133+
log.info("System: %s (%s)", self.client.system_version, self.client.lang_code)
135134
except AuthKeyDuplicated as e:
136135
await self.stop()
137136
raise e
138-
except (OSError, TimeoutError, RPCError):
137+
except (OSError, RPCError):
139138
await self.stop()
140139
except Exception as e:
141140
await self.stop()
142141
raise e
143142
else:
144143
break
145144

146-
self.is_connected.set()
145+
self.is_started.set()
147146

148147
log.info("Session started")
149148

150149
async def stop(self):
151-
self.is_connected.clear()
150+
self.is_started.clear()
152151

153152
self.ping_task_event.set()
154153

@@ -159,17 +158,14 @@ async def stop(self):
159158

160159
await self.connection.close()
161160

162-
if self.network_task:
163-
await self.network_task
164-
165-
for i in self.results.values():
166-
i.event.set()
161+
if self.recv_task:
162+
await self.recv_task
167163

168164
if not self.is_media and callable(self.client.disconnect_handler):
169165
try:
170166
await self.client.disconnect_handler(self.client)
171167
except Exception as e:
172-
log.error(e, exc_info=True)
168+
log.exception(e)
173169

174170
log.info("Session stopped")
175171

@@ -189,7 +185,7 @@ async def handle_packet(self, packet):
189185
self.stored_msg_ids
190186
)
191187
except SecurityCheckMismatch as e:
192-
log.info(f"Discarding packet: {e}")
188+
log.info("Discarding packet: %s", e)
193189
await self.connection.close()
194190
return
195191

@@ -199,10 +195,7 @@ async def handle_packet(self, packet):
199195
else [data]
200196
)
201197

202-
# Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}")
203-
# will cause "data" to be evaluated as string every time instead of only when debug is actually enabled.
204-
log.debug("Received:")
205-
log.debug(data)
198+
log.debug("Received: %s", data)
206199

207200
for msg in messages:
208201
if msg.seq_no % 2 != 0:
@@ -235,11 +228,11 @@ async def handle_packet(self, packet):
235228
self.results[msg_id].event.set()
236229

237230
if len(self.pending_acks) >= self.ACKS_THRESHOLD:
238-
log.debug(f"Send {len(self.pending_acks)} acks")
231+
log.debug("Sending %s acks", len(self.pending_acks))
239232

240233
try:
241234
await self.send(raw.types.MsgsAck(msg_ids=list(self.pending_acks)), False)
242-
except (OSError, TimeoutError):
235+
except OSError:
243236
pass
244237
else:
245238
self.pending_acks.clear()
@@ -261,22 +254,22 @@ async def ping_worker(self):
261254
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
262255
), False
263256
)
264-
except (OSError, TimeoutError, RPCError):
257+
except (OSError, RPCError):
265258
pass
266259

267260
log.info("PingTask stopped")
268261

269-
async def network_worker(self):
262+
async def recv_worker(self):
270263
log.info("NetworkTask started")
271264

272265
while True:
273266
packet = await self.connection.recv()
274267

275268
if packet is None or len(packet) == 4:
276269
if packet:
277-
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
270+
log.warning('Server sent "%s"', Int.read(BytesIO(packet)))
278271

279-
if self.is_connected.is_set():
272+
if self.is_started.is_set():
280273
self.loop.create_task(self.restart())
281274

282275
break
@@ -289,13 +282,7 @@ async def send(self, data: TLObject, wait_response: bool = True, timeout: float
289282
message = self.msg_factory(data)
290283
msg_id = message.msg_id
291284

292-
if wait_response:
293-
self.results[msg_id] = Result()
294-
295-
# Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}")
296-
# will cause "data" to be evaluated as string every time instead of only when debug is actually enabled.
297-
log.debug(f"Sent:")
298-
log.debug(message)
285+
log.debug("Sent: %s", message)
299286

300287
payload = await self.loop.run_in_executor(
301288
pyrogram.crypto_executor,
@@ -307,34 +294,35 @@ async def send(self, data: TLObject, wait_response: bool = True, timeout: float
307294
self.auth_key_id
308295
)
309296

310-
try:
311-
await self.connection.send(payload)
312-
except OSError as e:
313-
self.results.pop(msg_id, None)
314-
raise e
297+
await self.connection.send(payload)
315298

316299
if wait_response:
300+
self.results[msg_id] = Result()
301+
317302
try:
318303
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
319304
except asyncio.TimeoutError:
320305
pass
321-
finally:
322-
result = self.results.pop(msg_id).value
306+
307+
result = self.results.pop(msg_id).value
323308

324309
if result is None:
325-
raise TimeoutError
326-
elif isinstance(result, raw.types.RpcError):
310+
raise TimeoutError("Response timed out")
311+
312+
if isinstance(result, raw.types.RpcError):
327313
if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)):
328314
data = data.query
329315

330316
RPCError.raise_it(result, type(data))
331-
elif isinstance(result, raw.types.BadMsgNotification):
317+
318+
if isinstance(result, raw.types.BadMsgNotification):
332319
raise BadMsgNotification(result.error_code)
333-
elif isinstance(result, raw.types.BadServerSalt):
320+
321+
if isinstance(result, raw.types.BadServerSalt):
334322
self.salt = result.new_server_salt
335323
return await self.send(data, wait_response, timeout)
336-
else:
337-
return result
324+
325+
return result
338326

339327
async def invoke(
340328
self,
@@ -344,7 +332,7 @@ async def invoke(
344332
sleep_threshold: float = SLEEP_THRESHOLD
345333
):
346334
try:
347-
await asyncio.wait_for(self.is_connected.wait(), self.WAIT_TIMEOUT)
335+
await asyncio.wait_for(self.is_started.wait(), self.WAIT_TIMEOUT)
348336
except asyncio.TimeoutError:
349337
pass
350338

@@ -364,16 +352,19 @@ async def invoke(
364352
if amount > sleep_threshold >= 0:
365353
raise
366354

367-
log.warning(f'[{self.client.name}] Waiting for {amount} seconds before continuing '
368-
f'(required by "{query_name}")')
355+
log.warning('[%s] Waiting for %s seconds before continuing (required by "%s")',
356+
self.client.name, amount, query_name)
369357

370358
await asyncio.sleep(amount)
371-
except (OSError, TimeoutError, InternalServerError, ServiceUnavailable) as e:
359+
except (OSError, InternalServerError, ServiceUnavailable) as e:
372360
if retries == 0:
373361
raise e from None
374362

375363
(log.warning if retries < 2 else log.info)(
376-
f'[{Session.MAX_RETRIES - retries + 1}] Retrying "{query_name}" due to {str(e) or repr(e)}')
364+
'[%s] Retrying "%s" due to: %s',
365+
Session.MAX_RETRIES - retries + 1,
366+
query_name, str(e) or repr(e)
367+
)
377368

378369
await asyncio.sleep(0.5)
379370

0 commit comments

Comments
 (0)