Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mcp.shared.response_router import ResponseRouter
from mcp.types import (
CONNECTION_CLOSED,
INTERNAL_ERROR,
INVALID_PARAMS,
REQUEST_TIMEOUT,
CancelledNotification,
Expand Down Expand Up @@ -184,6 +185,7 @@ class BaseSession(
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT]
_response_routers: list[ResponseRouter]
_propagate_errors: dict[RequestId, BaseException]

def __init__(
self,
Expand All @@ -201,6 +203,7 @@ def __init__(
self._progress_callbacks = {}
self._response_routers = []
self._exit_stack = AsyncExitStack()
self._propagate_errors = {}

def add_response_router(self, router: ResponseRouter) -> None:
"""Register a response router to handle responses for non-standard requests.
Expand Down Expand Up @@ -295,6 +298,11 @@ async def send_request(
class_name = request.__class__.__name__
message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds."
raise MCPError(code=REQUEST_TIMEOUT, message=message)
except anyio.EndOfStream:
propagate = self._propagate_errors.pop(request_id, None)
if propagate is not None:
raise propagate from None
raise

if isinstance(response_or_error, JSONRPCError):
raise MCPError.from_jsonrpc_error(response_or_error)
Expand Down Expand Up @@ -374,7 +382,20 @@ async def _handle_session_message(message: SessionMessage) -> None:

if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
except Exception:
except Exception as e:
if getattr(e, "__mcp_propagate__", False):
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.id,
error=ErrorData(code=INTERNAL_ERROR, message="Handler raised", data=""),
)
await self._write_stream.send(SessionMessage(message=error_response))
self._in_flight.pop(message.message.id, None)
for in_flight_id, stream in list(self._response_streams.items()):
self._propagate_errors[in_flight_id] = e
await stream.aclose()
return

# For request validation errors, send a proper JSON-RPC error
# response instead of crashing the server
logging.warning("Failed to validate request", exc_info=True)
Expand Down Expand Up @@ -451,7 +472,7 @@ async def _handle_session_message(message: SessionMessage) -> None:
try:
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
except Exception: # pragma: no cover
except Exception:
# Stream might already be closed
pass
self._response_streams.clear()
Expand Down
114 changes: 114 additions & 0 deletions tests/shared/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from mcp import Client, types
from mcp.client.session import ClientSession
from mcp.server import Server, ServerRequestContext
from mcp.shared._context import RequestContext
from mcp.shared.exceptions import MCPError
from mcp.shared.memory import create_client_server_memory_streams
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
from mcp.types import (
INTERNAL_ERROR,
PARSE_ERROR,
CancelledNotification,
CancelledNotificationParams,
Expand Down Expand Up @@ -416,3 +418,115 @@ async def make_request(client_session: ClientSession):
# Pending request completed successfully
assert len(result_holder) == 1
assert isinstance(result_holder[0], EmptyResult)


@pytest.mark.anyio
async def test_callback_exception_propagation():
"""Verify that exceptions raised in callbacks with __mcp_propagate__ = True
are propagated to the awaiter of send_request, and result in INTERNAL_ERROR to peer.
"""

class CustomPropagatedException(Exception):
__mcp_propagate__ = True

ev_server_received_error = anyio.Event()
server_error_holder: list[JSONRPCError] = []

async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, server_write = server_streams

async def mock_server():
# Wait for client's ping request
msg = await server_read.receive()
assert isinstance(msg, SessionMessage)
assert isinstance(msg.message, JSONRPCRequest)

# Trigger list_roots callback on client by sending roots/list request
roots_request = JSONRPCRequest(
jsonrpc="2.0",
id=1,
method="roots/list",
)
await server_write.send(SessionMessage(message=roots_request))

# Receive the client's response (which should be an error due to propagated exception)
response_msg = await server_read.receive()
assert isinstance(response_msg, SessionMessage)
assert isinstance(response_msg.message, JSONRPCError)
server_error_holder.append(response_msg.message)
ev_server_received_error.set()

async def mock_list_roots(context: RequestContext[ClientSession]):
raise CustomPropagatedException("Callback error that should propagate")

async def make_request(client_session: ClientSession):
# Send a ping request and assert that CustomPropagatedException propagates to it
with pytest.raises(CustomPropagatedException) as exc_info:
await client_session.send_ping()
assert "Callback error that should propagate" in str(exc_info.value)

async with (
anyio.create_task_group() as tg,
ClientSession(
read_stream=client_read,
write_stream=client_write,
list_roots_callback=mock_list_roots,
) as client_session,
):
tg.start_soon(mock_server)
tg.start_soon(make_request, client_session)

with anyio.fail_after(2): # pragma: no branch
await ev_server_received_error.wait()

assert len(server_error_holder) == 1
assert server_error_holder[0].error.code == INTERNAL_ERROR


@pytest.mark.anyio
async def test_send_request_end_of_stream_without_propagated_error():
"""Ensure EndOfStream is surfaced when no propagated error is present."""
async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, _server_write = server_streams

async def mock_server(client_session: ClientSession):
message = await server_read.receive()
assert isinstance(message, SessionMessage)
assert isinstance(message.message, JSONRPCRequest)
response_stream = client_session._response_streams[message.message.id]
await response_stream.aclose()

async def make_request(client_session: ClientSession):
with pytest.raises(anyio.EndOfStream):
await client_session.send_ping()

async with (
anyio.create_task_group() as tg,
ClientSession(read_stream=client_read, write_stream=client_write) as client_session,
):
tg.start_soon(mock_server, client_session)
tg.start_soon(make_request, client_session)


@pytest.mark.anyio
async def test_receive_loop_handles_closed_response_stream():
"""Cover receive loop cleanup when a response stream is already closed."""
async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
_server_read, server_write = server_streams

async with ClientSession(read_stream=client_read, write_stream=client_write) as client_session:
response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](
1
)
await response_stream.aclose()
await response_stream_reader.aclose()
client_session._response_streams[0] = response_stream

server_write.close()

with anyio.fail_after(2): # pragma: no branch
while client_session._response_streams:
await anyio.sleep(0)
Loading