Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions src/mcp/shared/direct_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""In-memory `Dispatcher` that wires two peers together with no transport.

`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a
request on one side directly invokes the other side's `on_request`. There is no
serialization, no JSON-RPC framing, and no streams. It exists to:

* prove the `Dispatcher` Protocol is implementable without JSON-RPC
* provide a fast substrate for testing the layers above the dispatcher
(`ServerRunner`, `Context`, `Connection`) without wire-level moving parts
* embed a server in-process when the JSON-RPC overhead is unnecessary

Unlike `JSONRPCDispatcher`, exceptions raised in a handler propagate directly
to the caller — there is no exception-to-`ErrorData` boundary here.
"""

from __future__ import annotations

from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass, field
from typing import Any

import anyio

from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT
from mcp.shared.exceptions import MCPError, NoBackChannelError
from mcp.shared.transport_context import TransportContext
from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT

__all__ = ["DirectDispatcher", "create_direct_dispatcher_pair"]

DIRECT_TRANSPORT_KIND = "direct"


_Request = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]]
_Notify = Callable[[str, Mapping[str, Any] | None], Awaitable[None]]


@dataclass
class _DirectDispatchContext:
"""`DispatchContext` for an inbound request on a `DirectDispatcher`.

The back-channel callables target the *originating* side, so a handler's
`send_request` reaches the peer that made the inbound request.
"""

transport: TransportContext
_back_request: _Request
_back_notify: _Notify
_on_progress: ProgressFnT | None = None
cancel_requested: anyio.Event = field(default_factory=anyio.Event)

async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
await self._back_notify(method, params)

async def send_request(
self,
method: str,
params: Mapping[str, Any] | None,
opts: CallOptions | None = None,
) -> dict[str, Any]:
if not self.transport.can_send_request:
raise NoBackChannelError(method)
return await self._back_request(method, params, opts)

async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
if self._on_progress is not None:
await self._on_progress(progress, total, message)


class DirectDispatcher:
"""A `Dispatcher` that calls a peer's handlers directly, in-process.

Two instances are wired together with `create_direct_dispatcher_pair`; each
holds a reference to the other. `send_request` on one awaits the peer's
`on_request`. `run` parks until `close` is called.
"""

def __init__(self, transport_ctx: TransportContext):
self._transport_ctx = transport_ctx
self._peer: DirectDispatcher | None = None
self._on_request: OnRequest | None = None
self._on_notify: OnNotify | None = None
self._ready = anyio.Event()
self._closed = anyio.Event()

def connect_to(self, peer: DirectDispatcher) -> None:
self._peer = peer

async def send_request(
self,
method: str,
params: Mapping[str, Any] | None,
opts: CallOptions | None = None,
) -> dict[str, Any]:
if self._peer is None:
raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()")
return await self._peer._dispatch_request(method, params, opts)

async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
if self._peer is None:
raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()")
await self._peer._dispatch_notify(method, params)

async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None:
self._on_request = on_request
self._on_notify = on_notify
self._ready.set()
await self._closed.wait()

def close(self) -> None:
self._closed.set()

def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispatchContext:
assert self._peer is not None
peer = self._peer
return _DirectDispatchContext(
transport=self._transport_ctx,
_back_request=lambda m, p, o: peer._dispatch_request(m, p, o),
_back_notify=lambda m, p: peer._dispatch_notify(m, p),
_on_progress=on_progress,
)

async def _dispatch_request(
self,
method: str,
params: Mapping[str, Any] | None,
opts: CallOptions | None,
) -> dict[str, Any]:
await self._ready.wait()
assert self._on_request is not None
opts = opts or {}
dctx = self._make_context(on_progress=opts.get("on_progress"))
try:
with anyio.fail_after(opts.get("timeout")):
try:
return await self._on_request(dctx, method, params)
except MCPError:
raise
except Exception as e:
raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e
except TimeoutError:
raise MCPError(
code=REQUEST_TIMEOUT,
message=f"Timed out after {opts.get('timeout')}s waiting for {method!r}",
) from None

async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) -> None:
await self._ready.wait()
assert self._on_notify is not None
dctx = self._make_context()
await self._on_notify(dctx, method, params)


def create_direct_dispatcher_pair(
*,
can_send_request: bool = True,
) -> tuple[DirectDispatcher, DirectDispatcher]:
"""Create two `DirectDispatcher` instances wired to each other.

Args:
can_send_request: Sets `TransportContext.can_send_request` on both
sides. Pass ``False`` to simulate a transport with no back-channel.

Returns:
A ``(left, right)`` pair. Conventionally ``left`` is the client side
and ``right`` is the server side, but the wiring is symmetric.
"""
ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request)
left = DirectDispatcher(ctx)
right = DirectDispatcher(ctx)
left.connect_to(right)
right.connect_to(left)
return left, right
145 changes: 145 additions & 0 deletions src/mcp/shared/dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Dispatcher Protocol — the call/return boundary between transports and handlers.

A Dispatcher turns a duplex message channel into two things:

* an outbound API: ``send_request(method, params)`` and ``notify(method, params)``
* an inbound pump: ``run(on_request, on_notify)`` that drives the receive loop
and invokes the supplied handlers for each incoming request/notification

It is deliberately *not* MCP-aware. Method names are strings, params and
results are ``dict[str, Any]``. The MCP type layer (request/result models,
capability negotiation, ``Context``) sits above this; the wire encoding
(JSON-RPC, gRPC, in-process direct calls) sits below it.

See ``JSONRPCDispatcher`` for the production implementation and
``DirectDispatcher`` for an in-memory implementation used in tests and for
embedding a server in-process.
"""

from collections.abc import Awaitable, Callable, Mapping
from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable

import anyio

from mcp.shared.transport_context import TransportContext

__all__ = [
"CallOptions",
"DispatchContext",
"DispatchMiddleware",
"Dispatcher",
"OnNotify",
"OnRequest",
"Outbound",
"ProgressFnT",
]

TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True)


class ProgressFnT(Protocol):
"""Callback invoked when a progress notification arrives for a pending request."""

async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ...


class CallOptions(TypedDict, total=False):
"""Per-call options for `Outbound.send_request`.

All keys are optional. Dispatchers ignore keys they do not understand.
"""

timeout: float
"""Seconds to wait for a result before raising and sending ``notifications/cancelled``."""

on_progress: ProgressFnT
"""Receive ``notifications/progress`` updates for this request."""

resumption_token: str
"""Opaque token to resume a previously interrupted request (transport-dependent)."""

on_resumption_token: Callable[[str], Awaitable[None]]
"""Receive a resumption token when the transport issues one."""


@runtime_checkable
class Outbound(Protocol):
"""Anything that can send requests and notifications to the peer.

Both `Dispatcher` (top-level outbound) and `DispatchContext` (back-channel
during an inbound request) extend this. `PeerMixin` wraps an `Outbound` to
provide typed MCP request/notification methods.
"""

async def send_request(
self,
method: str,
params: Mapping[str, Any] | None,
opts: CallOptions | None = None,
) -> dict[str, Any]:
"""Send a request and await its result.

Raises:
MCPError: If the peer responded with an error, or the handler
raised. Implementations normalize all handler exceptions to
`MCPError` so callers see a single exception type.
"""
...

async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
"""Send a fire-and-forget notification."""
...


class DispatchContext(Outbound, Protocol[TransportT_co]):
"""Per-request context handed to ``on_request`` / ``on_notify``.

Carries the transport metadata for the inbound message and provides the
back-channel for sending requests/notifications to the peer while handling
it. `send_request` raises `NoBackChannelError` if
``transport.can_send_request`` is ``False``.
"""

@property
def transport(self) -> TransportT_co:
"""Transport-specific metadata for this inbound message."""
...

@property
def cancel_requested(self) -> anyio.Event:
"""Set when the peer sends ``notifications/cancelled`` for this request."""
...

async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
"""Report progress for the inbound request, if the peer supplied a progress token.

A no-op when no token was supplied.
"""
...


OnRequest = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]]
"""Handler for inbound requests: ``(ctx, method, params) -> result``. Raise ``MCPError`` to send an error response."""

OnNotify = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[None]]
"""Handler for inbound notifications: ``(ctx, method, params)``."""

DispatchMiddleware = Callable[[OnRequest], OnRequest]
"""Wraps an ``OnRequest`` to produce another ``OnRequest``. Applied outermost-first."""


class Dispatcher(Outbound, Protocol[TransportT_co]):
"""A duplex request/notification channel with call-return semantics.

Implementations own correlation of outbound requests to inbound results, the
receive loop, per-request concurrency, and cancellation/progress wiring.
"""

async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None:
"""Drive the receive loop until the underlying channel closes.

Each inbound request is dispatched to ``on_request`` in its own task;
the returned dict (or raised ``MCPError``) is sent back as the response.
Inbound notifications go to ``on_notify``.
"""
...
21 changes: 20 additions & 1 deletion src/mcp/shared/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, cast

from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError
from mcp.types import INVALID_REQUEST, URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError


class MCPError(Exception):
Expand Down Expand Up @@ -41,6 +41,25 @@ def __str__(self) -> str:
return self.message


class NoBackChannelError(MCPError):
"""Raised when sending a server-initiated request over a transport that cannot deliver it.

Stateless HTTP and JSON-response-mode HTTP have no channel for the server to
push requests (sampling, elicitation, roots/list) to the client. This is
raised by `DispatchContext.send_request` when `transport.can_send_request`
is ``False``, and serializes to an ``INVALID_REQUEST`` error response.
"""

def __init__(self, method: str):
super().__init__(
code=INVALID_REQUEST,
message=(
f"Cannot send {method!r}: this transport context has no back-channel for server-initiated requests."
),
)
self.method = method


class StatelessModeNotSupported(RuntimeError):
"""Raised when attempting to use a method that is not supported in stateless mode.

Expand Down
30 changes: 30 additions & 0 deletions src/mcp/shared/transport_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Transport-specific metadata attached to each inbound message.

`TransportContext` is the base; each transport defines its own subclass with
whatever fields make sense (HTTP request id, ASGI scope, stdio process handle,
etc.). The dispatcher passes it through opaquely; only the layers above the
dispatcher (`ServerRunner`, `Context`, user handlers) read its concrete fields.
"""

from dataclasses import dataclass

__all__ = ["TransportContext"]


@dataclass(kw_only=True, frozen=True)
class TransportContext:
"""Base transport metadata for an inbound message.

Subclass per transport and add fields as needed. Instances are immutable.
"""

kind: str
"""Short identifier for the transport (e.g. ``"stdio"``, ``"streamable-http"``)."""

can_send_request: bool
"""Whether the transport can deliver server-initiated requests to the peer.

``False`` for stateless HTTP and HTTP with JSON response mode; ``True`` for
stdio, SSE, and stateful streamable HTTP. When ``False``,
`DispatchContext.send_request` raises `NoBackChannelError`.
"""
Loading
Loading