Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix ssl part
  • Loading branch information
kumaraditya303 committed Mar 21, 2026
commit 09df55ed3ef1ead05d86cc8e52da145336775c9d
6 changes: 4 additions & 2 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,8 @@ def _make_ssl_transport(
extra=None, server=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
call_connection_made=True):
call_connection_made=True,
context=None):
"""Create SSL transport."""
raise NotImplementedError

Expand Down Expand Up @@ -1228,7 +1229,8 @@ async def _create_connection_transport(
sock, protocol, sslcontext, waiter,
server_side=server_side, server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
ssl_shutdown_timeout=ssl_shutdown_timeout,
context=context)
else:
transport = self._make_socket_transport(sock, protocol, waiter, context=context)

Expand Down
8 changes: 5 additions & 3 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,17 @@ def _make_ssl_transport(
extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT,
ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT,
context=None,
):
self._ensure_fd_no_transport(rawsock)
ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter,
server_side, server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout
ssl_shutdown_timeout=ssl_shutdown_timeout,
)
_SelectorSocketTransport(self, rawsock, ssl_protocol,
extra=extra, server=server)
extra=extra, server=server, context=context)
return ssl_protocol._app_transport

def _make_datagram_transport(self, sock, protocol,
Expand Down Expand Up @@ -230,7 +231,8 @@ async def _accept_connection2(
conn, protocol, sslcontext, waiter=waiter,
server_side=True, extra=extra, server=server,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout)
ssl_shutdown_timeout=ssl_shutdown_timeout,
context=context)
else:
transport = self._make_socket_transport(
conn, protocol, waiter=waiter, extra=extra,
Expand Down
82 changes: 59 additions & 23 deletions Lib/test/test_asyncio/test_server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@

from unittest import TestCase

try:
import ssl
except ImportError:
ssl = None

from test.test_asyncio import utils as test_utils

def tearDownModule():
asyncio.events._set_event_loop_policy(None)

class ServerContextvarsTestCase:
loop_factory = None # To be defined in subclasses
server_ssl_context = None # To be defined in subclasses for SSL tests
client_ssl_context = None # To be defined in subclasses for SSL tests

def run_coro(self, coro):
return asyncio.run(coro, loop_factory=self.loop_factory)
Expand All @@ -25,12 +34,14 @@ async def handle_client(reader, writer):
await writer.drain()
writer.close()

server = await asyncio.start_server(handle_client, '127.0.0.1', 0)
server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
ssl=self.server_ssl_context)
# change the value
var.set("after_server")

async def client(addr):
reader, writer = await asyncio.open_connection(*addr)
reader, writer = await asyncio.open_connection(*addr,
ssl=self.client_ssl_context)
data = await reader.read(100)
writer.close()
await writer.wait_closed()
Expand All @@ -56,11 +67,13 @@ async def handle_client(reader, writer):
await writer.drain()
writer.close()

server = await asyncio.start_server(handle_client, '127.0.0.1', 0)
server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
ssl=self.server_ssl_context)
var.set("after_server")

async def client(addr):
reader, writer = await asyncio.open_connection(*addr)
reader, writer = await asyncio.open_connection(*addr,
ssl=self.client_ssl_context)
data = await reader.read(100)
writer.close()
await writer.wait_closed()
Expand All @@ -87,11 +100,13 @@ async def handle_client(reader, writer):
await writer.drain()
writer.close()

server = await asyncio.start_server(handle_client, '127.0.0.1', 0)
server = await asyncio.start_server(handle_client, '127.0.0.1', 0,
ssl=self.server_ssl_context)
var.set("after_server")

async def client(addr):
reader, writer = await asyncio.open_connection(*addr)
reader, writer = await asyncio.open_connection(*addr,
ssl=self.client_ssl_context)
data = await reader.read(100)
self.assertEqual(data.decode(), "before_server")
writer.close()
Expand Down Expand Up @@ -122,11 +137,13 @@ def connection_made(self, transport):
self.transport.close()

server = await asyncio.get_running_loop().create_server(
lambda: EchoProtocol(), '127.0.0.1', 0)
lambda: EchoProtocol(), '127.0.0.1', 0,
ssl=self.server_ssl_context)
var.set("after_server")

async def client(addr):
reader, writer = await asyncio.open_connection(*addr)
reader, writer = await asyncio.open_connection(*addr,
ssl=self.client_ssl_context)
data = await reader.read(100)
self.assertEqual(data.decode(), "default")
writer.close()
Expand Down Expand Up @@ -157,12 +174,14 @@ def connection_made(self, transport):
self.transport.close()

server = await asyncio.get_running_loop().create_server(
lambda: EchoProtocol(), '127.0.0.1', 0)
lambda: EchoProtocol(), '127.0.0.1', 0,
ssl=self.server_ssl_context)

var.set("after_server")

async def client(addr, expected):
reader, writer = await asyncio.open_connection(*addr)
reader, writer = await asyncio.open_connection(*addr,
ssl=self.client_ssl_context)
data = await reader.read(100)
self.assertEqual(data.decode(), expected)
writer.close()
Expand All @@ -184,6 +203,7 @@ def test_gh140947(self):
cvar2 = contextvars.ContextVar("cvar2")
cvar3 = contextvars.ContextVar("cvar3")
results = {}
is_ssl = self.server_ssl_context is not None

def capture_context(meth):
result = []
Expand Down Expand Up @@ -218,36 +238,37 @@ def connection_lost(self, exc):

async def asgi(self):
capture_context("asgi start")

cvar1.set(True)

# make sure that we only resume after the pause
# otherwise the resume does nothing
while not self.transport._paused:
await asyncio.sleep(0.1)

if is_ssl:
while not self.transport._ssl_protocol._app_reading_paused:
await asyncio.sleep(0.01)
else:
while not self.transport._paused:
await asyncio.sleep(0.01)
cvar2.set(True)

self.transport.resume_reading()

cvar3.set(True)

capture_context("asgi end")


async def main():
loop = asyncio.get_running_loop()
on_conn_lost = loop.create_future()

host, port = "127.0.0.1", 8888

async with await loop.create_server(lambda: DemoProtocol(on_conn_lost), host, port):
reader, writer = await asyncio.open_connection(host, port)
server = await loop.create_server(
lambda: DemoProtocol(on_conn_lost), '127.0.0.1', 0,
ssl=self.server_ssl_context)
async with server:
addr = server.sockets[0].getsockname()
reader, writer = await asyncio.open_connection(*addr,
ssl=self.client_ssl_context)
writer.write(b"anything")
await writer.drain()
writer.close()
await writer.wait_closed()
await on_conn_lost

self.run_coro(main())
self.assertDictEqual(results, {
"connection_made": [],
Expand All @@ -261,12 +282,27 @@ async def main():
class AsyncioEventLoopTests(TestCase, ServerContextvarsTestCase):
loop_factory = staticmethod(asyncio.new_event_loop)

@unittest.skipUnless(ssl, "SSL not available")
class AsyncioEventLoopSSLTests(AsyncioEventLoopTests):
server_ssl_context = test_utils.simple_server_sslcontext()
client_ssl_context = test_utils.simple_client_sslcontext()

if sys.platform == "win32":
class AsyncioProactorEventLoopTests(TestCase, ServerContextvarsTestCase):
loop_factory = asyncio.ProactorEventLoop

class AsyncioSelectorEventLoopTests(TestCase, ServerContextvarsTestCase):
loop_factory = asyncio.SelectorEventLoop

@unittest.skipUnless(ssl, "SSL not available")
class AsyncioProactorEventLoopSSLTests(AsyncioProactorEventLoopTests):
server_ssl_context = test_utils.simple_server_sslcontext()
client_ssl_context = test_utils.simple_client_sslcontext()

@unittest.skipUnless(ssl, "SSL not available")
class AsyncioSelectorEventLoopSSLTests(AsyncioSelectorEventLoopTests):
server_ssl_context = test_utils.simple_server_sslcontext()
client_ssl_context = test_utils.simple_client_sslcontext()

if __name__ == "__main__":
unittest.main()
Loading