diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f8aec6c9e..05c5b9e1f 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -774,10 +774,17 @@ async def terminate(self) -> None: Once terminated, all requests with this session ID will receive 404 Not Found. """ + if self._terminated: + return self._terminated = True logger.info(f"Terminating session: {self.mcp_session_id}") + sse_stream_writers = list(self._sse_stream_writers.values()) + for writer in sse_stream_writers: + writer.close() + self._sse_stream_writers.clear() + # We need a copy of the keys to avoid modification during iteration request_stream_keys = list(self._request_streams.keys()) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 648dcc827..e9fbbc9d0 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -145,6 +145,11 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: yield # Let the application run finally: logger.info("StreamableHTTP session manager shutting down") + for transport in list(self._server_instances.values()): + try: + await transport.terminate() + except Exception: # pragma: no cover + logger.debug("Error terminating StreamableHTTP transport during shutdown", exc_info=True) # Cancel task group to stop all spawned tasks tg.cancel_scope.cancel() self._task_group = None diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 0e8afed50..68662fa07 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -64,6 +64,19 @@ async def try_run(): assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0]) +@pytest.mark.anyio +async def test_run_terminates_active_transports_before_shutdown(): + app = Server("test-server") + manager = StreamableHTTPSessionManager(app=app) + transport = AsyncMock() + + async with manager.run(): + manager._server_instances["session-id"] = transport + + transport.terminate.assert_awaited_once_with() + assert not manager._server_instances + + @pytest.mark.anyio async def test_handle_request_without_run_raises_error(): """Test that handle_request raises error if run() hasn't been called.""" @@ -269,6 +282,22 @@ async def mock_receive(): assert len(transport._request_streams) == 0, "Transport should have no active request streams" +@pytest.mark.anyio +async def test_transport_terminate_closes_active_sse_writers(): + transport = StreamableHTTPServerTransport(mcp_session_id="session-id") + writer, reader = anyio.create_memory_object_stream[dict[str, str]](1) + transport._sse_stream_writers["request-id"] = writer + + await transport.terminate() + + assert transport.is_terminated + assert not transport._sse_stream_writers + with pytest.raises(anyio.ClosedResourceError): + writer.send_nowait({"event": "message", "data": "{}"}) + + await reader.aclose() + + @pytest.mark.anyio async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture): """Test that requests with unknown session IDs return HTTP 404 per MCP spec."""