diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml new file mode 100644 index 0000000000..d9362afd57 --- /dev/null +++ b/.github/workflows/deploy-docs.yml @@ -0,0 +1,57 @@ +name: Deploy Docs + +on: + push: + branches: + - main + - v1.x + paths: + - docs/** + - mkdocs.yml + - src/mcp/** + - scripts/build-docs.sh + - pyproject.toml + - uv.lock + - .github/workflows/deploy-docs.yml + workflow_dispatch: + +concurrency: + group: deploy-docs + cancel-in-progress: false + +jobs: + deploy-docs: + runs-on: ubuntu-latest + + permissions: + contents: read + pages: write + id-token: write + + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 + + - name: Install uv + uses: astral-sh/setup-uv@803947b9bd8e9f986429fa0c5a41c367cd732b41 # v7.2.1 + with: + enable-cache: true + version: 0.9.5 + + - name: Build combined docs (v1.x at /, main at /v2/) + run: bash scripts/build-docs.sh site + + - name: Configure Pages + uses: actions/configure-pages@45bfe0192ca1faeb007ade9deae92b16b8254a0d # v6.0.0 + + - name: Upload Pages artifact + uses: actions/upload-pages-artifact@fc324d3547104276b827a68afc52ff2a11cc49c9 # v5.0.0 + with: + path: site + + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@cd2ce8fcbc39b97be8ca5fce6e763baed58fa128 # v5.0.0 diff --git a/.github/workflows/publish-docs-manually.yml b/.github/workflows/publish-docs-manually.yml deleted file mode 100644 index befe44d31c..0000000000 --- a/.github/workflows/publish-docs-manually.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: Publish Docs manually - -on: - workflow_dispatch: - -jobs: - docs-publish: - runs-on: ubuntu-latest - permissions: - contents: write - steps: - - uses: actions/checkout@v4 - - name: Configure Git Credentials - run: | - git config user.name github-actions[bot] - git config user.email 41898282+github-actions[bot]@users.noreply.github.com - - - name: Install uv - uses: astral-sh/setup-uv@v3 - with: - enable-cache: true - version: 0.9.5 - - - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - - uses: actions/cache@v4 - with: - key: mkdocs-material-${{ env.cache_id }} - path: .cache - restore-keys: | - mkdocs-material- - - - run: uv sync --frozen --group docs - - run: uv run --frozen --no-sync mkdocs gh-deploy --force diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 59ede84172..085f82d833 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -51,32 +51,3 @@ jobs: - name: Publish package distributions to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - - docs-publish: - runs-on: ubuntu-latest - needs: ["pypi-publish"] - permissions: - contents: write - steps: - - uses: actions/checkout@v4 - - name: Configure Git Credentials - run: | - git config user.name github-actions[bot] - git config user.email 41898282+github-actions[bot]@users.noreply.github.com - - - name: Install uv - uses: astral-sh/setup-uv@v3 - with: - enable-cache: true - version: 0.9.5 - - - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - - uses: actions/cache@v4 - with: - key: mkdocs-material-${{ env.cache_id }} - path: .cache - restore-keys: | - mkdocs-material- - - - run: uv sync --frozen --group docs - - run: uv run --frozen --no-sync mkdocs gh-deploy --force diff --git a/.gitignore b/.gitignore index 2478cac4b3..348785e4e1 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,7 @@ venv.bak/ # mkdocs documentation /site +/.worktrees/ # mypy .mypy_cache/ diff --git a/docs/experimental/tasks-server.md b/docs/experimental/tasks-server.md index 761dc5de5c..c6b94814cd 100644 --- a/docs/experimental/tasks-server.md +++ b/docs/experimental/tasks-server.md @@ -53,6 +53,29 @@ That's it. `enable_tasks()` automatically: - Registers handlers for `tasks/get`, `tasks/result`, `tasks/list`, `tasks/cancel` - Updates server capabilities +## Task Visibility + +Task IDs generated by `run_task()` embed an opaque marker identifying the session that +created the task, and the default handlers use it to restrict each session to its own +tasks: `tasks/get`, `tasks/result`, and `tasks/cancel` respond with "task not found" for +another session's task, and `tasks/list` returns only the requesting session's tasks. A +client that reconnects gets a new session and can no longer reach tasks it created on the +previous one. + +A task ID has no session marker when it was passed to `run_task()` explicitly, when the +task was created directly through the `TaskStore`, or when the server runs in stateless +mode (each request gets a fresh session, so tasks must remain reachable across requests). +Such tasks are accessible to any requestor that presents the exact task ID, and are never +included in `tasks/list` responses because the server cannot tell which session they +belong to. Treat these task IDs as capabilities: generate them with enough entropy that +they cannot be guessed, share them only with the intended recipient, and prefer short +TTLs. Passing an explicit `task_id` to `run_task()` is deprecated for this reason. + +To scope tasks to something other than the session — for example a user identity from your +authorization layer — register your own handlers with `@server.experimental.get_task()`, +`@server.experimental.get_task_result()`, `@server.experimental.list_tasks()`, and +`@server.experimental.cancel_task()` instead of relying on the defaults. + ## Tool Declaration Tools declare task support via the `execution.taskSupport` field: diff --git a/docs/index.md b/docs/index.md index 061a2f5bcf..48f3ace1ce 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,5 +1,8 @@ # MCP Python SDK +!!! tip "Looking for the upcoming v2?" + See the [v2 development documentation](https://py.sdk.modelcontextprotocol.io/v2/). + The **Model Context Protocol (MCP)** allows applications to provide context for LLMs in a standardized way, separating the concerns of providing context from the actual LLM interaction. This Python SDK implements the full MCP specification, making it easy to: diff --git a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py index 80a2e8b8a3..aa8306ba2a 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py @@ -123,6 +123,8 @@ async def introspect_handler(request: Request) -> Response: "iat": int(time.time()), "token_type": "Bearer", "aud": access_token.resource, # RFC 8707 audience claim + "sub": access_token.subject, # RFC 7662 subject + "iss": str(server_settings.server_url), } ) diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py index e3a25d3e8c..fc1ef1df94 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -186,6 +186,7 @@ async def handle_simple_callback(self, username: str, password: str, state: str) scopes=[self.settings.mcp_scope], code_challenge=code_challenge, resource=resource, # RFC 8707 + subject=username, ) self.auth_codes[new_code] = auth_code @@ -224,6 +225,7 @@ async def exchange_authorization_code( scopes=authorization_code.scopes, expires_at=int(time.time()) + 3600, resource=authorization_code.resource, # RFC 8707 + subject=authorization_code.subject, ) # Store user data mapping for this token diff --git a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py index 5228d034e4..641095a125 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py +++ b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py @@ -75,6 +75,8 @@ async def verify_token(self, token: str) -> AccessToken | None: scopes=data.get("scope", "").split() if data.get("scope") else [], expires_at=data.get("exp"), resource=data.get("aud"), # Include resource in token + subject=data.get("sub"), # RFC 7662 subject (resource owner) + claims=data, ) except Exception as e: logger.warning(f"Token introspection failed: {e}") diff --git a/mkdocs.yml b/mkdocs.yml index 6f327d006b..7245f239b4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -5,7 +5,7 @@ strict: true repo_name: modelcontextprotocol/python-sdk repo_url: https://github.com/modelcontextprotocol/python-sdk edit_uri: edit/v1.x/docs/ -site_url: https://modelcontextprotocol.github.io/python-sdk +site_url: https://py.sdk.modelcontextprotocol.io/ # TODO(Marcelo): Add Anthropic copyright? # copyright: © Model Context Protocol 2025 to present diff --git a/pyproject.toml b/pyproject.toml index 9b28741aeb..6c88c8e789 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,7 @@ bump = true [project.urls] Homepage = "https://modelcontextprotocol.io" +Documentation = "https://py.sdk.modelcontextprotocol.io/" Repository = "https://github.com/modelcontextprotocol/python-sdk" Issues = "https://github.com/modelcontextprotocol/python-sdk/issues" diff --git a/scripts/build-docs.sh b/scripts/build-docs.sh new file mode 100755 index 0000000000..5a61309acf --- /dev/null +++ b/scripts/build-docs.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +# +# Build combined v1 + v2 MkDocs documentation for GitHub Pages. +# +# v1 docs (from the v1.x branch) are placed at the site root. +# v2 docs (from main) are placed under /v2/. +# +# Both branches are fetched fresh from origin, so the output is identical +# regardless of which branch triggered the workflow. This script is intended +# to run in CI; for local single-branch preview use `uv run mkdocs serve`. +# +# Usage: +# scripts/build-docs.sh [output-dir] +# +# Default output directory: site +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +OUTPUT_DIR="$(cd "$REPO_ROOT" && mkdir -p "${1:-site}" && cd "${1:-site}" && pwd)" +V1_WORKTREE="$REPO_ROOT/.worktrees/v1-docs" +V2_WORKTREE="$REPO_ROOT/.worktrees/v2-docs" + +cleanup() { + cd "$REPO_ROOT" + git worktree remove --force "$V1_WORKTREE" 2>/dev/null || true + git worktree remove --force "$V2_WORKTREE" 2>/dev/null || true + rmdir "$REPO_ROOT/.worktrees" 2>/dev/null || true +} +trap cleanup EXIT + +rm -rf "${OUTPUT_DIR:?}"/* + +build_branch() { + local branch="$1" worktree="$2" dest="$3" + + echo "=== Building docs for ${branch} ===" + git fetch origin "$branch" + git worktree remove --force "$worktree" 2>/dev/null || true + rm -rf "$worktree" + git worktree add --detach "$worktree" "origin/${branch}" + + ( + cd "$worktree" + uv sync --frozen --group docs + uv run --frozen --no-sync mkdocs build --site-dir "$dest" + ) +} + +build_branch v1.x "$V1_WORKTREE" "$OUTPUT_DIR" +build_branch main "$V2_WORKTREE" "$OUTPUT_DIR/v2" + +echo "=== Combined docs built at $OUTPUT_DIR ===" diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 64c9b8841f..300b298924 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -1,6 +1,6 @@ import json import time -from typing import Any +from typing import Any, TypedDict from pydantic import AnyHttpUrl from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser @@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken): self.scopes = auth_info.scopes +class AuthorizationContext(TypedDict): + client_id: str + issuer: str | None + subject: str | None + + +def authorization_context(user: AuthenticatedUser) -> AuthorizationContext: + """Identify the principal `user` represents, for transports to compare + against the principal that created a session. Components the token + verifier does not supply are `None`, so the comparison degrades to the + remaining components. + + See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for + a verifier that populates `subject` and `claims` from an introspection + response.""" + token = user.access_token + issuer = (token.claims or {}).get("iss") + return AuthorizationContext( + client_id=token.client_id, + issuer=str(issuer) if issuer is not None else None, + subject=token.subject, + ) + + class BearerAuthBackend(AuthenticationBackend): """ Authentication backend that validates Bearer tokens using a TokenVerifier. diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 96296c148e..310baff5fd 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Generic, Literal, Protocol, TypeVar +from typing import Any, Generic, Literal, Protocol, TypeVar from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from pydantic import AnyUrl, BaseModel @@ -25,6 +25,7 @@ class AuthorizationCode(BaseModel): redirect_uri: AnyUrl redirect_uri_provided_explicitly: bool resource: str | None = None # RFC 8707 resource indicator + subject: str | None = None # resource owner; propagate to the issued AccessToken class RefreshToken(BaseModel): @@ -32,6 +33,7 @@ class RefreshToken(BaseModel): client_id: str scopes: list[str] expires_at: int | None = None + subject: str | None = None # resource owner; propagate to refreshed AccessTokens class AccessToken(BaseModel): @@ -40,6 +42,8 @@ class AccessToken(BaseModel): scopes: list[str] expires_at: int | None = None resource: str | None = None # RFC 8707 resource indicator + subject: str | None = None # RFC 7662/9068 `sub`: resource owner; unique only per issuer + claims: dict[str, Any] | None = None # additional claims (e.g. `iss`, `act`) RegistrationErrorCode = Literal[ diff --git a/src/mcp/server/experimental/__init__.py b/src/mcp/server/experimental/__init__.py index 824bb8b8be..91c6dcf3e8 100644 --- a/src/mcp/server/experimental/__init__.py +++ b/src/mcp/server/experimental/__init__.py @@ -8,4 +8,5 @@ - mcp.server.experimental.task_support.TaskSupport - mcp.server.experimental.task_result_handler.TaskResultHandler - mcp.server.experimental.request_context.Experimental +- mcp.server.experimental.task_scope (session scoping of task IDs) """ diff --git a/src/mcp/server/experimental/request_context.py b/src/mcp/server/experimental/request_context.py index 78e75beb6a..0d69836355 100644 --- a/src/mcp/server/experimental/request_context.py +++ b/src/mcp/server/experimental/request_context.py @@ -7,11 +7,15 @@ WARNING: These APIs are experimental and may change without notice. """ +import warnings from collections.abc import Awaitable, Callable from dataclasses import dataclass, field -from typing import Any +from typing import Any, overload + +from typing_extensions import deprecated from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.experimental.task_scope import scoped_task_id from mcp.server.experimental.task_support import TaskSupport from mcp.server.session import ServerSession from mcp.shared.exceptions import McpError @@ -29,6 +33,14 @@ Tool, ) +EXPLICIT_TASK_ID_DEPRECATION = ( + "Passing an explicit task_id to run_task is deprecated. A task created with an " + "explicit ID is not associated with the session that created it: any requestor " + "that presents the ID can read its status and result or cancel it, and it never " + "appears in tasks/list. Omit task_id to let the SDK generate an ID associated " + "with the creating session." +) + @dataclass class Experimental: @@ -143,6 +155,25 @@ def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool: return False return True + @overload + async def run_task( + self, + work: Callable[[ServerTaskContext], Awaitable[Result]], + *, + task_id: None = None, + model_immediate_response: str | None = None, + ) -> CreateTaskResult: ... + + @overload + @deprecated(EXPLICIT_TASK_ID_DEPRECATION) + async def run_task( + self, + work: Callable[[ServerTaskContext], Awaitable[Result]], + *, + task_id: str, + model_immediate_response: str | None = None, + ) -> CreateTaskResult: ... + async def run_task( self, work: Callable[[ServerTaskContext], Awaitable[Result]], @@ -167,9 +198,17 @@ async def run_task( When work() returns a Result, the task is auto-completed with that result. If work() raises an exception, the task is auto-failed. + Generated task IDs embed the session's task scope so that the default + task handlers only serve the task to the session that created it. An + explicitly provided `task_id` is used verbatim and is not associated + with the session, so any session can access it through the default + handlers; passing one is deprecated for that reason. + Args: work: Async function that does the actual work - task_id: Optional task ID (generated if not provided) + task_id: Deprecated. Optional task ID, used verbatim and not + associated with the creating session. Omit it to let the SDK + generate one. model_immediate_response: Optional string to include in _meta as io.modelcontextprotocol/model-immediate-response @@ -196,6 +235,8 @@ async def work(task: ServerTaskContext) -> CallToolResult: WARNING: This API is experimental and may change without notice. """ + if task_id is not None: + warnings.warn(EXPLICIT_TASK_ID_DEPRECATION, DeprecationWarning, stacklevel=2) if self._task_support is None: raise RuntimeError("Task support not enabled. Call server.experimental.enable_tasks() first.") if self._session is None: @@ -210,6 +251,11 @@ async def work(task: ServerTaskContext) -> CallToolResult: # Access task_group via TaskSupport - raises if not in run() context task_group = support.task_group + if task_id is None: + session_scope = self._session.experimental.task_session_scope + if session_scope is not None: + task_id = scoped_task_id(session_scope) + task = await support.store.create_task(self.task_metadata, task_id) task_ctx = ServerTaskContext( diff --git a/src/mcp/server/experimental/session_features.py b/src/mcp/server/experimental/session_features.py index 4842da5175..c118537fa2 100644 --- a/src/mcp/server/experimental/session_features.py +++ b/src/mcp/server/experimental/session_features.py @@ -40,6 +40,12 @@ class ExperimentalServerSessionFeatures: def __init__(self, session: "ServerSession") -> None: self._session = session + # Opaque marker identifying this session for task scoping. Assigned by + # TaskSupport.configure_session(). Task IDs generated by run_task() + # embed it so the default task handlers can restrict task access to + # the session that created the task. None means tasks created on this + # session are not associated with it (e.g. stateless servers). + self.task_session_scope: str | None = None async def get_task(self, task_id: str) -> types.GetTaskResult: """ diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 0b869216e8..1cf7f69749 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -46,6 +46,11 @@ class TaskResultHandler: 4. Blocks until task reaches terminal state 5. Returns the final result + Prefer `server.experimental.enable_tasks()`, whose default tasks/result + handler wraps `handle()` and only serves tasks created by the requesting + session. A custom handler that calls `handle()` directly is responsible + for deciding which requestors may access which tasks. + Usage: # Create handler with store and queue handler = TaskResultHandler(task_store, message_queue) @@ -55,9 +60,6 @@ class TaskResultHandler: async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult: ctx = server.request_context return await handler.handle(req, ctx.session, ctx.request_id) - - # Or use the convenience method - handler.register(server) """ def __init__( diff --git a/src/mcp/server/experimental/task_scope.py b/src/mcp/server/experimental/task_scope.py new file mode 100644 index 0000000000..c33cf55725 --- /dev/null +++ b/src/mcp/server/experimental/task_scope.py @@ -0,0 +1,75 @@ +""" +Session scoping for experimental task identifiers. + +Task IDs generated by `run_task()` embed an opaque, per-session marker (the +"session scope") so that the default task handlers can tell which session +created a task. The default handlers for tasks/get, tasks/result, tasks/list, +and tasks/cancel only operate on tasks created by the requesting session. + +Task IDs without a session scope (explicitly provided IDs, IDs created +directly through a TaskStore, or IDs created in stateless mode) have no known +creator. They can be used with tasks/get, tasks/result, and tasks/cancel from +any session - possession of the ID is what grants access - but they are never +included in tasks/list responses. + +WARNING: These APIs are experimental and may change without notice. +""" + +import re +from uuid import uuid4 + +__all__ = [ + "new_session_scope", + "scoped_task_id", + "session_scope_of", + "task_in_session_scope", + "task_listable_in_session_scope", +] + +# A scoped task ID has the form "<32 hex chars>:". Both halves must +# match exactly so that explicitly chosen task IDs are never mistaken for +# scoped ones. \Z rather than $ so a trailing newline cannot match. +_SCOPED_TASK_ID = re.compile( + r"\A(?P[0-9a-f]{32}):" + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\Z" +) + + +def new_session_scope() -> str: + """Create a new opaque session scope token.""" + return uuid4().hex + + +def scoped_task_id(session_scope: str) -> str: + """Generate a task ID associated with the given session scope.""" + return f"{session_scope}:{uuid4()}" + + +def session_scope_of(task_id: str) -> str | None: + """Return the session scope embedded in a task ID, or None if it has none.""" + match = _SCOPED_TASK_ID.match(task_id) + return match.group("scope") if match else None + + +def task_in_session_scope(task_id: str, session_scope: str | None) -> bool: + """Whether a task may be used by a requestor with the given session scope. + + Used by tasks/get, tasks/result, and tasks/cancel. A task whose ID carries + no session scope has no known creator, so possession of the ID is what + grants access to it: it can be used from any session. + """ + embedded = session_scope_of(task_id) + return embedded is None or embedded == session_scope + + +def task_listable_in_session_scope(task_id: str, session_scope: str | None) -> bool: + """Whether a task may be included in a tasks/list response for the given session scope. + + Used by tasks/list. Listing is stricter than access by ID: a task is only + listed to the session that created it. Tasks with no session scope are + never listed because they have no known creator, and requestors with no + session scope are never shown any tasks because the server cannot tell + them apart. + """ + embedded = session_scope_of(task_id) + return embedded is not None and embedded == session_scope diff --git a/src/mcp/server/experimental/task_support.py b/src/mcp/server/experimental/task_support.py index dbb2ed6d2b..8e91faf73b 100644 --- a/src/mcp/server/experimental/task_support.py +++ b/src/mcp/server/experimental/task_support.py @@ -13,6 +13,7 @@ from anyio.abc import TaskGroup from mcp.server.experimental.task_result_handler import TaskResultHandler +from mcp.server.experimental.task_scope import new_session_scope from mcp.server.session import ServerSession from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue @@ -83,7 +84,7 @@ async def run(self) -> AsyncIterator[None]: finally: self._task_group = None - def configure_session(self, session: ServerSession) -> None: + def configure_session(self, session: ServerSession, *, stateless: bool = False) -> None: """ Configure a session for task support. @@ -91,12 +92,22 @@ def configure_session(self, session: ServerSession) -> None: responses to queued requests (elicitation, sampling) are routed back to the waiting resolvers. + It also assigns the session a task session scope. Task IDs generated + by `run_task()` embed this scope, and the default task handlers only + operate on tasks created by the requesting session. Stateless sessions + are not assigned a scope: each request runs on a fresh session, so a + task created by one request could never be retrieved by a later one if + tasks were bound to the session that created them. + Called automatically by Server.run() for each new session. Args: session: The session to configure + stateless: Whether the session belongs to a stateless server run """ session.add_response_router(self.handler) + if not stateless and session.experimental.task_session_scope is None: + session.experimental.task_session_scope = new_session_scope() @classmethod def in_memory(cls) -> "TaskSupport": diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 7a43bd7cf0..8f62ce2e54 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -1282,9 +1282,14 @@ async def log( related_request_id=self.request_id, ) + # TODO(maxisbey): see if this is needed otherwise remove @property def client_id(self) -> str | None: - """Get the client ID if available.""" + """Get the client ID if available. + + Note: this reads from the MCP request's `_meta` params, not the OAuth + bearer token. For that, use `get_access_token().client_id`. + """ return ( getattr(self.request_context.meta, "client_id", None) if self.request_context.meta else None ) # pragma: no cover diff --git a/src/mcp/server/lowlevel/experimental.py b/src/mcp/server/lowlevel/experimental.py index 42353e4ea0..737d6bb2cd 100644 --- a/src/mcp/server/lowlevel/experimental.py +++ b/src/mcp/server/lowlevel/experimental.py @@ -9,6 +9,7 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING +from mcp.server.experimental.task_scope import task_in_session_scope, task_listable_in_session_scope from mcp.server.experimental.task_support import TaskSupport from mcp.server.lowlevel.func_inspection import create_call_wrapper from mcp.shared.exceptions import McpError @@ -31,6 +32,7 @@ ServerResult, ServerTasksCapability, ServerTasksRequestsCapability, + Task, TasksCallCapability, TasksCancelCapability, TasksListCapability, @@ -125,8 +127,38 @@ def enable_tasks( return self._task_support + def _requestor_session_scope(self) -> str | None: + """Return the task session scope of the session making the current request.""" + return self._server.request_context.session.experimental.task_session_scope + + def _require_task_in_requestor_scope(self, task_id: str) -> None: + """Reject task IDs that belong to a different session. + + Task IDs generated by `run_task()` embed the creating session's + scope. The default handlers treat a task created by another session + exactly like a task that does not exist, so a requestor cannot tell + whether such a task exists. Task IDs without an embedded scope are + accepted from any session. + + Raises: + McpError: With INVALID_PARAMS if the task belongs to another session. + """ + if not task_in_session_scope(task_id, self._requestor_session_scope()): + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message=f"Task not found: {task_id}", + ) + ) + def _register_default_task_handlers(self) -> None: - """Register default handlers for task operations.""" + """Register default handlers for task operations. + + Each default handler only operates on tasks created by the requesting + session (see `_require_task_in_requestor_scope`), and tasks/list only + returns the requesting session's own tasks (see + `task_listable_in_session_scope`). + """ assert self._task_support is not None support = self._task_support @@ -134,6 +166,7 @@ def _register_default_task_handlers(self) -> None: if GetTaskRequest not in self._request_handlers: async def _default_get_task(req: GetTaskRequest) -> ServerResult: + self._require_task_in_requestor_scope(req.params.taskId) task = await support.store.get_task(req.params.taskId) if task is None: raise McpError( @@ -160,6 +193,7 @@ async def _default_get_task(req: GetTaskRequest) -> ServerResult: if GetTaskPayloadRequest not in self._request_handlers: async def _default_get_task_result(req: GetTaskPayloadRequest) -> ServerResult: + self._require_task_in_requestor_scope(req.params.taskId) ctx = self._server.request_context result = await support.handler.handle(req, ctx.session, ctx.request_id) return ServerResult(result) @@ -170,9 +204,26 @@ async def _default_get_task_result(req: GetTaskPayloadRequest) -> ServerResult: if ListTasksRequest not in self._request_handlers: async def _default_list_tasks(req: ListTasksRequest) -> ServerResult: - cursor = req.params.cursor if req.params else None - tasks, next_cursor = await support.store.list_tasks(cursor) - return ServerResult(ListTasksResult(tasks=tasks, nextCursor=next_cursor)) + requestor_scope = self._requestor_session_scope() + if requestor_scope is None: + # The server cannot tell this requestor apart from any + # other, so there are no tasks it can be shown. + return ServerResult(ListTasksResult(tasks=[])) + # Return every task that belongs to the requesting session in + # a single page. The store's pagination cursor is never sent + # to the requestor: it is derived from the unfiltered listing, + # so it could identify a task belonging to a different + # session. For the same reason the request's cursor is not + # forwarded to the store. + own_tasks: list[Task] = [] + cursor: str | None = None + while True: + page, cursor = await support.store.list_tasks(cursor) + own_tasks.extend( + task for task in page if task_listable_in_session_scope(task.taskId, requestor_scope) + ) + if cursor is None: + return ServerResult(ListTasksResult(tasks=own_tasks)) self._request_handlers[ListTasksRequest] = _default_list_tasks @@ -180,6 +231,7 @@ async def _default_list_tasks(req: ListTasksRequest) -> ServerResult: if CancelTaskRequest not in self._request_handlers: async def _default_cancel_task(req: CancelTaskRequest) -> ServerResult: + self._require_task_in_requestor_scope(req.params.taskId) result = await cancel_task(support.store, req.params.taskId) return ServerResult(result) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 2dd1a8277a..7d925de32b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -667,7 +667,7 @@ async def run( # Configure task support for this session if enabled task_support = self._experimental_handlers.task_support if self._experimental_handlers else None if task_support is not None: - task_support.configure_session(session) + task_support.configure_session(session, stateless=stateless) await stack.enter_async_context(task_support.run()) async with anyio.create_task_group() as tg: diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 19af93fd16..489785c4c9 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -52,6 +52,7 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.transport_security import ( TransportSecurityMiddleware, TransportSecuritySettings, @@ -75,6 +76,9 @@ class SseServerTransport: _endpoint: str _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] + # Identity of the credential that created each session; requests for a + # session must present the same credential. + _session_owners: dict[UUID, AuthorizationContext] _security: TransportSecurityMiddleware def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: @@ -115,6 +119,7 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | self._endpoint = endpoint self._read_stream_writers = {} + self._session_owners = {} self._security = TransportSecurityMiddleware(security_settings) logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @@ -142,6 +147,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag write_stream, write_stream_reader = anyio.create_memory_object_stream(0) session_id = uuid4() + user = scope.get("user") + if isinstance(user, AuthenticatedUser): + self._session_owners[session_id] = authorization_context(user) self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") @@ -177,26 +185,34 @@ async def sse_writer(): } ) - async with anyio.create_task_group() as tg: - - async def response_wrapper(scope: Scope, receive: Receive, send: Send): - """ - The EventSourceResponse returning signals a client close / disconnect. - In this case we close our side of the streams to signal the client that - the connection has been closed. - """ - await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( - scope, receive, send - ) - await read_stream_writer.aclose() - await write_stream_reader.aclose() - logging.debug(f"Client session disconnected {session_id}") - - logger.debug("Starting SSE response task") - tg.start_soon(response_wrapper, scope, receive, send) - - logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) + try: + async with anyio.create_task_group() as tg: + + async def response_wrapper(scope: Scope, receive: Receive, send: Send): + """ + The EventSourceResponse returning signals a client close / disconnect. + In this case we close our side of the streams to signal the client that + the connection has been closed. + """ + await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( + scope, receive, send + ) + await read_stream_writer.aclose() + await write_stream_reader.aclose() + await sse_stream_reader.aclose() + logging.debug(f"Client session disconnected {session_id}") + + logger.debug("Starting SSE response task") + tg.start_soon(response_wrapper, scope, receive, send) + + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) + finally: + # The connection is gone: stop routing messages to this session + # and drop its entries so they do not accumulate for the lifetime + # of the transport. + self._read_stream_writers.pop(session_id, None) + self._session_owners.pop(session_id, None) async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover logger.debug("Handling POST message") @@ -227,6 +243,15 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) response = Response("Could not find session", status_code=404) return await response(scope, receive, send) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + if requestor != self._session_owners.get(session_id): + # A session can only be used with the credential that created it. + # Respond exactly as if the session did not exist. + logger.warning("Rejecting message for session %s: credential does not match", session_id) + response = Response("Could not find session", status_code=404) + return await response(scope, receive, send) + body = await request.body() logger.debug(f"Received JSON: {body}") diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 8a7b765e86..1a1a85721d 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -5,7 +5,6 @@ import contextlib import logging from collections.abc import AsyncIterator -from http import HTTPStatus from typing import Any from uuid import uuid4 @@ -15,6 +14,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, @@ -88,6 +88,9 @@ def __init__( # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + # Identity of the credential that created each session; requests for a + # session must present the same credential. + self._session_owners: dict[str, AuthorizationContext] = {} # The task group will be set during lifespan self._task_group = None @@ -135,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: self._task_group = None # Clear any remaining server instances self._server_instances.clear() + self._session_owners.clear() async def handle_request( self, @@ -227,12 +231,32 @@ async def _handle_stateful_request( request = Request(scope, receive) request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + user = scope.get("user") + requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None + # Existing session case - if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: # pragma: no cover + if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] + if requestor != self._session_owners.get(request_mcp_session_id): + # A session can only be used with the credential that created + # it. Respond exactly as if the session did not exist. + logger.warning( + "Rejecting request for session %s: credential does not match the one that created the session", + request_mcp_session_id[:64], + ) + body = JSONRPCError( + jsonrpc="2.0", id="server-error", error=ErrorData(code=INVALID_REQUEST, message="Session not found") + ) + response = Response( + body.model_dump_json(by_alias=True, exclude_none=True), + status_code=404, + media_type="application/json", + ) + await response(scope, receive, send) + return logger.debug("Session already exists, handling request directly") # Push back idle deadline on activity - if transport.idle_scope is not None and self.session_idle_timeout is not None: + if transport.idle_scope is not None and self.session_idle_timeout is not None: # pragma: no cover transport.idle_scope.deadline = anyio.current_time() + self.session_idle_timeout await transport.handle_request(scope, receive, send) return @@ -251,6 +275,8 @@ async def _handle_stateful_request( ) assert http_transport.mcp_session_id is not None + if requestor is not None: + self._session_owners[http_transport.mcp_session_id] = requestor self._server_instances[http_transport.mcp_session_id] = http_transport logger.info(f"Created new transport with session ID: {new_session_id}") @@ -281,6 +307,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE assert http_transport.mcp_session_id is not None logger.info(f"Session {http_transport.mcp_session_id} idle timeout") self._server_instances.pop(http_transport.mcp_session_id, None) + self._session_owners.pop(http_transport.mcp_session_id, None) await http_transport.terminate() except Exception: logger.exception(f"Session {http_transport.mcp_session_id} crashed") @@ -296,6 +323,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE "active instances." ) del self._server_instances[http_transport.mcp_session_id] + self._session_owners.pop(http_transport.mcp_session_id, None) # Assert task group is not None for type checking assert self._task_group is not None @@ -306,19 +334,10 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE await http_transport.handle_request(scope, receive, send) else: # Unknown or expired session ID - return 404 per MCP spec - # TODO: Align error code once spec clarifies - # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 - error_response = JSONRPCError( - jsonrpc="2.0", - id="server-error", - error=ErrorData( - code=INVALID_REQUEST, - message="Session not found", - ), + body = JSONRPCError( + jsonrpc="2.0", id="server-error", error=ErrorData(code=INVALID_REQUEST, message="Session not found") ) response = Response( - content=error_response.model_dump_json(by_alias=True, exclude_none=True), - status_code=HTTPStatus.NOT_FOUND, - media_type="application/json", + body.model_dump_json(by_alias=True, exclude_none=True), status_code=404, media_type="application/json" ) await response(scope, receive, send) diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 7209ed412a..64a1dabb04 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -506,13 +506,14 @@ async def run_server() -> None: # Create a task directly in the store for testing task = await store.create_task(TaskMetadata(ttl=60000)) - # Test list_tasks (default handler) + # Test list_tasks (default handler). Tasks created directly in the + # store have no session scope, so they are reachable by ID but not + # included in tasks/list (see test_task_scope.py). list_result = await client_session.send_request( ClientRequest(ListTasksRequest()), ListTasksResult, ) - assert len(list_result.tasks) == 1 - assert list_result.tasks[0].taskId == task.taskId + assert list_result.tasks == [] # Test get_task (default handler - found) get_result = await client_session.send_request( diff --git a/tests/experimental/tasks/server/test_task_scope.py b/tests/experimental/tasks/server/test_task_scope.py new file mode 100644 index 0000000000..c13b728a86 --- /dev/null +++ b/tests/experimental/tasks/server/test_task_scope.py @@ -0,0 +1,150 @@ +"""Unit tests for the task session-scope helpers. + +A session scope is an opaque marker assigned to each session by +TaskSupport.configure_session(). Task IDs generated by run_task() embed it so +the default task handlers can tell which session created a task. See +test_task_visibility.py for the end-to-end behaviour these helpers produce. +""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import anyio +import pytest + +from mcp.server import Server +from mcp.server.experimental.task_scope import ( + new_session_scope, + scoped_task_id, + session_scope_of, + task_in_session_scope, + task_listable_in_session_scope, +) +from mcp.server.experimental.task_support import TaskSupport +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage + + +def test_new_session_scope_is_unique() -> None: + assert new_session_scope() != new_session_scope() + + +def test_scoped_task_id_round_trips_its_scope() -> None: + scope = new_session_scope() + + task_id = scoped_task_id(scope) + + assert session_scope_of(task_id) == scope + + +def test_scoped_task_ids_are_unique_within_a_scope() -> None: + scope = new_session_scope() + + assert scoped_task_id(scope) != scoped_task_id(scope) + + +@pytest.mark.parametrize( + "task_id", + [ + "plain-task-id", + "550e8400-e29b-41d4-a716-446655440000", # bare uuid4 + "", + # Right shape but the scope half is not 32 hex chars. + "not-a-scope:550e8400-e29b-41d4-a716-446655440000", + # Right scope half but the suffix is not a uuid4. + "0123456789abcdef0123456789abcdef:not-a-uuid", + # Uppercase hex is not produced by new_session_scope(). + "0123456789ABCDEF0123456789ABCDEF:550e8400-e29b-41d4-a716-446655440000", + ], +) +def test_session_scope_of_returns_none_for_unscoped_ids(task_id: str) -> None: + assert session_scope_of(task_id) is None + + +def test_a_scoped_task_is_usable_only_from_the_scope_that_created_it() -> None: + scope = new_session_scope() + task_id = scoped_task_id(scope) + + assert task_in_session_scope(task_id, scope) is True + assert task_in_session_scope(task_id, new_session_scope()) is False + assert task_in_session_scope(task_id, None) is False + + +def test_an_unscoped_task_is_usable_from_any_scope() -> None: + assert task_in_session_scope("plain-task-id", new_session_scope()) is True + assert task_in_session_scope("plain-task-id", None) is True + + +def test_a_scoped_task_is_listable_only_in_the_scope_that_created_it() -> None: + scope = new_session_scope() + task_id = scoped_task_id(scope) + + assert task_listable_in_session_scope(task_id, scope) is True + assert task_listable_in_session_scope(task_id, new_session_scope()) is False + assert task_listable_in_session_scope(task_id, None) is False + + +def test_an_unscoped_task_is_never_listable() -> None: + assert task_listable_in_session_scope("plain-task-id", new_session_scope()) is False + assert task_listable_in_session_scope("plain-task-id", None) is False + + +@asynccontextmanager +async def _make_session() -> AsyncIterator[ServerSession]: + """Create a ServerSession suitable for inspecting configure_session().""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + options = InitializationOptions( + server_name="test", + server_version="0", + capabilities=Server("test").get_capabilities(NotificationOptions(), {}), + ) + async with ( + server_to_client_receive, + client_to_server_send, + ServerSession(client_to_server_receive, server_to_client_send, options) as session, + ): + yield session + + +@pytest.mark.anyio +async def test_configure_session_assigns_a_scope() -> None: + support = TaskSupport.in_memory() + async with _make_session() as session: + assert session.experimental.task_session_scope is None + + support.configure_session(session) + + assert session.experimental.task_session_scope is not None + + +@pytest.mark.anyio +async def test_configure_session_assigns_distinct_scopes_per_session() -> None: + support = TaskSupport.in_memory() + async with _make_session() as first, _make_session() as second: + support.configure_session(first) + support.configure_session(second) + + assert first.experimental.task_session_scope != second.experimental.task_session_scope + + +@pytest.mark.anyio +async def test_configure_session_is_idempotent() -> None: + support = TaskSupport.in_memory() + async with _make_session() as session: + support.configure_session(session) + scope = session.experimental.task_session_scope + support.configure_session(session) + + assert session.experimental.task_session_scope == scope + + +@pytest.mark.anyio +async def test_configure_session_assigns_no_scope_to_stateless_sessions() -> None: + support = TaskSupport.in_memory() + async with _make_session() as session: + support.configure_session(session, stateless=True) + + assert session.experimental.task_session_scope is None diff --git a/tests/experimental/tasks/server/test_task_visibility.py b/tests/experimental/tasks/server/test_task_visibility.py new file mode 100644 index 0000000000..2ee16399ea --- /dev/null +++ b/tests/experimental/tasks/server/test_task_visibility.py @@ -0,0 +1,321 @@ +"""End-to-end tests for which clients can see and control a task. + +Every test runs a real server and one or more in-memory client sessions. A +task started with run_task() belongs to the client session that started it: +that session can poll it, list it, and cancel it, while every other session +is told the task does not exist. Tasks whose IDs carry no session marker +(explicitly chosen IDs, or tasks on stateless servers) are usable by any +session that knows the ID, but are never listed. +""" + +from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import AsyncExitStack +from typing import Any + +import anyio +import pytest +from anyio.abc import TaskGroup + +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.shared.exceptions import McpError +from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore +from mcp.shared.experimental.tasks.store import TaskStore +from mcp.shared.message import SessionMessage +from mcp.types import ( + TASK_REQUIRED, + CallToolResult, + CreateTaskResult, + ListTasksResult, + TextContent, + Tool, + ToolExecution, +) + +# The `connect` fixture: each call opens a new client session against the test server. +Connect = Callable[..., Awaitable[ClientSession]] + +# Enough tasks that the bundled in-memory store needs more than one page (of 10) +# to list them, so listings that span store pages are exercised. +MORE_TASKS_THAN_ONE_STORE_PAGE = 11 + + +def build_task_server(store: TaskStore | None = None) -> Server: + """Build a server exposing three task tools. + + - "greet" finishes immediately and returns a greeting. + - "long_running_job" keeps running until the server shuts down. + - "nightly_export" is a singleton job: every invocation uses the + explicitly chosen task ID "the-nightly-export". + """ + server = Server("task-visibility-test-server") + server.experimental.enable_tasks(store=store) + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name=name, + description=name, + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport=TASK_REQUIRED), + ) + for name in ("greet", "long_running_job", "nightly_export") + ] + + @server.call_tool() + async def call_tool(name: str, arguments: dict[str, Any]) -> CallToolResult | CreateTaskResult: + async def greet(task: ServerTaskContext) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text=f"Hello, {arguments['name']}!")]) + + async def long_running_job(task: ServerTaskContext) -> CallToolResult: + await anyio.sleep_forever() + raise AssertionError("unreachable") # pragma: no cover + + async def nightly_export(task: ServerTaskContext) -> CallToolResult: + return CallToolResult(content=[TextContent(type="text", text="exported")]) + + run_task = server.request_context.experimental.run_task + if name == "nightly_export": + return await run_task(nightly_export, task_id="the-nightly-export") + return await run_task(greet if name == "greet" else long_running_job) + + return server + + +async def open_client( + server: Server, task_group: TaskGroup, stack: AsyncExitStack, *, stateless: bool = False +) -> ClientSession: + """Connect a new client session to `server` over in-memory streams.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + stateless=stateless, + ) + + task_group.start_soon(run_server) + client = await stack.enter_async_context(ClientSession(server_to_client_receive, client_to_server_send)) + await client.initialize() + return client + + +@pytest.fixture +def task_server() -> Server: + return build_task_server() + + +@pytest.fixture +async def connect(task_server: Server) -> AsyncIterator[Connect]: + """A factory that opens a new client session against the test server on each call.""" + async with anyio.create_task_group() as task_group, AsyncExitStack() as stack: + + async def _connect(*, stateless: bool = False) -> ClientSession: + return await open_client(task_server, task_group, stack, stateless=stateless) + + yield _connect + task_group.cancel_scope.cancel() + + +async def start_task(client: ClientSession, tool: str = "long_running_job", **arguments: Any) -> str: + """Start `tool` as a task and return the new task's ID.""" + result = await client.experimental.call_tool_as_task(tool, arguments) + return result.task.taskId + + +async def wait_until_finished(client: ClientSession, task_id: str) -> None: + """Poll the task until it reaches a terminal status.""" + with anyio.fail_after(5): + async for _ in client.experimental.poll_task(task_id): + pass + + +async def listed_task_ids(client: ClientSession) -> list[str]: + """Return the IDs of every task the server lists for this client.""" + return [task.taskId for task in (await client.experimental.list_tasks()).tasks] + + +# --- What the client that started a task can do with it --- + + +@pytest.mark.anyio +async def test_a_client_can_poll_its_own_task_to_completion_and_read_the_result(connect: Connect) -> None: + client = await connect() + task_id = await start_task(client, "greet", name="Ada") + await wait_until_finished(client, task_id) + + result = await client.experimental.get_task_result(task_id, CallToolResult) + + assert result.content == [TextContent(type="text", text="Hello, Ada!")] + + +@pytest.mark.anyio +async def test_a_client_sees_its_own_task_when_listing_tasks(connect: Connect) -> None: + client = await connect() + task_id = await start_task(client) + + listed = await listed_task_ids(client) + + assert listed == [task_id] + + +@pytest.mark.anyio +async def test_a_client_can_cancel_its_own_task(connect: Connect) -> None: + client = await connect() + task_id = await start_task(client) + + cancelled = await client.experimental.cancel_task(task_id) + + assert cancelled.status == "cancelled" + + +# --- What a client cannot do with a task started by another client --- + + +@pytest.mark.anyio +async def test_a_client_cannot_get_the_status_of_another_clients_task(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + task_id = await start_task(creator) + + with pytest.raises(McpError, match="Task not found"): + await other_client.experimental.get_task(task_id) + + +@pytest.mark.anyio +async def test_a_client_cannot_get_the_result_of_another_clients_task(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + task_id = await start_task(creator, "greet", name="Ada") + await wait_until_finished(creator, task_id) + + with pytest.raises(McpError, match="Task not found"): + await other_client.experimental.get_task_result(task_id, CallToolResult) + + +@pytest.mark.anyio +async def test_a_client_cannot_cancel_another_clients_task(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + task_id = await start_task(creator) + + with pytest.raises(McpError, match="Task not found"): + await other_client.experimental.cancel_task(task_id) + + # The task is unaffected. + assert (await creator.experimental.get_task(task_id)).status == "working" + + +@pytest.mark.anyio +async def test_a_client_does_not_see_another_clients_task_when_listing_tasks(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + await start_task(creator) + + listed = await listed_task_ids(other_client) + + assert listed == [] + + +@pytest.mark.anyio +async def test_each_client_lists_only_its_own_tasks(connect: Connect) -> None: + first_client = await connect() + second_client = await connect() + first_task = await start_task(first_client) + second_task = await start_task(second_client) + + assert await listed_task_ids(first_client) == [first_task] + assert await listed_task_ids(second_client) == [second_task] + + +@pytest.mark.anyio +async def test_listing_tasks_reveals_nothing_about_other_clients_tasks_however_many_there_are( + connect: Connect, +) -> None: + """The listing must not identify other clients' tasks through any field, including the pagination cursor.""" + creator = await connect() + other_client = await connect() + for _ in range(MORE_TASKS_THAN_ONE_STORE_PAGE): + await start_task(creator) + + listing = await other_client.experimental.list_tasks() + + assert listing == ListTasksResult(tasks=[], nextCursor=None) + + +@pytest.mark.anyio +async def test_a_client_with_more_than_one_store_page_of_tasks_lists_all_of_them(connect: Connect) -> None: + client = await connect() + started = {await start_task(client) for _ in range(MORE_TASKS_THAN_ONE_STORE_PAGE)} + + listing = await client.experimental.list_tasks() + + assert {task.taskId for task in listing.tasks} == started + assert listing.nextCursor is None + + +# --- Tasks that do not belong to any client session --- + + +@pytest.mark.anyio +# Choosing the task ID instead of letting the SDK generate one is deprecated for +# exactly the behaviour this test demonstrates: the task is not tied to the +# session that created it. +@pytest.mark.filterwarnings("ignore:Passing an explicit task_id") +async def test_a_task_whose_id_was_chosen_by_the_server_is_accessible_to_every_client(connect: Connect) -> None: + creator = await connect() + other_client = await connect() + await wait_until_finished(creator, await start_task(creator, "nightly_export")) + + status = await other_client.experimental.get_task("the-nightly-export") + + assert status.status == "completed" + + +@pytest.mark.anyio +async def test_a_stateless_server_serves_a_task_to_any_session_that_knows_its_id(connect: Connect) -> None: + first_session = await connect(stateless=True) + second_session = await connect(stateless=True) + task_id = await start_task(first_session, "greet", name="Ada") + await wait_until_finished(second_session, task_id) + + result = await second_session.experimental.get_task_result(task_id, CallToolResult) + + assert result.content == [TextContent(type="text", text="Hello, Ada!")] + + +@pytest.mark.anyio +async def test_a_stateless_server_lists_no_tasks(connect: Connect) -> None: + session = await connect(stateless=True) + await start_task(session) + + listed = await listed_task_ids(session) + + assert listed == [] + + +# --- The behaviour does not depend on the bundled in-memory store --- + + +@pytest.mark.anyio +async def test_clients_are_isolated_when_the_server_uses_a_custom_task_store() -> None: + class CustomTaskStore(InMemoryTaskStore): + """A stand-in for a user-provided TaskStore implementation.""" + + server = build_task_server(store=CustomTaskStore()) + + async with anyio.create_task_group() as task_group, AsyncExitStack() as stack: + creator = await open_client(server, task_group, stack) + other_client = await open_client(server, task_group, stack) + task_id = await start_task(creator) + + with pytest.raises(McpError, match="Task not found"): + await other_client.experimental.get_task(task_id) + + assert (await creator.experimental.get_task(task_id)).status == "working" + task_group.cancel_scope.cancel() diff --git a/tests/experimental/tasks/test_request_context.py b/tests/experimental/tasks/test_request_context.py index 5fa5da81af..e52ec4403a 100644 --- a/tests/experimental/tasks/test_request_context.py +++ b/tests/experimental/tasks/test_request_context.py @@ -3,6 +3,7 @@ import pytest from mcp.server.experimental.request_context import Experimental +from mcp.server.experimental.task_context import ServerTaskContext from mcp.shared.exceptions import McpError from mcp.types import ( METHOD_NOT_FOUND, @@ -11,6 +12,7 @@ TASK_REQUIRED, ClientCapabilities, ClientTasksCapability, + Result, TaskMetadata, Tool, ToolExecution, @@ -164,3 +166,30 @@ def test_can_use_tool_forbidden_without_task_support() -> None: def test_can_use_tool_none_without_task_support() -> None: exp = Experimental(_client_capabilities=ClientCapabilities()) assert exp.can_use_tool(None) is True + + +@pytest.mark.anyio +async def test_run_task_with_an_explicit_task_id_emits_a_deprecation_warning() -> None: + """An explicitly provided task ID is not associated with the creating session, so passing one is deprecated.""" + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + + async def work(task: ServerTaskContext) -> Result: + raise AssertionError("unreachable") # pragma: no cover + + with pytest.warns(DeprecationWarning, match="not associated with the session"): + # Task support is not configured, so the call fails after the + # deprecated argument has been reported. + with pytest.raises(RuntimeError, match="Task support not enabled"): + # The deliberate use of the deprecated overload is the point of this test. + await exp.run_task(work, task_id="explicitly-chosen") # pyright: ignore[reportDeprecated] + + +@pytest.mark.anyio +async def test_run_task_without_a_task_id_does_not_warn() -> None: + exp = Experimental(task_metadata=TaskMetadata(ttl=60000)) + + async def work(task: ServerTaskContext) -> Result: + raise AssertionError("unreachable") # pragma: no cover + + with pytest.raises(RuntimeError, match="Task support not enabled"): + await exp.run_task(work) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 08fcabf276..7b7f68c425 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -54,6 +54,7 @@ async def authorize(self, client: OAuthClientInformationFull, params: Authorizat redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly, expires_at=time.time() + 300, scopes=params.scopes or ["read", "write"], + subject="test-user", ) self.auth_codes[code.code] = code @@ -80,6 +81,7 @@ async def exchange_authorization_code( client_id=client.client_id, scopes=authorization_code.scopes, expires_at=int(time.time()) + 3600, + subject=authorization_code.subject, ) self.refresh_tokens[refresh_token] = access_token @@ -109,6 +111,7 @@ async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_t client_id=token_info.client_id, scopes=token_info.scopes, expires_at=token_info.expires_at, + subject=token_info.subject, ) return refresh_obj @@ -142,6 +145,7 @@ async def exchange_refresh_token( client_id=client.client_id, scopes=scopes or token_info.scopes, expires_at=int(time.time()) + 3600, + subject=refresh_token.subject, ) self.refresh_tokens[new_refresh_token] = new_access_token @@ -170,6 +174,7 @@ async def load_access_token(self, token: str) -> AccessToken | None: client_id=token_info.client_id, scopes=token_info.scopes, expires_at=token_info.expires_at, + subject=token_info.subject, ) async def revoke_token(self, token: AccessToken | RefreshToken) -> None: @@ -783,6 +788,7 @@ async def test_authorization_get( assert auth_info.client_id == client_info["client_id"] assert "read" in auth_info.scopes assert "write" in auth_info.scopes + assert auth_info.subject == "test-user" # 6. Refresh the token response = await test_client.post( @@ -803,6 +809,10 @@ async def test_authorization_get( assert new_token_response["access_token"] != access_token assert new_token_response["refresh_token"] != refresh_token + refreshed_auth_info = await mock_oauth_provider.load_access_token(new_token_response["access_token"]) + assert refreshed_auth_info + assert refreshed_auth_info.subject == "test-user" + # 7. Revoke the token response = await test_client.post( "/revoke", diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a25..716a308a53 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -1,9 +1,13 @@ -"""Tests for SSE server DNS rebinding protection.""" +"""Tests for SSE server request validation.""" import logging import multiprocessing +import re import socket +from collections.abc import Iterator +from typing import Any +import anyio import httpx import pytest import uvicorn @@ -11,8 +15,11 @@ from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route +from starlette.types import Message from mcp.server import Server +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Tool @@ -22,6 +29,23 @@ SERVER_NAME = "test_sse_security_server" +@pytest.fixture(autouse=True) +def reset_sse_starlette_exit_event() -> Iterator[None]: + """sse-starlette<2 caches a module-level anyio.Event on AppStatus; clear it + around each test so it is never bound to a closed event loop. Clearing it + afterwards matters too: later test modules fork uvicorn subprocesses on + Linux and would otherwise inherit a stale event.""" + from sse_starlette.sse import AppStatus + + def clear() -> None: + if hasattr(AppStatus, "should_exit_event"): # pragma: no cover + setattr(AppStatus, "should_exit_event", None) + + clear() + yield + clear() + + @pytest.fixture def server_port() -> int: with socket.socket() as s: @@ -291,3 +315,112 @@ async def test_sse_security_post_valid_content_type(server_port: int): finally: process.terminate() process.join() + + +def _authenticated_user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _sse_scope(method: str, path: str, user: AuthenticatedUser | None) -> dict[str, Any]: + """Build an ASGI scope for a request to the SSE transport.""" + scope: dict[str, Any] = { + "type": "http", + "method": method, + "path": path, + "root_path": "", + "query_string": b"", + "headers": [(b"content-type", b"application/json")], + } + if user is not None: + scope["user"] = user + return scope + + +async def _post_message(transport: SseServerTransport, session_id: str, user: AuthenticatedUser | None) -> int: + """POST a message to an SSE session as `user` and return the response status.""" + body = b'{"jsonrpc": "2.0", "id": 1, "method": "ping", "params": null}' + scope = _sse_scope("POST", "/messages/", user) + scope["query_string"] = f"session_id={session_id}".encode() + sent: list[Message] = [] + + async def receive() -> Message: + return {"type": "http.request", "body": body, "more_body": False} + + async def send(message: Message) -> None: + sent.append(message) + + await transport.handle_post_message(scope, receive, send) + response_start = next(msg for msg in sent if msg["type"] == "http.response.start") + return response_start["status"] + + +_Principal = tuple[str] | tuple[str, str] | tuple[str, str, str] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("creator", "sender", "expected"), + [ + pytest.param(("client-a",), ("client-b",), 404, id="different-client"), + pytest.param(("client-a",), None, 404, id="unauthenticated-sender"), + pytest.param(("client-a", "alice"), ("client-a", "bob"), 404, id="same-client-different-subject"), + pytest.param(("client-a", "alice"), ("client-a",), 404, id="same-client-no-subject"), + pytest.param( + ("client-a", "alice", "https://i1"), ("client-a", "alice", "https://i2"), 404, id="different-issuer" + ), + pytest.param(None, ("client-a",), 404, id="unauthenticated-creator"), + pytest.param(("client-a",), ("client-a",), 202, id="same-client"), + pytest.param(("client-a", "alice"), ("client-a", "alice"), 202, id="same-client-and-subject"), + pytest.param(None, None, 202, id="both-unauthenticated"), + ], +) +async def test_sse_post_requires_the_credential_that_created_the_session( + creator: _Principal | None, + sender: _Principal | None, + expected: int, +): + """The session endpoint URL issued to one authenticated principal must not + accept messages from a request authenticated as a different one.""" + transport = SseServerTransport("/messages/") + session_id_received = anyio.Event() + session_ids: list[str] = [] + client_disconnected = anyio.Event() + + async def get_send(message: Message) -> None: + # The first body chunk is the SSE event announcing the session URI to POST messages to. + if message["type"] == "http.response.body" and not session_ids: + match = re.search(rb"session_id=([0-9a-f]{32})", message.get("body", b"")) + assert match is not None, f"expected the endpoint event first, got {message!r}" + session_ids.append(match.group(1).decode()) + session_id_received.set() + + async def get_receive() -> Message: + # The SSE client stays connected until the test signals otherwise. + await client_disconnected.wait() + return {"type": "http.disconnect"} + + creator_user = _authenticated_user(*creator) if creator is not None else None + sender_user = _authenticated_user(*sender) if sender is not None else None + + async def hold_sse_connection() -> None: + """Establish the SSE session as `creator` and keep it open, as a server would.""" + scope = _sse_scope("GET", "/sse", creator_user) + with anyio.fail_after(5): + async with transport.connect_sse(scope, get_receive, get_send) as (read_stream, write_stream): + async with read_stream, write_stream: + async for _ in read_stream: + pass + + async with anyio.create_task_group() as tg: + tg.start_soon(hold_sse_connection) + with anyio.fail_after(5): + await session_id_received.wait() + + assert await _post_message(transport, session_ids[0], sender_user) == expected + + client_disconnected.set() + + # Once the connection is gone the session is no longer routable. + assert await _post_message(transport, session_ids[0], creator_user) == 404 diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 33bcb5f2aa..0ae07c43ad 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -6,9 +6,11 @@ import anyio import pytest -from starlette.types import Message +from starlette.types import Message, Scope from mcp.server import streamable_http_manager +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken from mcp.server.lowlevel import Server from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager @@ -390,3 +392,166 @@ def test_session_idle_timeout_rejects_non_positive(): def test_session_idle_timeout_rejects_stateless(): with pytest.raises(RuntimeError, match="not supported in stateless"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) + + +def _user(client_id: str, subject: str | None = None, issuer: str | None = None) -> AuthenticatedUser: + """Build the scope["user"] value that AuthenticationMiddleware would set for this principal.""" + claims = {"iss": issuer} if issuer is not None else None + return AuthenticatedUser(AccessToken(token="token", client_id=client_id, scopes=[], subject=subject, claims=claims)) + + +def _request_scope( + *, session_id: str | None = None, user: AuthenticatedUser | None = None, method: str = "POST" +) -> Scope: + """Build an ASGI scope for a request to the MCP endpoint.""" + headers = [ + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + ] + if session_id is not None: + headers.append((b"mcp-session-id", session_id.encode())) + scope: Scope = { + "type": "http", + "method": method, + "path": "/mcp", + "headers": headers, + } + if user is not None: + scope["user"] = user + return scope + + +async def _open_session(manager: StreamableHTTPSessionManager, user: AuthenticatedUser | None) -> str: + """Create a new session as `user` and return its session ID.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request(_request_scope(user=user), mock_receive, mock_send) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + headers = dict(response_start.get("headers", [])) + return headers[MCP_SESSION_ID_HEADER.encode()].decode() + + +async def _request_session( + manager: StreamableHTTPSessionManager, session_id: str, user: AuthenticatedUser | None, method: str = "POST" +) -> int: + """Send a request for an existing session as `user` and return the response status.""" + sent_messages: list[Message] = [] + + async def mock_send(message: Message) -> None: + sent_messages.append(message) + + async def mock_receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request( + _request_scope(session_id=session_id, user=user, method=method), mock_receive, mock_send + ) + + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + return response_start["status"] + + +@pytest.fixture +async def manager_with_live_session(): + """A running manager around a real `Server`. Sessions remain registered until + `manager.run()` exits because `Server.run` blocks waiting for an initialize message.""" + manager = StreamableHTTPSessionManager(app=Server("test-session-credentials")) + async with manager.run(): + yield manager + + +@pytest.mark.anyio +async def test_session_accepts_requests_from_the_credential_that_created_it( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Requests presenting the same credential as the one that created the session are served.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + status = await _request_session(manager, session_id, _user("client-a")) + + # The request passes the manager's credential check and reaches the + # session's transport, instead of being answered with 404 by the manager. + assert status != 404 + + +@pytest.mark.anyio +@pytest.mark.parametrize("method", ["POST", "GET", "DELETE"]) +async def test_session_rejects_requests_from_a_different_credential( + manager_with_live_session: StreamableHTTPSessionManager, method: str +) -> None: + """A session created by one credential cannot be used with another credential, whatever the method.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, _user("client-b"), method) == 404 + # The session is still registered and still serves its creator. + assert await _request_session(manager, session_id, _user("client-a")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_from_a_different_subject_of_the_same_client( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Two end-users that share an OAuth client cannot use each other's sessions.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a", subject="alice")) + + assert await _request_session(manager, session_id, _user("client-a", subject="bob")) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject=None)) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_requests_with_the_same_subject_from_a_different_issuer( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A subject is unique only per issuer, so a colliding subject from a different issuer is not the same principal.""" + manager = manager_with_live_session + creator = _user("client-a", subject="alice", issuer="https://issuer.one") + session_id = await _open_session(manager, creator) + + other_issuer = _user("client-a", subject="alice", issuer="https://issuer.two") + assert await _request_session(manager, session_id, other_issuer) == 404 + assert await _request_session(manager, session_id, _user("client-a", subject="alice")) == 404 + assert await _request_session(manager, session_id, creator) != 404 + + +@pytest.mark.anyio +async def test_session_rejects_unauthenticated_requests_for_an_authenticated_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created with a credential cannot be used without one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, _user("client-a")) + + assert await _request_session(manager, session_id, None) == 404 + + +@pytest.mark.anyio +async def test_session_rejects_authenticated_requests_for_an_anonymous_session( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """A session created without a credential cannot be used with one.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, _user("client-a")) == 404 + + +@pytest.mark.anyio +async def test_anonymous_session_accepts_anonymous_requests( + manager_with_live_session: StreamableHTTPSessionManager, +) -> None: + """Servers without authentication keep working: no credential on either side.""" + manager = manager_with_live_session + session_id = await _open_session(manager, None) + + assert await _request_session(manager, session_id, None) != 404