diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 35a83fcf1b..b0c8c8e6fc 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -122,11 +122,11 @@ async def respond(self, response: SendResultT | ErrorData) -> None: Must be called within a context manager block. Raises: RuntimeError: If not used within a context manager - AssertionError: If request was already responded to """ if not self._entered: # pragma: no cover raise RuntimeError("RequestResponder must be used as a context manager") - assert not self._completed, "Request already responded to" + if self._completed: + return if not self.cancelled: # pragma: no branch self._completed = True @@ -143,6 +143,9 @@ async def cancel(self) -> None: raise RuntimeError("No active cancel scope") self._cancel_scope.cancel() + if self._completed: + return + self._completed = True # Mark as completed so it's removed from in_flight # Send an error response to indicate cancellation await self._session._send_response( # type: ignore[reportPrivateUsage] diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index f4010141d8..df12c21b1f 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,5 +1,6 @@ from collections.abc import AsyncGenerator from typing import Any +from unittest.mock import AsyncMock, MagicMock import anyio import pytest @@ -10,11 +11,13 @@ from mcp.shared.exceptions import McpError from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder from mcp.types import ( CancelledNotification, CancelledNotificationParams, ClientNotification, ClientRequest, + ClientResult, EmptyResult, ErrorData, JSONRPCError, @@ -30,6 +33,20 @@ def mcp_server() -> Server: return Server(name="test server") +def make_request_responder() -> tuple[RequestResponder[ClientRequest, ClientResult], MagicMock]: + mock_session = MagicMock() + mock_session._send_response = AsyncMock() + request = ClientRequest(types.PingRequest()) + responder: RequestResponder[ClientRequest, ClientResult] = RequestResponder( + request_id=1, + request_meta=None, + request=request, + session=mock_session, + on_complete=lambda responder: None, + ) + return responder, mock_session + + @pytest.fixture async def client_connected_to_server( mcp_server: Server, @@ -128,6 +145,33 @@ async def make_request(client_session: ClientSession): await ev_cancelled.wait() +@pytest.mark.anyio +async def test_request_responder_respond_after_cancel_does_not_raise(): + responder, mock_session = make_request_responder() + + with responder: + await responder.cancel() + await responder.respond(ClientResult(root=EmptyResult())) + + mock_session._send_response.assert_awaited_once() + assert mock_session._send_response.await_args.kwargs["response"] == ErrorData( + code=0, message="Request cancelled", data=None + ) + + +@pytest.mark.anyio +async def test_request_responder_cancel_after_respond_does_not_send_error(): + responder, mock_session = make_request_responder() + response = ClientResult(root=EmptyResult()) + + with responder: + await responder.respond(response) + await responder.cancel() + + mock_session._send_response.assert_awaited_once() + assert mock_session._send_response.await_args.kwargs["response"] == response + + @pytest.mark.anyio async def test_response_id_type_mismatch_string_to_int(): """