Skip to content

Commit d2f3cd9

Browse files
authored
Implement non-blocking TCP connection (
1 parent bedca34 commit d2f3cd9

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

pyrogram/connection/connection.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
import asyncio
2020
import logging
21-
from typing import Optional
21+
from typing import Optional, Type
2222

23-
from .transport import TCP, TCPAbridgedO
23+
from .transport import TCP, TCPAbridged
2424
from ..session.internals import DataCenter
2525

2626
log = logging.getLogger(__name__)
@@ -29,26 +29,37 @@
2929
class Connection:
3030
MAX_CONNECTION_ATTEMPTS = 3
3131

32-
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False):
32+
def __init__(
33+
self,
34+
dc_id: int,
35+
test_mode: bool,
36+
ipv6: bool,
37+
alt_port: bool,
38+
proxy: dict,
39+
media: bool = False,
40+
protocol_factory: Type[TCP] = TCPAbridged
41+
) -> None:
3342
self.dc_id = dc_id
3443
self.test_mode = test_mode
3544
self.ipv6 = ipv6
45+
self.alt_port = alt_port
3646
self.proxy = proxy
3747
self.media = media
48+
self.protocol_factory = protocol_factory
3849

39-
self.address = DataCenter(dc_id, test_mode, ipv6, media)
40-
self.protocol: TCP = None
50+
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
51+
self.protocol: Optional[TCP] = None
4152

42-
async def connect(self):
53+
async def connect(self) -> None:
4354
for i in range(Connection.MAX_CONNECTION_ATTEMPTS):
44-
self.protocol = TCPAbridgedO(self.ipv6, self.proxy)
55+
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy)
4556

4657
try:
4758
log.info("Connecting...")
4859
await self.protocol.connect(self.address)
4960
except OSError as e:
5061
log.warning("Unable to connect due to network issues: %s", e)
51-
self.protocol.close()
62+
await self.protocol.close()
5263
await asyncio.sleep(1)
5364
else:
5465
log.info("Connected! %s DC%s%s - IPv%s",
@@ -59,17 +70,14 @@ async def connect(self):
5970
break
6071
else:
6172
log.warning("Connection failed! Trying again...")
62-
raise TimeoutError
73+
raise ConnectionError
6374

64-
def close(self):
65-
self.protocol.close()
75+
async def close(self) -> None:
76+
await self.protocol.close()
6677
log.info("Disconnected")
6778

68-
async def send(self, data: bytes):
69-
try:
70-
await self.protocol.send(data)
71-
except Exception as e:
72-
raise OSError(e)
79+
async def send(self, data: bytes) -> None:
80+
await self.protocol.send(data)
7381

7482
async def recv(self) -> Optional[bytes]:
7583
return await self.protocol.recv()

0 commit comments

Comments
 (0)