Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,10 +774,17 @@ async def terminate(self) -> None:

Once terminated, all requests with this session ID will receive 404 Not Found.
"""
if self._terminated:
return

self._terminated = True
logger.info(f"Terminating session: {self.mcp_session_id}")

sse_stream_writers = list(self._sse_stream_writers.values())
for writer in sse_stream_writers:
writer.close()
self._sse_stream_writers.clear()

# We need a copy of the keys to avoid modification during iteration
request_stream_keys = list(self._request_streams.keys())

Expand Down
5 changes: 5 additions & 0 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
yield # Let the application run
finally:
logger.info("StreamableHTTP session manager shutting down")
for transport in list(self._server_instances.values()):
try:
await transport.terminate()
except Exception: # pragma: no cover
logger.debug("Error terminating StreamableHTTP transport during shutdown", exc_info=True)
# Cancel task group to stop all spawned tasks
tg.cancel_scope.cancel()
self._task_group = None
Expand Down
29 changes: 29 additions & 0 deletions tests/server/test_streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ async def try_run():
assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0])


@pytest.mark.anyio
async def test_run_terminates_active_transports_before_shutdown():
app = Server("test-server")
manager = StreamableHTTPSessionManager(app=app)
transport = AsyncMock()

async with manager.run():
manager._server_instances["session-id"] = transport

transport.terminate.assert_awaited_once_with()
assert not manager._server_instances


@pytest.mark.anyio
async def test_handle_request_without_run_raises_error():
"""Test that handle_request raises error if run() hasn't been called."""
Expand Down Expand Up @@ -269,6 +282,22 @@ async def mock_receive():
assert len(transport._request_streams) == 0, "Transport should have no active request streams"


@pytest.mark.anyio
async def test_transport_terminate_closes_active_sse_writers():
transport = StreamableHTTPServerTransport(mcp_session_id="session-id")
writer, reader = anyio.create_memory_object_stream[dict[str, str]](1)
transport._sse_stream_writers["request-id"] = writer

await transport.terminate()

assert transport.is_terminated
assert not transport._sse_stream_writers
with pytest.raises(anyio.ClosedResourceError):
writer.send_nowait({"event": "message", "data": "{}"})

await reader.aclose()


@pytest.mark.anyio
async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture):
"""Test that requests with unknown session IDs return HTTP 404 per MCP spec."""
Expand Down
Loading