Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,26 +1219,31 @@ async def _create_connection_transport(
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None, context=None):

sock.setblocking(False)
context = context if context is not None else contextvars.copy_context()

protocol = protocol_factory()
waiter = self.create_future()
if ssl:
sslcontext = None if isinstance(ssl, bool) else ssl
transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout,
context=context)
else:
transport = self._make_socket_transport(sock, protocol, waiter, context=context)

# gh-153133: close the socket if the transport is never created.
transport = None
try:
sock.setblocking(False)
context = context if context is not None else contextvars.copy_context()

protocol = protocol_factory()
waiter = self.create_future()
if ssl:
sslcontext = None if isinstance(ssl, bool) else ssl
transport = self._make_ssl_transport(
sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout,
context=context)
else:
transport = self._make_socket_transport(sock, protocol, waiter, context=context)

await waiter
except:
transport.close()
if transport is None:
sock.close()
else:
transport.close()
raise

return transport, protocol
Expand Down
41 changes: 41 additions & 0 deletions Lib/test/test_asyncio/test_base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,47 @@ def getaddrinfo(*args, **kw):
self.loop.run_until_complete(coro)
self.assertTrue(sock.close.called)

def test_create_connection_sock_transport_error_closes_sock(self):
# gh-153133: a user-provided socket is closed if the transport is
# never created.
sock = mock.Mock()
sock.type = socket.SOCK_STREAM

def factory():
raise ZeroDivisionError

coro = self.loop.create_connection(factory, sock=sock)
with self.assertRaises(ZeroDivisionError):
self.loop.run_until_complete(coro)
self.assertTrue(sock.close.called)

@patch_socket
def test_create_connection_transport_error_closes_sock(self, m_socket):
# gh-153133: an internally created socket is closed if the transport
# is never created.
sock = mock.Mock()
m_socket.socket.return_value = sock

def getaddrinfo(*args, **kw):
fut = self.loop.create_future()
addr = (socket.AF_INET, socket.SOCK_STREAM, 0, '',
('127.0.0.1', 80))
fut.set_result([addr])
return fut
self.loop.getaddrinfo = getaddrinfo

async def sock_connect(sock, address):
return None

def factory():
raise ZeroDivisionError

with mock.patch.object(self.loop, 'sock_connect', sock_connect):
coro = self.loop.create_connection(factory, '127.0.0.1', 80)
with self.assertRaises(ZeroDivisionError):
self.loop.run_until_complete(coro)
self.assertTrue(sock.close.called)

@patch_socket
def test_create_connection_happy_eyeballs_empty_exceptions(self, m_socket):
# See gh-135836: Fix IndexError when Happy Eyeballs algorithm
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix a socket leak in :meth:`asyncio.loop.create_connection` when the
transport cannot be created.
Loading