From f1c0e22c976c1f97a356b89ffaf53f3f8b52016e Mon Sep 17 00:00:00 2001 From: Asti1982 <65121113+Asti1982@users.noreply.github.com> Date: Tue, 26 May 2026 06:10:58 +0200 Subject: [PATCH] Drain stdio responses after redirected stdin EOF --- src/mcp/server/lowlevel/server.py | 22 ++++-- src/mcp/server/mcpserver/server.py | 8 ++- src/mcp/server/session.py | 7 +- src/mcp/shared/session.py | 13 +++- tests/issues/test_2678_stdio_eof_drain.py | 83 +++++++++++++++++++++++ 5 files changed, 121 insertions(+), 12 deletions(-) create mode 100644 tests/issues/test_2678_stdio_eof_drain.py diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 59de0ace45..761bca1116 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -39,7 +39,7 @@ async def main(): import contextvars import logging import warnings -from collections.abc import AsyncIterator, Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from importlib.metadata import version as importlib_version from typing import Any, Generic, cast @@ -85,7 +85,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals @asynccontextmanager -async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]: +async def lifespan(_: Server[LifespanResultT]) -> AsyncGenerator[dict[str, Any]]: """Default lifespan context manager that does nothing. Returns: @@ -371,6 +371,10 @@ async def run( # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. stateless: bool = False, + # When True, stdin/file-style EOF is treated as "no more inbound messages"; + # accepted request handlers are allowed to finish and flush their responses. + drain_in_flight_on_read_eof: bool = False, + drain_in_flight_on_read_eof_timeout_seconds: float = 5.0, ): async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) @@ -380,6 +384,7 @@ async def run( write_stream, initialization_options, stateless=stateless, + close_write_stream_on_read_end=not drain_in_flight_on_read_eof, ) ) @@ -408,11 +413,14 @@ async def run( raise_exceptions, ) finally: - # Transport closed: cancel in-flight handlers. Without this the - # TG join waits for them, and when they eventually try to - # respond they hit a closed write stream (the session's - # _receive_loop closed it when the read stream ended). - tg.cancel_scope.cancel() + if not drain_in_flight_on_read_eof: + # Transport closed: cancel in-flight handlers. Without this the + # TG join waits for them, and when they eventually try to + # respond they hit a closed write stream (the session's + # _receive_loop closed it when the read stream ended). + tg.cancel_scope.cancel() + else: + tg.cancel_scope.deadline = anyio.current_time() + drain_in_flight_on_read_eof_timeout_seconds async def _handle_message( self, diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index b3471163b7..83221cc017 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -6,7 +6,7 @@ import inspect import json import re -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence +from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager from typing import Any, Generic, Literal, TypeVar, overload @@ -74,6 +74,8 @@ logger = get_logger(__name__) +STDIO_EOF_DRAIN_TIMEOUT_SECONDS = 5.0 + _CallableT = TypeVar("_CallableT", bound=Callable[..., Any]) @@ -119,7 +121,7 @@ def lifespan_wrapper( lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]], ) -> Callable[[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]]: @asynccontextmanager - async def wrap(_: Server[LifespanResultT]) -> AsyncIterator[LifespanResultT]: + async def wrap(_: Server[LifespanResultT]) -> AsyncGenerator[LifespanResultT]: async with lifespan(app) as context: yield context @@ -852,6 +854,8 @@ async def run_stdio_async(self) -> None: read_stream, write_stream, self._lowlevel_server.create_initialization_options(), + drain_in_flight_on_read_eof=True, + drain_in_flight_on_read_eof_timeout_seconds=STDIO_EOF_DRAIN_TIMEOUT_SECONDS, ) async def run_sse_async( # pragma: no cover diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 20b640527a..7bf0579a19 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -84,8 +84,13 @@ def __init__( write_stream: WriteStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, + close_write_stream_on_read_end: bool = True, ) -> None: - super().__init__(read_stream, write_stream) + super().__init__( + read_stream, + write_stream, + close_write_stream_on_read_end=close_write_stream_on_read_end, + ) self._stateless = stateless self._initialization_state = ( InitializationState.Initialized if stateless else InitializationState.NotInitialized diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 243eef5ae6..fb7b6eb20c 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -191,9 +191,11 @@ def __init__( write_stream: WriteStream[SessionMessage], # If none, reading will never time out read_timeout_seconds: float | None = None, + close_write_stream_on_read_end: bool = True, ) -> None: self._read_stream = read_stream self._write_stream = write_stream + self._close_write_stream_on_read_end = close_write_stream_on_read_end self._response_streams = {} self._request_id = 0 self._session_read_timeout_seconds = read_timeout_seconds @@ -234,7 +236,11 @@ async def __aexit__( # would be very surprising behavior), so make sure to cancel the tasks # in the task group. self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + try: + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + finally: + if not self._close_write_stream_on_read_end: + await self._write_stream.aclose() async def send_request( self, @@ -349,7 +355,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: raise NotImplementedError async def _receive_loop(self) -> None: - async with self._read_stream, self._write_stream: + async with AsyncExitStack() as stack: + await stack.enter_async_context(self._read_stream) + if self._close_write_stream_on_read_end: + await stack.enter_async_context(self._write_stream) try: async def _handle_session_message(message: SessionMessage) -> None: diff --git a/tests/issues/test_2678_stdio_eof_drain.py b/tests/issues/test_2678_stdio_eof_drain.py new file mode 100644 index 0000000000..b839605444 --- /dev/null +++ b/tests/issues/test_2678_stdio_eof_drain.py @@ -0,0 +1,83 @@ +import json +import subprocess +import sys +import textwrap +from pathlib import Path + + +def test_stdio_redirected_stdin_eof_drains_accepted_tool_responses(tmp_path: Path) -> None: + server_py = tmp_path / "server.py" + payload_jsonl = tmp_path / "payload.jsonl" + response_jsonl = tmp_path / "response.jsonl" + + server_py.write_text( + textwrap.dedent( + """ + import asyncio + + from mcp.server.mcpserver import MCPServer + + mcp = MCPServer("repro") + + @mcp.tool() + async def slow_echo(text: str) -> str: + await asyncio.sleep(0.05) + return text + + if __name__ == "__main__": + mcp.run(transport="stdio") + """ + ), + encoding="utf-8", + ) + payload_jsonl.write_text( + "\n".join( + [ + json.dumps( + { + "jsonrpc": "2.0", + "id": 0, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "repro", "version": "0.1"}, + }, + } + ), + json.dumps({"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}), + json.dumps( + { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": {"name": "slow_echo", "arguments": {"text": "first"}}, + } + ), + json.dumps( + { + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": {"name": "slow_echo", "arguments": {"text": "second"}}, + } + ), + ] + ) + + "\n", + encoding="utf-8", + ) + + with payload_jsonl.open("rb") as stdin, response_jsonl.open("wb") as stdout: + completed = subprocess.run( + [sys.executable, str(server_py)], + stdin=stdin, + stdout=stdout, + stderr=subprocess.PIPE, + timeout=10, + check=False, + ) + + assert completed.returncode == 0, completed.stderr.decode("utf-8", errors="replace") + response_ids = {json.loads(line)["id"] for line in response_jsonl.read_text(encoding="utf-8").splitlines()} + assert {0, 1, 2}.issubset(response_ids)