Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add tests
  • Loading branch information
kumaraditya303 committed Mar 21, 2026
commit b344e60dab74dfa71f7cefb3f542602cfb3c9294
5 changes: 3 additions & 2 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,11 @@ def _accept_connection(
raise # The event loop will catch, log and ignore it.
else:
extra = {'peername': addr}
conn_context = context.copy() if context is not None else None
accept = self._accept_connection2(
protocol_factory, conn, extra, sslcontext, server,
ssl_handshake_timeout, ssl_shutdown_timeout, context=context)
self.create_task(accept, context=context)
ssl_handshake_timeout, ssl_shutdown_timeout, context=conn_context)
self.create_task(accept, context=conn_context)

async def _accept_connection2(
self, protocol_factory, conn, extra,
Expand Down
264 changes: 264 additions & 0 deletions Lib/test/test_asyncio/test_server_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@

import asyncio
import contextvars
import unittest

from unittest import TestCase

def tearDownModule():
asyncio.events._set_event_loop_policy(None)

class ServerContextvarsTestCase:
loop_factory = None # To be defined in subclasses

def run_coro(self, coro):
return asyncio.run(coro, loop_factory=self.loop_factory)

def test_start_server1(self):
# Test that asyncio.start_server captures the context at the time of server creation
async def test():
var = contextvars.ContextVar("var", default="default")

async def handle_client(reader, writer):
value = var.get()
writer.write(value.encode())
await writer.drain()
writer.close()

server = await asyncio.start_server(handle_client, '127.0.0.1', 0)
# change the value
var.set("after_server")

async def client(addr):
reader, writer = await asyncio.open_connection(*addr)
data = await reader.read(100)
writer.close()
await writer.wait_closed()
return data.decode()

async with server:
addr = server.sockets[0].getsockname()
self.assertEqual(await client(addr), "default")

self.assertEqual(var.get(), "after_server")

self.run_coro(test())

def test_start_server2(self):
# Test that mutations to the context in one handler don't affect other handlers or the server's context
async def test():
var = contextvars.ContextVar("var", default="default")

async def handle_client(reader, writer):
value = var.get()
writer.write(value.encode())
var.set("in_handler")
await writer.drain()
writer.close()

server = await asyncio.start_server(handle_client, '127.0.0.1', 0)
var.set("after_server")

async def client(addr):
reader, writer = await asyncio.open_connection(*addr)
data = await reader.read(100)
writer.close()
await writer.wait_closed()
return data.decode()

async with server:
addr = server.sockets[0].getsockname()
self.assertEqual(await client(addr), "default")
self.assertEqual(await client(addr), "default")
self.assertEqual(await client(addr), "default")

self.assertEqual(var.get(), "after_server")

self.run_coro(test())

def test_start_server3(self):
# Test that mutations to context in concurrent handlers don't affect each other or the server's context
async def test():
var = contextvars.ContextVar("var", default="default")
var.set("before_server")

async def handle_client(reader, writer):
writer.write(var.get().encode())
await writer.drain()
writer.close()

server = await asyncio.start_server(handle_client, '127.0.0.1', 0)
var.set("after_server")

async def client(addr):
reader, writer = await asyncio.open_connection(*addr)
data = await reader.read(100)
self.assertEqual(data.decode(), "before_server")
writer.close()
await writer.wait_closed()

async with server:
addr = server.sockets[0].getsockname()
async with asyncio.TaskGroup() as tg:
for _ in range(100):
tg.create_task(client(addr))

self.assertEqual(var.get(), "after_server")

self.run_coro(test())

def test_create_server1(self):
# Test that loop.create_server captures the context at the time of server creation
# and that mutations to the context in protocol callbacks don't affect the server's context
async def test():
var = contextvars.ContextVar("var", default="default")

class EchoProtocol(asyncio.Protocol):
def connection_made(self, transport):
self.transport = transport
value = var.get()
var.set("in_handler")
self.transport.write(value.encode())
self.transport.close()

server = await asyncio.get_running_loop().create_server(
lambda: EchoProtocol(), '127.0.0.1', 0)
var.set("after_server")

async def client(addr):
reader, writer = await asyncio.open_connection(*addr)
data = await reader.read(100)
self.assertEqual(data.decode(), "default")
writer.close()
await writer.wait_closed()

async with server:
addr = server.sockets[0].getsockname()
await client(addr)

self.assertEqual(var.get(), "after_server")

self.run_coro(test())

def test_create_server2(self):
# Test that mutations to context in one protocol instance don't affect other instances or the server's context
async def test():
var = contextvars.ContextVar("var", default="default")

class EchoProtocol(asyncio.Protocol):
def __init__(self):
super().__init__()
assert var.get() == "default", var.get()
def connection_made(self, transport):
self.transport = transport
value = var.get()
var.set("in_handler")
self.transport.write(value.encode())
self.transport.close()

server = await asyncio.get_running_loop().create_server(
lambda: EchoProtocol(), '127.0.0.1', 0)

var.set("after_server")

async def client(addr, expected):
reader, writer = await asyncio.open_connection(*addr)
data = await reader.read(100)
self.assertEqual(data.decode(), expected)
writer.close()
await writer.wait_closed()

async with server:
addr = server.sockets[0].getsockname()
await client(addr, "default")
await client(addr, "default")

self.assertEqual(var.get(), "after_server")

self.run_coro(test())

def test_gh140947(self):
# See https://github.com/python/cpython/issues/140947

cvar1 = contextvars.ContextVar("cvar1")
cvar2 = contextvars.ContextVar("cvar2")
cvar3 = contextvars.ContextVar("cvar3")
results = {}

def capture_context(meth):
result = []
for k,v in contextvars.copy_context().items():
result.append((k.name, v))
results[meth] = sorted(result)

class DemoProtocol(asyncio.Protocol):
def __init__(self, on_conn_lost):
self.transport = None
self.on_conn_lost = on_conn_lost
self.tasks = set()

def connection_made(self, transport):
capture_context("connection_made")
self.transport = transport

def data_received(self, data):
capture_context("data_received")

task = asyncio.create_task(self.asgi())
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)

self.transport.pause_reading()

def connection_lost(self, exc):
capture_context("connection_lost")
if not self.on_conn_lost.done():
self.on_conn_lost.set_result(True)

async def asgi(self):
capture_context("asgi start")

cvar1.set(True)

# make sure that we only resume after the pause
# otherwise the resume does nothing
while not self.transport._paused:
await asyncio.sleep(0.1)

cvar2.set(True)

self.transport.resume_reading()

cvar3.set(True)

capture_context("asgi end")


async def main():
loop = asyncio.get_running_loop()
on_conn_lost = loop.create_future()

host, port = "127.0.0.1", 8888

async with await loop.create_server(lambda: DemoProtocol(on_conn_lost), host, port):
reader, writer = await asyncio.open_connection(host, port)
writer.write(b"anything")
await writer.drain()
writer.close()
await writer.wait_closed()
await on_conn_lost
self.run_coro(main())
self.assertDictEqual(results, {
"connection_made": [],
"data_received": [],
"asgi start": [],
"asgi end": [("cvar1", True), ("cvar2", True), ("cvar3", True)],
"connection_lost": [],
})


class AsyncioEventLoopTests(TestCase, ServerContextvarsTestCase):
loop_factory = staticmethod(asyncio.new_event_loop)

if __name__ == "__main__":
unittest.main()
Loading