Skip to content

Commit dc6c816

Browse files
committed
Revert some of the last changes
1 parent 0d11240 commit dc6c816

4 files changed

Lines changed: 47 additions & 60 deletions

File tree

pyrogram/connection/connection.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828

2929
class Connection:
30+
MAX_RETRIES = 3
31+
3032
MODES = {
3133
0: TCPFull,
3234
1: TCPAbridged,
@@ -35,7 +37,7 @@ class Connection:
3537
4: TCPIntermediateO
3638
}
3739

38-
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 1):
40+
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False, mode: int = 3):
3941
self.dc_id = dc_id
4042
self.test_mode = test_mode
4143
self.ipv6 = ipv6
@@ -45,18 +47,17 @@ def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media:
4547
self.mode = self.MODES.get(mode, TCPAbridged)
4648

4749
self.protocol = None # type: TCP
48-
self.is_connected = asyncio.Event()
4950

5051
async def connect(self):
51-
while True:
52+
for i in range(Connection.MAX_RETRIES):
5253
self.protocol = self.mode(self.ipv6, self.proxy)
5354

5455
try:
5556
log.info("Connecting...")
5657
await self.protocol.connect(self.address)
5758
except OSError as e:
58-
log.warning(f"Connection failed due to network issues: {e}")
59-
await self.protocol.close()
59+
log.warning(f"Unable to connect due to network issues: {e}")
60+
self.protocol.close()
6061
await asyncio.sleep(1)
6162
else:
6263
log.info("Connected! {} DC{}{} - IPv{} - {}".format(
@@ -67,21 +68,19 @@ async def connect(self):
6768
self.mode.__name__,
6869
))
6970
break
71+
else:
72+
log.warning("Connection failed! Trying again...")
73+
raise TimeoutError
7074

71-
self.is_connected.set()
72-
73-
async def close(self):
74-
await self.protocol.close()
75-
self.is_connected.clear()
75+
def close(self):
76+
self.protocol.close()
7677
log.info("Disconnected")
7778

78-
async def reconnect(self):
79-
await self.close()
80-
await self.connect()
81-
8279
async def send(self, data: bytes):
83-
await self.is_connected.wait()
84-
await self.protocol.send(data)
80+
try:
81+
await self.protocol.send(data)
82+
except Exception:
83+
raise OSError
8584

8685
async def recv(self) -> Optional[bytes]:
8786
return await self.protocol.recv()

pyrogram/connection/transport/tcp/tcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def connect(self, address: tuple):
8282
self.socket.connect(address)
8383
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
8484

85-
async def close(self):
85+
def close(self):
8686
try:
8787
self.writer.close()
8888
except AttributeError:

pyrogram/session/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,4 @@ async def create(self):
258258
else:
259259
return auth_key
260260
finally:
261-
await self.connection.close()
261+
self.connection.close()

pyrogram/session/session.py

Lines changed: 30 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
self.stored_msg_ids = []
8484

8585
self.ping_task = None
86+
self.ping_task_event = asyncio.Event()
8687

8788
self.network_task = None
8889

@@ -149,23 +150,17 @@ async def start(self):
149150
async def stop(self):
150151
self.is_connected.clear()
151152

152-
if self.ping_task:
153-
self.ping_task.cancel()
153+
self.ping_task_event.set()
154154

155-
try:
156-
await self.ping_task
157-
except asyncio.CancelledError:
158-
pass
155+
if self.ping_task is not None:
156+
await self.ping_task
159157

160-
if self.network_task:
161-
self.network_task.cancel()
158+
self.ping_task_event.clear()
162159

163-
try:
164-
await self.network_task
165-
except asyncio.CancelledError:
166-
pass
160+
self.connection.close()
167161

168-
await self.connection.close()
162+
if self.network_task:
163+
await self.network_task
169164

170165
for i in self.results.values():
171166
i.event.set()
@@ -194,7 +189,7 @@ async def handle_packet(self, packet):
194189
self.stored_msg_ids
195190
)
196191
except SecurityCheckMismatch:
197-
await self.connection.close()
192+
self.connection.close()
198193
return
199194

200195
messages = (
@@ -252,53 +247,46 @@ async def handle_packet(self, packet):
252247
self.pending_acks.clear()
253248

254249
async def ping_worker(self):
250+
log.info("PingTask started")
251+
255252
while True:
253+
try:
254+
await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL)
255+
except asyncio.TimeoutError:
256+
pass
257+
else:
258+
break
259+
256260
try:
257261
await self._send(
258262
raw.functions.PingDelayDisconnect(
259-
ping_id=0,
260-
disconnect_delay=self.WAIT_TIMEOUT + 10
263+
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
261264
), False
262265
)
263266
except (OSError, TimeoutError, RPCError):
264267
pass
265268

266-
await asyncio.sleep(self.PING_INTERVAL)
269+
log.info("PingTask stopped")
267270

268271
async def network_worker(self):
272+
log.info("NetworkTask started")
273+
269274
while True:
270275
packet = await self.connection.recv()
271276

272-
if not packet:
273-
await self.connection.reconnect()
277+
if packet is None or len(packet) == 4:
278+
if packet:
279+
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
274280

275-
try:
276-
await self._send(
277-
raw.functions.InvokeWithLayer(
278-
layer=layer,
279-
query=raw.functions.InitConnection(
280-
api_id=self.client.api_id,
281-
app_version=self.client.app_version,
282-
device_model=self.client.device_model,
283-
system_version=self.client.system_version,
284-
system_lang_code=self.client.lang_code,
285-
lang_code=self.client.lang_code,
286-
lang_pack="",
287-
query=raw.functions.help.GetConfig(),
288-
)
289-
),
290-
wait_response=False
291-
)
292-
except (OSError, TimeoutError, RPCError):
293-
pass
281+
if self.is_connected.is_set():
282+
self.loop.create_task(self.restart())
294283

295-
continue
296-
297-
if len(packet) == 4:
298-
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
284+
break
299285

300286
self.loop.create_task(self.handle_packet(packet))
301287

288+
log.info("NetworkTask stopped")
289+
302290
async def _send(self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
303291
message = self.msg_factory(data)
304292
msg_id = message.msg_id

0 commit comments

Comments
 (0)