@@ -56,6 +56,13 @@ def __init__(self, ipv6: bool, proxy: Proxy) -> None:
5656
5757 self .lock = asyncio .Lock ()
5858 self .loop = asyncio .get_running_loop ()
59+ self ._closed = True
60+
61+ @property
62+ def closed (self ) -> bool :
63+ return (
64+ self ._closed or self .writer is None or self .writer .is_closing () or self .reader is None
65+ )
5966
6067 async def _connect_via_proxy (self , destination : tuple [str , int ]) -> None :
6168 scheme = self .proxy .get ("scheme" )
@@ -108,45 +115,56 @@ async def _connect(self, destination: tuple[str, int]) -> None:
108115 async def connect (self , address : tuple [str , int ]) -> None :
109116 try :
110117 await asyncio .wait_for (self ._connect (address ), TCP .TIMEOUT )
118+ self ._closed = False
111119 except (
112120 asyncio .TimeoutError
113121 ): # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
122+ self ._closed = True
114123 raise TimeoutError ("Connection timed out" )
115124
116125 async def close (self ) -> None :
117126 if self .writer is None :
127+ self ._closed = True
118128 return
119129
120130 try :
121131 self .writer .close ()
122132 await asyncio .wait_for (self .writer .wait_closed (), TCP .TIMEOUT )
123133 except Exception as e :
124134 log .info ("Close exception: %s %s" , type (e ).__name__ , e )
135+ finally :
136+ self ._closed = True
125137
126138 async def send (self , data : bytes ) -> None :
127- if self .writer is None :
128- return
139+ if self .writer is None or self . _closed :
140+ raise OSError ( "Connection is closed" )
129141
130142 async with self .lock :
131143 try :
132144 self .writer .write (data )
133145 await self .writer .drain ()
134146 except Exception as e :
135147 log .info ("Send exception: %s %s" , type (e ).__name__ , e )
148+ self ._closed = True
136149 raise OSError (e ) from e
137150
138151 async def recv (self , length : int = 0 ) -> bytes | None :
152+ if self ._closed or self .reader is None :
153+ return None
154+
139155 data = b""
140156
141157 while len (data ) < length :
142158 try :
143159 chunk = await asyncio .wait_for (self .reader .read (length - len (data )), TCP .TIMEOUT )
144160 except (OSError , asyncio .TimeoutError ):
161+ self ._closed = True
145162 return None
146163 else :
147164 if chunk :
148165 data += chunk
149166 else :
167+ self ._closed = True
150168 return None
151169
152170 return data
0 commit comments