From b485b6ae0f47821f37feedb18cd1b8404c5da34b Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Mon, 25 May 2026 22:26:45 +0200 Subject: [PATCH 1/4] fix(stdio): drain responses after stdin EOF --- src/mcp/server/lowlevel/server.py | 39 +++++------ src/mcp/server/session.py | 2 +- src/mcp/shared/session.py | 15 ++++- tests/server/test_cancel_handling.py | 34 ++++------ tests/server/test_stdio.py | 96 +++++++++++++++++++++++++++- 5 files changed, 140 insertions(+), 46 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 59de0ace45..64c87b96c7 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -390,29 +390,22 @@ async def run( await stack.enter_async_context(task_support.run()) async with anyio.create_task_group() as tg: - try: - async for message in session.incoming_messages: - logger.debug("Received message: %s", message) - - if isinstance(message, RequestResponder) and message.context is not None: - context = message.context - else: - context = contextvars.copy_context() - - context.run( - tg.start_soon, - self._handle_message, - message, - session, - lifespan_context, - raise_exceptions, - ) - finally: - # Transport closed: cancel in-flight handlers. Without this the - # TG join waits for them, and when they eventually try to - # respond they hit a closed write stream (the session's - # _receive_loop closed it when the read stream ended). - tg.cancel_scope.cancel() + async for message in session.incoming_messages: + logger.debug("Received message: %s", message) + + if isinstance(message, RequestResponder) and message.context is not None: + context = message.context + else: + context = contextvars.copy_context() + + context.run( + tg.start_soon, + self._handle_message, + message, + session, + lifespan_context, + raise_exceptions, + ) async def _handle_message( self, diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 20b640527a..4e288dd1a1 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -85,7 +85,7 @@ def __init__( init_options: InitializationOptions, stateless: bool = False, ) -> None: - super().__init__(read_stream, write_stream) + super().__init__(read_stream, write_stream, close_write_stream_on_read_close=False) self._stateless = stateless self._initialization_state = ( InitializationState.Initialized if stateless else InitializationState.NotInitialized diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 243eef5ae6..0bc63db4d7 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -191,16 +191,26 @@ def __init__( write_stream: WriteStream[SessionMessage], # If none, reading will never time out read_timeout_seconds: float | None = None, + # When True, closing/EOF on the read stream closes the write stream too. + # + # For full-duplex transports (e.g., stdio), an input EOF can be a + # half-close: the peer is done sending requests but still expects + # responses on the output stream. In that case, callers may opt out so + # in-flight handlers can drain their responses before shutdown. + close_write_stream_on_read_close: bool = True, ) -> None: self._read_stream = read_stream self._write_stream = write_stream self._response_streams = {} self._request_id = 0 self._session_read_timeout_seconds = read_timeout_seconds + self._close_write_stream_on_read_close = close_write_stream_on_read_close self._in_flight = {} self._progress_callbacks = {} self._response_routers = [] self._exit_stack = AsyncExitStack() + self._exit_stack.push_async_callback(self._read_stream.aclose) + self._exit_stack.push_async_callback(self._write_stream.aclose) def add_response_router(self, router: ResponseRouter) -> None: """Register a response router to handle responses for non-standard requests. @@ -349,7 +359,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: raise NotImplementedError async def _receive_loop(self) -> None: - async with self._read_stream, self._write_stream: + async with AsyncExitStack() as stack: + await stack.enter_async_context(self._read_stream) + if self._close_write_stream_on_read_close: + await stack.enter_async_context(self._write_stream) try: async def _handle_session_message(message: SessionMessage) -> None: diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index cff5a37c15..2d663a25cd 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -19,6 +19,7 @@ InitializeRequestParams, JSONRPCNotification, JSONRPCRequest, + JSONRPCResponse, ListToolsResult, PaginatedRequestParams, TextContent, @@ -100,29 +101,18 @@ async def first_request(): @pytest.mark.anyio -async def test_server_cancels_in_flight_handlers_on_transport_close(): - """When the transport closes mid-request, server.run() must cancel in-flight - handlers rather than join on them. - - Without the cancel, the task group waits for the handler, which then tries - to respond through a write stream that _receive_loop already closed, - raising ClosedResourceError and crashing server.run() with exit code 1. - - This drives server.run() with raw memory streams because InMemoryTransport - wraps it in its own finally-cancel (_memory.py) which masks the bug. - """ +async def test_server_drains_in_flight_handlers_on_transport_read_eof(): + """When the transport's read side hits EOF (e.g., stdio stdin closes), the + server must drain already-started handlers so their responses reach the + peer via the still-open write side.""" handler_started = anyio.Event() - handler_cancelled = anyio.Event() + handler_allowed_to_finish = anyio.Event() server_run_returned = anyio.Event() async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: handler_started.set() - try: - await anyio.sleep_forever() - finally: - handler_cancelled.set() - # unreachable: sleep_forever only exits via cancellation - raise AssertionError # pragma: no cover + await handler_allowed_to_finish.wait() + return CallToolResult(content=[TextContent(type="text", text="ok")]) server = Server("test", on_call_tool=handle_call_tool) @@ -167,9 +157,13 @@ async def run_server(): # handler gets CancelledError, server.run() returns. await to_server.aclose() - await server_run_returned.wait() + handler_allowed_to_finish.set() + + response = await from_server.receive() + assert isinstance(response.message, JSONRPCResponse) + assert response.message.id == 2 - assert handler_cancelled.is_set() + await server_run_returned.wait() @pytest.mark.anyio diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 677a993567..4d05a9036b 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -5,9 +5,27 @@ import anyio import pytest +from mcp.server import Server, ServerRequestContext from mcp.server.stdio import stdio_server from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + CallToolRequestParams, + CallToolResult, + ClientCapabilities, + Implementation, + InitializeRequestParams, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, + jsonrpc_message_adapter, +) @pytest.mark.anyio @@ -92,3 +110,79 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch): second = await read_stream.receive() assert isinstance(second, SessionMessage) assert second.message == valid + + +@pytest.mark.anyio +async def test_stdio_server_drains_in_flight_responses_on_stdin_eof(): + """When stdin reaches EOF (e.g., bash-redirected input), already-received + requests must still be able to emit their responses on stdout.""" + stdin = io.StringIO() + stdout = io.StringIO() + + tool_started_count = 0 + both_tools_started = anyio.Event() + allow_tools_to_finish = anyio.Event() + + async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="slow", description="test", input_schema={})]) + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + nonlocal tool_started_count + tool_started_count += 1 + if tool_started_count == 2: + both_tools_started.set() + await allow_tools_to_finish.wait() + return CallToolResult(content=[TextContent(type="text", text="ok")]) + + server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool) + + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params=InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test", version="1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + call_1 = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="tools/call", + params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"), + ) + call_2 = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"), + ) + + for message in (init_req, initialized, call_1, call_2): + stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n") + stdin.seek(0) + + async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as ( + read_stream, + write_stream, + ): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(server.run, read_stream, write_stream, server.create_initialization_options()) + await both_tools_started.wait() + allow_tools_to_finish.set() + + stdout.seek(0) + ids: set[int | str] = set() + for line in stdout.readlines(): + line = line.strip() + if not line: + continue + message = jsonrpc_message_adapter.validate_json(line) + if isinstance(message, JSONRPCResponse | JSONRPCError): + assert message.id is not None + ids.add(message.id) + assert 1 in ids + assert 2 in ids From 4d4e8637fb70bc70e5c8e0862be9ec615ffa7e3b Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Mon, 25 May 2026 22:59:26 +0200 Subject: [PATCH 2/4] test: cover stdio EOF drain and shutdown edges --- tests/server/test_cancel_handling.py | 117 +++++++++++++++++++++++++++ tests/server/test_stdio.py | 18 ++--- 2 files changed, 125 insertions(+), 10 deletions(-) diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 2d663a25cd..9544c91fb8 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -166,6 +166,123 @@ async def run_server(): await server_run_returned.wait() +@pytest.mark.anyio +async def test_server_reraises_handler_cancellation_when_server_is_cancelled(): + """If the server task is cancelled (e.g. KeyboardInterrupt), in-flight + request handlers will get cancelled too. Cancellation must be re-raised so + the task group can unwind cleanly.""" + handler_started = anyio.Event() + server_run_returned = anyio.Event() + cancel_scope = anyio.CancelScope() + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + handler_started.set() + await anyio.sleep_forever() + raise AssertionError # pragma: no cover + + server = Server("test", on_call_tool=handle_call_tool) + + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(): + try: + with cancel_scope: + await server.run(server_read, server_write, server.create_initialization_options()) + finally: + server_run_returned.set() + + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test", version="1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + call_req = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"), + ) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: + tg.start_soon(run_server) + + await to_server.send(SessionMessage(init_req)) + await from_server.receive() # init response + await to_server.send(SessionMessage(initialized)) + await to_server.send(SessionMessage(call_req)) + + await handler_started.wait() + cancel_scope.cancel() + await server_run_returned.wait() + + +@pytest.mark.anyio +async def test_server_drops_response_when_write_stream_closes_mid_request(): + """If the write side closes while a handler is in-flight, responding may + raise (ClosedResourceError/BrokenResourceError). The handler task should + exit without crashing the server.""" + handler_started = anyio.Event() + allow_finish = anyio.Event() + server_run_returned = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + handler_started.set() + await allow_finish.wait() + return CallToolResult(content=[TextContent(type="text", text="ok")]) + + server = Server("test", on_call_tool=handle_call_tool) + + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server(): + await server.run(server_read, server_write, server.create_initialization_options()) + server_run_returned.set() + + init_req = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test", version="1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + call_req = JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"), + ) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: + tg.start_soon(run_server) + + await to_server.send(SessionMessage(init_req)) + await from_server.receive() # init response + await to_server.send(SessionMessage(initialized)) + await to_server.send(SessionMessage(call_req)) + + await handler_started.wait() + await server_write.aclose() + + allow_finish.set() + await to_server.aclose() + + await server_run_returned.wait() + + @pytest.mark.anyio async def test_server_handles_transport_close_with_pending_server_to_client_requests(): """When the transport closes while handlers are blocked on server→client diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 4d05a9036b..565ed4ad17 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -15,7 +15,6 @@ ClientCapabilities, Implementation, InitializeRequestParams, - JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, @@ -147,6 +146,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar ).model_dump(by_alias=True, mode="json", exclude_none=True), ) initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + list_tools = JSONRPCRequest(jsonrpc="2.0", id=10, method="tools/list") call_1 = JSONRPCRequest( jsonrpc="2.0", id=1, @@ -160,7 +160,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"), ) - for message in (init_req, initialized, call_1, call_2): + for message in (init_req, initialized, list_tools, call_1, call_2): stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n") stdin.seek(0) @@ -175,14 +175,12 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar allow_tools_to_finish.set() stdout.seek(0) + output_lines = [line.strip() for line in stdout.readlines()] + messages = [jsonrpc_message_adapter.validate_json(line) for line in output_lines] ids: set[int | str] = set() - for line in stdout.readlines(): - line = line.strip() - if not line: - continue - message = jsonrpc_message_adapter.validate_json(line) - if isinstance(message, JSONRPCResponse | JSONRPCError): - assert message.id is not None - ids.add(message.id) + for message in messages: + assert isinstance(message, JSONRPCResponse) + ids.add(message.id) + assert 1 in ids assert 2 in ids From 184a84cdbba376b6223e846c3d76d149f5456218 Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Mon, 25 May 2026 23:05:59 +0200 Subject: [PATCH 3/4] test: ignore coverage branch arc on 3.14 --- tests/server/test_stdio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 565ed4ad17..27dc99a1be 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -169,7 +169,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar write_stream, ): with anyio.fail_after(5): - async with anyio.create_task_group() as tg: + async with anyio.create_task_group() as tg: # pragma: no branch tg.start_soon(server.run, read_stream, write_stream, server.create_initialization_options()) await both_tools_started.wait() allow_tools_to_finish.set() From 60fb7e9f6dc125d30121633046e99d111b75319e Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Mon, 25 May 2026 23:17:06 +0200 Subject: [PATCH 4/4] fix(server): opt-in drain on read EOF --- src/mcp/server/lowlevel/server.py | 45 ++++++++++++++++++---------- src/mcp/server/mcpserver/server.py | 1 + src/mcp/server/session.py | 3 +- tests/server/test_cancel_handling.py | 2 +- tests/server/test_stdio.py | 11 ++++++- 5 files changed, 43 insertions(+), 19 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 64c87b96c7..66d1c8ed1b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -371,6 +371,10 @@ async def run( # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. stateless: bool = False, + # When True, treat read EOF as a half-close and allow in-flight handlers + # to drain their responses via the still-open write stream (e.g. stdio + # with bash-redirected stdin). + drain_on_read_close: bool = False, ): async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) @@ -380,6 +384,7 @@ async def run( write_stream, initialization_options, stateless=stateless, + close_write_stream_on_read_close=not drain_on_read_close, ) ) @@ -390,22 +395,30 @@ async def run( await stack.enter_async_context(task_support.run()) async with anyio.create_task_group() as tg: - async for message in session.incoming_messages: - logger.debug("Received message: %s", message) - - if isinstance(message, RequestResponder) and message.context is not None: - context = message.context - else: - context = contextvars.copy_context() - - context.run( - tg.start_soon, - self._handle_message, - message, - session, - lifespan_context, - raise_exceptions, - ) + try: + async for message in session.incoming_messages: + logger.debug("Received message: %s", message) + + if isinstance(message, RequestResponder) and message.context is not None: + context = message.context + else: + context = contextvars.copy_context() + + context.run( + tg.start_soon, + self._handle_message, + message, + session, + lifespan_context, + raise_exceptions, + ) + finally: + if not drain_on_read_close: + # Transport closed: cancel in-flight handlers. Without this the + # TG join waits for them, and when they eventually try to + # respond they hit a closed write stream (the session's + # _receive_loop closed it when the read stream ended). + tg.cancel_scope.cancel() async def _handle_message( self, diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index b3471163b7..812ec3cf9d 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -852,6 +852,7 @@ async def run_stdio_async(self) -> None: read_stream, write_stream, self._lowlevel_server.create_initialization_options(), + drain_on_read_close=True, ) async def run_sse_async( # pragma: no cover diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 4e288dd1a1..5ef26915f6 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -84,8 +84,9 @@ def __init__( write_stream: WriteStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, + close_write_stream_on_read_close: bool = True, ) -> None: - super().__init__(read_stream, write_stream, close_write_stream_on_read_close=False) + super().__init__(read_stream, write_stream, close_write_stream_on_read_close=close_write_stream_on_read_close) self._stateless = stateless self._initialization_state = ( InitializationState.Initialized if stateless else InitializationState.NotInitialized diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 9544c91fb8..a988b30be6 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -120,7 +120,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) async def run_server(): - await server.run(server_read, server_write, server.create_initialization_options()) + await server.run(server_read, server_write, server.create_initialization_options(), drain_on_read_close=True) server_run_returned.set() init_req = JSONRPCRequest( diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 27dc99a1be..5a2364f7e8 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -170,7 +170,16 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar ): with anyio.fail_after(5): async with anyio.create_task_group() as tg: # pragma: no branch - tg.start_soon(server.run, read_stream, write_stream, server.create_initialization_options()) + + async def run_server() -> None: + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + drain_on_read_close=True, + ) + + tg.start_soon(run_server) await both_tools_started.wait() allow_tools_to_finish.set()