1818
1919import asyncio
2020import logging
21- from typing import Optional
21+ from typing import Optional , Type
2222
23- from .transport import TCP , TCPAbridgedO
23+ from .transport import TCP , TCPAbridged
2424from ..session .internals import DataCenter
2525
2626log = logging .getLogger (__name__ )
2929class Connection :
3030 MAX_CONNECTION_ATTEMPTS = 3
3131
32- def __init__ (self , dc_id : int , test_mode : bool , ipv6 : bool , proxy : dict , media : bool = False ):
32+ def __init__ (
33+ self ,
34+ dc_id : int ,
35+ test_mode : bool ,
36+ ipv6 : bool ,
37+ alt_port : bool ,
38+ proxy : dict ,
39+ media : bool = False ,
40+ protocol_factory : Type [TCP ] = TCPAbridged
41+ ) -> None :
3342 self .dc_id = dc_id
3443 self .test_mode = test_mode
3544 self .ipv6 = ipv6
45+ self .alt_port = alt_port
3646 self .proxy = proxy
3747 self .media = media
48+ self .protocol_factory = protocol_factory
3849
39- self .address = DataCenter (dc_id , test_mode , ipv6 , media )
40- self .protocol : TCP = None
50+ self .address = DataCenter (dc_id , test_mode , ipv6 , alt_port , media )
51+ self .protocol : Optional [ TCP ] = None
4152
42- async def connect (self ):
53+ async def connect (self ) -> None :
4354 for i in range (Connection .MAX_CONNECTION_ATTEMPTS ):
44- self .protocol = TCPAbridgedO ( self .ipv6 , self .proxy )
55+ self .protocol = self . protocol_factory ( ipv6 = self .ipv6 , proxy = self .proxy )
4556
4657 try :
4758 log .info ("Connecting..." )
4859 await self .protocol .connect (self .address )
4960 except OSError as e :
5061 log .warning ("Unable to connect due to network issues: %s" , e )
51- self .protocol .close ()
62+ await self .protocol .close ()
5263 await asyncio .sleep (1 )
5364 else :
5465 log .info ("Connected! %s DC%s%s - IPv%s" ,
@@ -59,17 +70,14 @@ async def connect(self):
5970 break
6071 else :
6172 log .warning ("Connection failed! Trying again..." )
62- raise TimeoutError
73+ raise ConnectionError
6374
64- def close (self ):
65- self .protocol .close ()
75+ async def close (self ) -> None :
76+ await self .protocol .close ()
6677 log .info ("Disconnected" )
6778
68- async def send (self , data : bytes ):
69- try :
70- await self .protocol .send (data )
71- except Exception as e :
72- raise OSError (e )
79+ async def send (self , data : bytes ) -> None :
80+ await self .protocol .send (data )
7381
7482 async def recv (self ) -> Optional [bytes ]:
7583 return await self .protocol .recv ()
0 commit comments