Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,10 @@ async def run(
# the initialization lifecycle, but can do so with any available node
# rather than requiring initialization for each connection.
stateless: bool = False,
# When True, treat read EOF as a half-close and allow in-flight handlers
# to drain their responses via the still-open write stream (e.g. stdio
# with bash-redirected stdin).
drain_on_read_close: bool = False,
):
async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self))
Expand All @@ -380,6 +384,7 @@ async def run(
write_stream,
initialization_options,
stateless=stateless,
close_write_stream_on_read_close=not drain_on_read_close,
)
)

Expand Down Expand Up @@ -408,11 +413,12 @@ async def run(
raise_exceptions,
)
finally:
# Transport closed: cancel in-flight handlers. Without this the
# TG join waits for them, and when they eventually try to
# respond they hit a closed write stream (the session's
# _receive_loop closed it when the read stream ended).
tg.cancel_scope.cancel()
if not drain_on_read_close:
# Transport closed: cancel in-flight handlers. Without this the
# TG join waits for them, and when they eventually try to
# respond they hit a closed write stream (the session's
# _receive_loop closed it when the read stream ended).
tg.cancel_scope.cancel()

async def _handle_message(
self,
Expand Down
1 change: 1 addition & 0 deletions src/mcp/server/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ async def run_stdio_async(self) -> None:
read_stream,
write_stream,
self._lowlevel_server.create_initialization_options(),
drain_on_read_close=True,
)

async def run_sse_async( # pragma: no cover
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ def __init__(
write_stream: WriteStream[SessionMessage],
init_options: InitializationOptions,
stateless: bool = False,
close_write_stream_on_read_close: bool = True,
) -> None:
super().__init__(read_stream, write_stream)
super().__init__(read_stream, write_stream, close_write_stream_on_read_close=close_write_stream_on_read_close)
self._stateless = stateless
self._initialization_state = (
InitializationState.Initialized if stateless else InitializationState.NotInitialized
Expand Down
15 changes: 14 additions & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,26 @@ def __init__(
write_stream: WriteStream[SessionMessage],
# If none, reading will never time out
read_timeout_seconds: float | None = None,
# When True, closing/EOF on the read stream closes the write stream too.
#
# For full-duplex transports (e.g., stdio), an input EOF can be a
# half-close: the peer is done sending requests but still expects
# responses on the output stream. In that case, callers may opt out so
# in-flight handlers can drain their responses before shutdown.
close_write_stream_on_read_close: bool = True,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
self._response_streams = {}
self._request_id = 0
self._session_read_timeout_seconds = read_timeout_seconds
self._close_write_stream_on_read_close = close_write_stream_on_read_close
self._in_flight = {}
self._progress_callbacks = {}
self._response_routers = []
self._exit_stack = AsyncExitStack()
self._exit_stack.push_async_callback(self._read_stream.aclose)
self._exit_stack.push_async_callback(self._write_stream.aclose)

def add_response_router(self, router: ResponseRouter) -> None:
"""Register a response router to handle responses for non-standard requests.
Expand Down Expand Up @@ -349,7 +359,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
raise NotImplementedError

async def _receive_loop(self) -> None:
async with self._read_stream, self._write_stream:
async with AsyncExitStack() as stack:
await stack.enter_async_context(self._read_stream)
if self._close_write_stream_on_read_close:
await stack.enter_async_context(self._write_stream)
try:

async def _handle_session_message(message: SessionMessage) -> None:
Expand Down
149 changes: 130 additions & 19 deletions tests/server/test_cancel_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
InitializeRequestParams,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ListToolsResult,
PaginatedRequestParams,
TextContent,
Expand Down Expand Up @@ -100,29 +101,142 @@ async def first_request():


@pytest.mark.anyio
async def test_server_cancels_in_flight_handlers_on_transport_close():
"""When the transport closes mid-request, server.run() must cancel in-flight
handlers rather than join on them.
async def test_server_drains_in_flight_handlers_on_transport_read_eof():
"""When the transport's read side hits EOF (e.g., stdio stdin closes), the
server must drain already-started handlers so their responses reach the
peer via the still-open write side."""
handler_started = anyio.Event()
handler_allowed_to_finish = anyio.Event()
server_run_returned = anyio.Event()

Without the cancel, the task group waits for the handler, which then tries
to respond through a write stream that _receive_loop already closed,
raising ClosedResourceError and crashing server.run() with exit code 1.
async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
handler_started.set()
await handler_allowed_to_finish.wait()
return CallToolResult(content=[TextContent(type="text", text="ok")])

This drives server.run() with raw memory streams because InMemoryTransport
wraps it in its own finally-cancel (_memory.py) which masks the bug.
"""
server = Server("test", on_call_tool=handle_call_tool)

to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)

async def run_server():
await server.run(server_read, server_write, server.create_initialization_options(), drain_on_read_close=True)
server_run_returned.set()

init_req = JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=InitializeRequestParams(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(),
client_info=Implementation(name="test", version="1.0"),
).model_dump(by_alias=True, mode="json", exclude_none=True),
)
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
call_req = JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
)

with anyio.fail_after(5):
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
tg.start_soon(run_server)

await to_server.send(SessionMessage(init_req))
await from_server.receive() # init response
await to_server.send(SessionMessage(initialized))
await to_server.send(SessionMessage(call_req))

await handler_started.wait()

# Close the server's input stream — this is what stdin EOF does.
# server.run()'s incoming_messages loop ends, finally-cancel fires,
# handler gets CancelledError, server.run() returns.
await to_server.aclose()

handler_allowed_to_finish.set()

response = await from_server.receive()
assert isinstance(response.message, JSONRPCResponse)
assert response.message.id == 2

await server_run_returned.wait()


@pytest.mark.anyio
async def test_server_reraises_handler_cancellation_when_server_is_cancelled():
"""If the server task is cancelled (e.g. KeyboardInterrupt), in-flight
request handlers will get cancelled too. Cancellation must be re-raised so
the task group can unwind cleanly."""
handler_started = anyio.Event()
handler_cancelled = anyio.Event()
server_run_returned = anyio.Event()
cancel_scope = anyio.CancelScope()

async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
handler_started.set()
await anyio.sleep_forever()
raise AssertionError # pragma: no cover

server = Server("test", on_call_tool=handle_call_tool)

to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10)
server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10)

async def run_server():
try:
await anyio.sleep_forever()
with cancel_scope:
await server.run(server_read, server_write, server.create_initialization_options())
finally:
handler_cancelled.set()
# unreachable: sleep_forever only exits via cancellation
raise AssertionError # pragma: no cover
server_run_returned.set()

init_req = JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="initialize",
params=InitializeRequestParams(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(),
client_info=Implementation(name="test", version="1.0"),
).model_dump(by_alias=True, mode="json", exclude_none=True),
)
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
call_req = JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
)

with anyio.fail_after(5):
async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server:
tg.start_soon(run_server)

await to_server.send(SessionMessage(init_req))
await from_server.receive() # init response
await to_server.send(SessionMessage(initialized))
await to_server.send(SessionMessage(call_req))

await handler_started.wait()
cancel_scope.cancel()
await server_run_returned.wait()


@pytest.mark.anyio
async def test_server_drops_response_when_write_stream_closes_mid_request():
"""If the write side closes while a handler is in-flight, responding may
raise (ClosedResourceError/BrokenResourceError). The handler task should
exit without crashing the server."""
handler_started = anyio.Event()
allow_finish = anyio.Event()
server_run_returned = anyio.Event()

async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
handler_started.set()
await allow_finish.wait()
return CallToolResult(content=[TextContent(type="text", text="ok")])

server = Server("test", on_call_tool=handle_call_tool)

Expand Down Expand Up @@ -161,16 +275,13 @@ async def run_server():
await to_server.send(SessionMessage(call_req))

await handler_started.wait()
await server_write.aclose()

# Close the server's input stream — this is what stdin EOF does.
# server.run()'s incoming_messages loop ends, finally-cancel fires,
# handler gets CancelledError, server.run() returns.
allow_finish.set()
await to_server.aclose()

await server_run_returned.wait()

assert handler_cancelled.is_set()


@pytest.mark.anyio
async def test_server_handles_transport_close_with_pending_server_to_client_requests():
Expand Down
103 changes: 102 additions & 1 deletion tests/server/test_stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,26 @@
import anyio
import pytest

from mcp.server import Server, ServerRequestContext
from mcp.server.stdio import stdio_server
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
from mcp.types import (
LATEST_PROTOCOL_VERSION,
CallToolRequestParams,
CallToolResult,
ClientCapabilities,
Implementation,
InitializeRequestParams,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ListToolsResult,
PaginatedRequestParams,
TextContent,
Tool,
jsonrpc_message_adapter,
)


@pytest.mark.anyio
Expand Down Expand Up @@ -92,3 +109,87 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch):
second = await read_stream.receive()
assert isinstance(second, SessionMessage)
assert second.message == valid


@pytest.mark.anyio
async def test_stdio_server_drains_in_flight_responses_on_stdin_eof():
"""When stdin reaches EOF (e.g., bash-redirected input), already-received
requests must still be able to emit their responses on stdout."""
stdin = io.StringIO()
stdout = io.StringIO()

tool_started_count = 0
both_tools_started = anyio.Event()
allow_tools_to_finish = anyio.Event()

async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult:
return ListToolsResult(tools=[Tool(name="slow", description="test", input_schema={})])

async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
nonlocal tool_started_count
tool_started_count += 1
if tool_started_count == 2:
both_tools_started.set()
await allow_tools_to_finish.wait()
return CallToolResult(content=[TextContent(type="text", text="ok")])

server = Server("test", on_list_tools=handle_list_tools, on_call_tool=handle_call_tool)

init_req = JSONRPCRequest(
jsonrpc="2.0",
id=0,
method="initialize",
params=InitializeRequestParams(
protocol_version=LATEST_PROTOCOL_VERSION,
capabilities=ClientCapabilities(),
client_info=Implementation(name="test", version="1.0"),
).model_dump(by_alias=True, mode="json", exclude_none=True),
)
initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")
list_tools = JSONRPCRequest(jsonrpc="2.0", id=10, method="tools/list")
call_1 = JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="tools/call",
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
)
call_2 = JSONRPCRequest(
jsonrpc="2.0",
id=2,
method="tools/call",
params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"),
)

for message in (init_req, initialized, list_tools, call_1, call_2):
stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n")
stdin.seek(0)

async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as (
read_stream,
write_stream,
):
with anyio.fail_after(5):
async with anyio.create_task_group() as tg: # pragma: no branch

async def run_server() -> None:
await server.run(
read_stream,
write_stream,
server.create_initialization_options(),
drain_on_read_close=True,
)

tg.start_soon(run_server)
await both_tools_started.wait()
allow_tools_to_finish.set()

stdout.seek(0)
output_lines = [line.strip() for line in stdout.readlines()]
messages = [jsonrpc_message_adapter.validate_json(line) for line in output_lines]
ids: set[int | str] = set()
for message in messages:
assert isinstance(message, JSONRPCResponse)
ids.add(message.id)

assert 1 in ids
assert 2 in ids
Loading