Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/mcp/client/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ async def stdin_writer():
except ProcessLookupError: # pragma: no cover
# Process already exited, which is fine
pass

if process.stdout: # pragma: no branch
try:
await process.stdout.aclose()
except Exception: # pragma: no cover
pass
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
Expand Down
7 changes: 5 additions & 2 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.abc import TaskGroup, TaskStatus
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from pydantic import ValidationError

Expand Down Expand Up @@ -437,10 +437,13 @@ async def post_writer(
write_stream: ContextSendStream[SessionMessage],
start_get_stream: Callable[[], None],
tg: TaskGroup,
*,
task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED,
) -> None:
"""Handle writing requests to the server."""
try:
async with write_stream_reader, read_stream_writer, write_stream:
task_status.started(None)

async def _handle_message(session_message: SessionMessage) -> None:
message = session_message.message
Expand Down Expand Up @@ -570,7 +573,7 @@ async def streamable_http_client(
def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)

tg.start_soon(
await tg.start(
transport.post_writer,
client,
write_stream_reader,
Expand Down
39 changes: 39 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,45 @@ async def test_streamable_http_client_basic_connection(basic_server: None, basic
assert result.server_info.name == SERVER_NAME


@pytest.mark.anyio
async def test_streamable_http_client_no_race_on_consecutive_requests(basic_server: None, basic_server_url: str):
"""Regression test for a start-up race immediately after initialize().

In some cases, the first request after initialize() (e.g. list_tools())
could behave inconsistently. This test runs multiple short-lived sessions
to reliably catch any start-up race.
"""
for iteration in range(10): # pragma: no branch
async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()

tools = await session.list_tools()
assert len(tools.tools) == 10, f"Iteration {iteration}: expected 10 tools, got {len(tools.tools)}"
assert tools.tools[0].name == "test_tool"

tools2 = await session.list_tools()
assert len(tools2.tools) == 10

resource = await session.read_resource(uri="foobar://test-iteration")
assert len(resource.contents) == 1


@pytest.mark.anyio
async def test_streamable_http_client_rapid_request_sequence(basic_server: None, basic_server_url: str):
"""Stress test for rapid sequences of requests."""
async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()

for i in range(20):
tools = await session.list_tools()
assert len(tools.tools) == 10, f"Request {i}: expected 10 tools, got {len(tools.tools)}"

resource = await session.read_resource(uri="foobar://final-test")
assert len(resource.contents) == 1


@pytest.mark.anyio
async def test_streamable_http_client_resource_read(initialized_client_session: ClientSession):
"""Test client resource read functionality."""
Expand Down
Loading