Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 99 additions & 86 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from functools import partial
from http import HTTPStatus
from typing import Any
from typing import Any, Final

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
Expand Down Expand Up @@ -60,13 +61,20 @@
# Special key for the standalone GET stream
GET_STREAM_KEY = "_GET_stream"

# Buffer for the per-request `_request_streams` so the serial `message_router`
# can deposit a response and move on instead of head-of-line blocking the
# whole session on a lazily-started `sse_writer`. See #1764.
REQUEST_STREAM_BUFFER_SIZE: Final = 16

# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E)
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")

# Type aliases
StreamId = str
EventId = str
# An SSE event-dict as accepted by sse-starlette (`event`, `data`, `id`, `retry`).
SSEEvent = dict[str, Any]


@dataclass
Expand Down Expand Up @@ -178,7 +186,7 @@
MemoryObjectReceiveStream[EventMessage],
],
] = {}
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {}
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[SSEEvent]] = {}
self._terminated = False
# Idle timeout cancel scope; managed by the session manager.
self.idle_scope: anyio.CancelScope | None = None
Expand Down Expand Up @@ -267,31 +275,48 @@

return SessionMessage(message, metadata=metadata)

async def _maybe_send_priming_event(
self,
request_id: RequestId,
sse_stream_writer: MemoryObjectSendStream[dict[str, Any]],
protocol_version: str,
) -> None:
"""Send priming event for SSE resumability if event_store is configured.
async def _mint_priming_event(self, stream_id: StreamId, protocol_version: str) -> SSEEvent | None:
"""Store the priming cursor for `stream_id` and return its SSE wire form.

Only sends priming events to clients with protocol version >= 2025-11-25,
which includes the fix for handling empty SSE data. Older clients would
crash trying to parse empty data as JSON.
Called before the request is dispatched so the priming row precedes
anything `message_router` can store for this stream. Returns `None`
when no event store is configured or the client predates 2025-11-25
(older clients cannot parse the empty-data event).
"""
if not self._event_store:
return
# Priming events have empty data which older clients cannot handle.
return None
if protocol_version < "2025-11-25":
return
priming_event_id = await self._event_store.store_event(
str(request_id), # Convert RequestId to StreamId (str)
None, # Priming event has no payload
)
priming_event: dict[str, str | int] = {"id": priming_event_id, "data": ""}
return None
priming_event_id = await self._event_store.store_event(stream_id, None)
priming_event: SSEEvent = {"id": priming_event_id, "data": ""}
if self._retry_interval is not None:
priming_event["retry"] = self._retry_interval
await sse_stream_writer.send(priming_event)
return priming_event

async def _run_sse_writer( # pragma: no cover
self,
request_id: RequestId,
sse_stream_writer: MemoryObjectSendStream[SSEEvent],
request_stream_reader: MemoryObjectReceiveStream[EventMessage],
priming_event: SSEEvent | None,
) -> None:
"""Forward `_request_streams[request_id]` onto the SSE wire for one POST."""
try:
async with sse_stream_writer, request_stream_reader:
if priming_event is not None:
await sse_stream_writer.send(priming_event)
async for event_message in request_stream_reader:
await sse_stream_writer.send(self._create_event_data(event_message))
if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError):
break
except anyio.ClosedResourceError:
logger.debug("SSE stream closed by close_sse_stream()")
except Exception:
logger.exception("Error in SSE writer")
finally:
logger.debug("Closing SSE writer")
self._sse_stream_writers.pop(request_id, None)
await self._clean_up_memory_streams(request_id)

def _create_error_response(
self,
Expand Down Expand Up @@ -348,7 +373,7 @@
"""Extract the session ID from request headers."""
return request.headers.get(MCP_SESSION_ID_HEADER)

def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: # pragma: no cover
def _create_event_data(self, event_message: EventMessage) -> SSEEvent: # pragma: no cover
"""Create event data dictionary from an EventMessage."""
event_data = {
"event": "message",
Expand Down Expand Up @@ -530,13 +555,13 @@
else request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)
)

# Extract the request ID outside the try block for proper scope
request_id = str(message.root.id) # pragma: no cover
# Register this stream for the request ID
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) # pragma: no cover
request_stream_reader = self._request_streams[request_id][1] # pragma: no cover
request_id = str(message.root.id)

if self.is_json_response_enabled: # pragma: no cover
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](
REQUEST_STREAM_BUFFER_SIZE
)
request_stream_reader = self._request_streams[request_id][1]
# Process the message
metadata = ServerMessageMetadata(request_context=request)
session_message = SessionMessage(message, metadata=metadata)
Expand Down Expand Up @@ -580,53 +605,30 @@
finally:
await self._clean_up_memory_streams(request_id)
else: # pragma: no cover
# Create SSE stream
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
# Mint the priming event before any per-request state exists:
# `EventStore.store_event` is user code and may raise, in which
# case the outer handler returns a 500 with nothing to clean up.
# Still strictly precedes dispatch, so storage order == wire order.
priming_event = await self._mint_priming_event(request_id, protocol_version)

# Store writer reference so close_sse_stream() can close it
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
self._sse_stream_writers[request_id] = sse_stream_writer
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](
REQUEST_STREAM_BUFFER_SIZE
)
request_stream_reader = self._request_streams[request_id][1]

async def sse_writer():
# Get the request ID from the incoming request message
try:
async with sse_stream_writer, request_stream_reader:
# Send priming event for SSE resumability
await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version)

# Process messages from the request-specific stream
async for event_message in request_stream_reader:
# Build the event data
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)

# If response, remove from pending streams and close
if isinstance(
event_message.message.root,
JSONRPCResponse | JSONRPCError,
):
break
except anyio.ClosedResourceError:
# Expected when close_sse_stream() is called
logger.debug("SSE stream closed by close_sse_stream()")
except Exception:
logger.exception("Error in SSE writer")
finally:
logger.debug("Closing SSE writer")
self._sse_stream_writers.pop(request_id, None)
await self._clean_up_memory_streams(request_id)

# Create and start EventSourceResponse
# SSE stream mode (original behavior)
# Set up headers
headers = {
"Cache-Control": "no-cache, no-transform",
"Connection": "keep-alive",

Check warning on line 623 in src/mcp/server/streamable_http.py

View check run for this annotation

Claude / Claude Code Review

_sse_stream_writers entry leaks when SSE setup fails before _run_sse_writer starts

If the SSE branch's "SSE response error" except path fires before sse-starlette ever invokes the data_sender_callable, the entry registered in `self._sse_stream_writers[request_id]` is never popped — only `_run_sse_writer`'s `finally` and `close_sse_stream()` remove it — leaving a stale closed writer in the per-session dict for the transport's lifetime. This gap also existed in the pre-PR code, but since this block is rewritten here, adding `self._sse_stream_writers.pop(request_id, None)` to the
Comment thread
maxisbey marked this conversation as resolved.
"Content-Type": CONTENT_TYPE_SSE,
**({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}),
}
response = EventSourceResponse(
content=sse_stream_reader,
data_sender_callable=sse_writer,
data_sender_callable=partial(
self._run_sse_writer, request_id, sse_stream_writer, request_stream_reader, priming_event
),
headers=headers,
)

Expand All @@ -644,16 +646,15 @@
await sse_stream_reader.aclose()
await self._clean_up_memory_streams(request_id)

except Exception as err: # pragma: no cover
except Exception as err:
logger.exception("Error handling POST request")
response = self._create_error_response(
f"Error handling POST request: {err}",
"Error handling POST request",
HTTPStatus.INTERNAL_SERVER_ERROR,
INTERNAL_ERROR,
)
await response(scope, receive, send)
if writer:
await writer.send(Exception(err))
await writer.send(Exception(err))
return

async def _handle_get_request(self, request: Request, send: Send) -> None: # pragma: no cover
Expand Down Expand Up @@ -706,13 +707,15 @@
return

# Create SSE stream
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)

async def standalone_sse_writer():
try:
# Create a standalone message stream for server-initiated messages

self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](0)
self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](
REQUEST_STREAM_BUFFER_SIZE
)
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]

async with sse_stream_writer, standalone_stream_reader:
Expand Down Expand Up @@ -903,7 +906,7 @@
replay_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)

# Create SSE stream for replay
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)

async def replay_sender():
try:
Expand All @@ -918,22 +921,32 @@

# If stream ID not in mapping, create it
if stream_id and stream_id not in self._request_streams:
# Register SSE writer so close_sse_stream() can close it
self._sse_stream_writers[stream_id] = sse_stream_writer

# Send priming event for this new connection
await self._maybe_send_priming_event(stream_id, sse_stream_writer, replay_protocol_version)

# Create new request streams for this connection
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0)
msg_reader = self._request_streams[stream_id][1]

# Forward messages to SSE
async with msg_reader:
async for event_message in msg_reader:
event_data = self._create_event_data(event_message)

await sse_stream_writer.send(event_data)
try:
# Register SSE writer so close_sse_stream() can close it
self._sse_stream_writers[stream_id] = sse_stream_writer

# Prime the resumed connection so the client sees the stream
# is re-registered. The replay→live-tail ordering window here
# is pre-existing and tracked separately.
priming_event = await self._mint_priming_event(stream_id, replay_protocol_version)
if priming_event is not None:
await sse_stream_writer.send(priming_event)

# Create new request streams for this connection
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](
REQUEST_STREAM_BUFFER_SIZE
)
msg_reader = self._request_streams[stream_id][1]

# Forward messages to SSE
async with msg_reader:
async for event_message in msg_reader:
event_data = self._create_event_data(event_message)

await sse_stream_writer.send(event_data)
finally:
self._sse_stream_writers.pop(stream_id, None)
await self._clean_up_memory_streams(stream_id)
except anyio.ClosedResourceError:
# Expected when close_sse_stream() is called
logger.debug("Replay SSE stream closed by close_sse_stream()")
Expand Down
Loading
Loading