From a0bbcdfd21c97f53992c57fd2c55b3daea8c495a Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Sun, 24 May 2026 22:16:18 +0200 Subject: [PATCH 1/4] fix(auth): make get_access_token per-request in stateful sessions --- .../server/auth/middleware/auth_context.py | 23 ++++ src/mcp/server/lowlevel/server.py | 8 +- .../test_get_access_token_streamable_http.py | 100 ++++++++++++++++++ 3 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 tests/server/auth/test_get_access_token_streamable_http.py diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 1d34a5546b..f34b98cefd 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -1,5 +1,8 @@ import contextvars +from contextvars import Token + +from starlette.requests import Request from starlette.types import ASGIApp, Receive, Scope, Send from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser @@ -20,6 +23,26 @@ def get_access_token() -> AccessToken | None: return auth_user.access_token if auth_user else None +def _push_auth_context_from_request(request: Request | None) -> Token[AuthenticatedUser | None] | None: + """Set auth context for the current task from an incoming request. + + This is primarily used by server transports where request handlers may run + in background tasks that are not part of the original ASGI request task. + """ + if request is None: + return None + user = getattr(request, "user", None) + if isinstance(user, AuthenticatedUser): + return auth_context_var.set(user) + return None + + +def _pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None: + if token is None: + return + auth_context_var.reset(token) + + class AuthContextMiddleware: """Middleware that extracts the authenticated user from the request and sets it in a contextvar for easy access throughout the request lifecycle. diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 59de0ace45..506861b56f 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -53,7 +53,7 @@ async def main(): from typing_extensions import TypeVar from mcp import types -from mcp.server.auth.middleware.auth_context import AuthContextMiddleware +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, _pop_auth_context, _push_auth_context_from_request from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes @@ -497,7 +497,11 @@ async def _handle_request( close_sse_stream=close_sse_stream_cb, close_standalone_sse_stream=close_standalone_sse_stream_cb, ) - response = await handler(ctx, req.params) + auth_token = _push_auth_context_from_request(request_data) + try: + response = await handler(ctx, req.params) + finally: + _pop_auth_context(auth_token) except MCPError as err: response = err.error except anyio.get_cancelled_exc_class(): diff --git a/tests/server/auth/test_get_access_token_streamable_http.py b/tests/server/auth/test_get_access_token_streamable_http.py new file mode 100644 index 0000000000..3e0bd24272 --- /dev/null +++ b/tests/server/auth/test_get_access_token_streamable_http.py @@ -0,0 +1,100 @@ +import time + +import httpx +import pytest +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.routing import Mount + +from mcp import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend +from mcp.server.auth.provider import AccessToken +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) + + +class _EchoTokenVerifier: + """Accepts any bearer token and echoes it back as the verified AccessToken.""" + + async def verify_token(self, token: str) -> AccessToken | None: + return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600) + + +async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + access = get_access_token() + text = access.token if access else "" + return CallToolResult(content=[TextContent(type="text", text=text)]) + + +async def _handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object", "properties": {}})]) + + +class _MutableBearerAuth(httpx.Auth): + def __init__(self, token: str) -> None: + self.token = token + + def auth_flow(self, request: httpx.Request): + request.headers["Authorization"] = f"Bearer {self.token}" + yield request + + +@pytest.mark.anyio +async def test_get_access_token_reflects_current_request_in_stateful_session() -> None: + host = "testserver" + + server = Server( + "auth-test-server", + on_call_tool=_handle_whoami, + on_list_tools=_handle_list_tools, + ) + + security = TransportSecuritySettings( + allowed_hosts=[host, f"{host}:*"], + allowed_origins=[f"http://{host}:*"], + ) + session_manager = StreamableHTTPSessionManager(app=server, security_settings=security, stateless=False) + + asgi_app = Starlette( + routes=[Mount("/mcp", app=session_manager.handle_request)], + middleware=[ + Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())), + Middleware(AuthContextMiddleware), + ], + lifespan=lambda app: session_manager.run(), + ) + + auth = _MutableBearerAuth("token-A") + async with asgi_app.router.lifespan_context(asgi_app): + async with ( + httpx.ASGITransport(asgi_app) as transport, + httpx.AsyncClient( + transport=transport, + base_url=f"http://{host}", + auth=auth, + timeout=httpx.Timeout(30, read=30), + follow_redirects=True, + ) as http_client, + ): + transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client) + async with Client(transport_ctx) as client: + r1 = await client.call_tool("whoami", {}) + assert isinstance(r1.content[0], TextContent) + assert r1.content[0].text == "token-A" + + auth.token = "token-B" + r2 = await client.call_tool("whoami", {}) + assert isinstance(r2.content[0], TextContent) + assert r2.content[0].text == "token-B" From de67bd78f60e923055e297c5801595ef0dcd969e Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Sun, 24 May 2026 22:19:07 +0200 Subject: [PATCH 2/4] fix(auth): avoid Request.user assertion without auth middleware --- src/mcp/server/auth/middleware/auth_context.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index f34b98cefd..1edbc57de6 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -31,7 +31,16 @@ def _push_auth_context_from_request(request: Request | None) -> Token[Authentica """ if request is None: return None - user = getattr(request, "user", None) + # Avoid Request.user, which asserts AuthenticationMiddleware is installed. + user = None + scope = getattr(request, "scope", None) + if isinstance(scope, dict): + user = scope.get("user") + if user is None: + try: + user = getattr(request, "user", None) + except AssertionError: + user = None if isinstance(user, AuthenticatedUser): return auth_context_var.set(user) return None From b9bf42a833ceab11f6fe459235b3609fe36c99d9 Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Sun, 24 May 2026 22:22:53 +0200 Subject: [PATCH 3/4] chore(auth): type-safe auth context push --- src/mcp/server/auth/middleware/auth_context.py | 11 ++++------- src/mcp/server/lowlevel/server.py | 10 +++++++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 1edbc57de6..0d7b3d6cbf 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -1,6 +1,6 @@ import contextvars - from contextvars import Token +from typing import Any from starlette.requests import Request from starlette.types import ASGIApp, Receive, Scope, Send @@ -23,7 +23,7 @@ def get_access_token() -> AccessToken | None: return auth_user.access_token if auth_user else None -def _push_auth_context_from_request(request: Request | None) -> Token[AuthenticatedUser | None] | None: +def push_auth_context_from_request(request: Request | None) -> Token[AuthenticatedUser | None] | None: """Set auth context for the current task from an incoming request. This is primarily used by server transports where request handlers may run @@ -32,10 +32,7 @@ def _push_auth_context_from_request(request: Request | None) -> Token[Authentica if request is None: return None # Avoid Request.user, which asserts AuthenticationMiddleware is installed. - user = None - scope = getattr(request, "scope", None) - if isinstance(scope, dict): - user = scope.get("user") + user: Any | None = request.scope.get("user") if user is None: try: user = getattr(request, "user", None) @@ -46,7 +43,7 @@ def _push_auth_context_from_request(request: Request | None) -> Token[Authentica return None -def _pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None: +def pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None: if token is None: return auth_context_var.reset(token) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 506861b56f..122ef3f14e 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -53,7 +53,11 @@ async def main(): from typing_extensions import TypeVar from mcp import types -from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, _pop_auth_context, _push_auth_context_from_request +from mcp.server.auth.middleware.auth_context import ( + AuthContextMiddleware, + pop_auth_context, + push_auth_context_from_request, +) from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes @@ -497,11 +501,11 @@ async def _handle_request( close_sse_stream=close_sse_stream_cb, close_standalone_sse_stream=close_standalone_sse_stream_cb, ) - auth_token = _push_auth_context_from_request(request_data) + auth_token = push_auth_context_from_request(request_data) try: response = await handler(ctx, req.params) finally: - _pop_auth_context(auth_token) + pop_auth_context(auth_token) except MCPError as err: response = err.error except anyio.get_cancelled_exc_class(): From 9e515ac92334fbfef3832ad0d48dd20fb16e74f7 Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Mon, 25 May 2026 01:09:41 +0200 Subject: [PATCH 4/4] fix auth context reset for streamable HTTP --- .../server/auth/middleware/auth_context.py | 4 +--- .../test_get_access_token_streamable_http.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 0d7b3d6cbf..31eb58b5b9 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -38,9 +38,7 @@ def push_auth_context_from_request(request: Request | None) -> Token[Authenticat user = getattr(request, "user", None) except AssertionError: user = None - if isinstance(user, AuthenticatedUser): - return auth_context_var.set(user) - return None + return auth_context_var.set(user if isinstance(user, AuthenticatedUser) else None) def pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None: diff --git a/tests/server/auth/test_get_access_token_streamable_http.py b/tests/server/auth/test_get_access_token_streamable_http.py index 3e0bd24272..9125fb2aee 100644 --- a/tests/server/auth/test_get_access_token_streamable_http.py +++ b/tests/server/auth/test_get_access_token_streamable_http.py @@ -14,7 +14,6 @@ from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend from mcp.server.auth.provider import AccessToken from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from mcp.server.transport_security import TransportSecuritySettings from mcp.types import ( CallToolRequestParams, CallToolResult, @@ -43,11 +42,12 @@ async def _handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequest class _MutableBearerAuth(httpx.Auth): - def __init__(self, token: str) -> None: + def __init__(self, token: str | None) -> None: self.token = token def auth_flow(self, request: httpx.Request): - request.headers["Authorization"] = f"Bearer {self.token}" + if self.token is not None: + request.headers["Authorization"] = f"Bearer {self.token}" yield request @@ -61,11 +61,7 @@ async def test_get_access_token_reflects_current_request_in_stateful_session() - on_list_tools=_handle_list_tools, ) - security = TransportSecuritySettings( - allowed_hosts=[host, f"{host}:*"], - allowed_origins=[f"http://{host}:*"], - ) - session_manager = StreamableHTTPSessionManager(app=server, security_settings=security, stateless=False) + session_manager = StreamableHTTPSessionManager(app=server, stateless=False) asgi_app = Starlette( routes=[Mount("/mcp", app=session_manager.handle_request)], @@ -89,7 +85,7 @@ async def test_get_access_token_reflects_current_request_in_stateful_session() - ) as http_client, ): transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client) - async with Client(transport_ctx) as client: + async with Client(transport_ctx) as client: # pragma: no branch r1 = await client.call_tool("whoami", {}) assert isinstance(r1.content[0], TextContent) assert r1.content[0].text == "token-A" @@ -98,3 +94,8 @@ async def test_get_access_token_reflects_current_request_in_stateful_session() - r2 = await client.call_tool("whoami", {}) assert isinstance(r2.content[0], TextContent) assert r2.content[0].text == "token-B" + + auth.token = None + r3 = await client.call_tool("whoami", {}) + assert isinstance(r3.content[0], TextContent) + assert r3.content[0].text == ""