Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/mcp/server/auth/middleware/auth_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import contextvars
from contextvars import Token
from typing import Any

from starlette.requests import Request
from starlette.types import ASGIApp, Receive, Scope, Send

from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
Expand All @@ -20,6 +23,30 @@ def get_access_token() -> AccessToken | None:
return auth_user.access_token if auth_user else None


def push_auth_context_from_request(request: Request | None) -> Token[AuthenticatedUser | None] | None:
"""Set auth context for the current task from an incoming request.

This is primarily used by server transports where request handlers may run
in background tasks that are not part of the original ASGI request task.
"""
if request is None:
return None
# Avoid Request.user, which asserts AuthenticationMiddleware is installed.
user: Any | None = request.scope.get("user")
if user is None:
try:
user = getattr(request, "user", None)
except AssertionError:
user = None
return auth_context_var.set(user if isinstance(user, AuthenticatedUser) else None)


def pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None:
if token is None:
return
auth_context_var.reset(token)


class AuthContextMiddleware:
"""Middleware that extracts the authenticated user from the request
and sets it in a contextvar for easy access throughout the request lifecycle.
Expand Down
12 changes: 10 additions & 2 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ async def main():
from typing_extensions import TypeVar

from mcp import types
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
from mcp.server.auth.middleware.auth_context import (
AuthContextMiddleware,
pop_auth_context,
push_auth_context_from_request,
)
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes
Expand Down Expand Up @@ -497,7 +501,11 @@ async def _handle_request(
close_sse_stream=close_sse_stream_cb,
close_standalone_sse_stream=close_standalone_sse_stream_cb,
)
response = await handler(ctx, req.params)
auth_token = push_auth_context_from_request(request_data)
try:
response = await handler(ctx, req.params)
finally:
pop_auth_context(auth_token)
except MCPError as err:
response = err.error
except anyio.get_cancelled_exc_class():
Expand Down
101 changes: 101 additions & 0 deletions tests/server/auth/test_get_access_token_streamable_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import time

import httpx
import pytest
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.routing import Mount

from mcp import Client
from mcp.client.streamable_http import streamable_http_client
from mcp.server import Server, ServerRequestContext
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
from mcp.server.auth.provider import AccessToken
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.types import (
CallToolRequestParams,
CallToolResult,
ListToolsResult,
PaginatedRequestParams,
TextContent,
Tool,
)


class _EchoTokenVerifier:
"""Accepts any bearer token and echoes it back as the verified AccessToken."""

async def verify_token(self, token: str) -> AccessToken | None:
return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600)


async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
access = get_access_token()
text = access.token if access else "<none>"
return CallToolResult(content=[TextContent(type="text", text=text)])


async def _handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult:
return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object", "properties": {}})])


class _MutableBearerAuth(httpx.Auth):
def __init__(self, token: str | None) -> None:
self.token = token

def auth_flow(self, request: httpx.Request):
if self.token is not None:
request.headers["Authorization"] = f"Bearer {self.token}"
yield request


@pytest.mark.anyio
async def test_get_access_token_reflects_current_request_in_stateful_session() -> None:
host = "testserver"

server = Server(
"auth-test-server",
on_call_tool=_handle_whoami,
on_list_tools=_handle_list_tools,
)

session_manager = StreamableHTTPSessionManager(app=server, stateless=False)

asgi_app = Starlette(
routes=[Mount("/mcp", app=session_manager.handle_request)],
middleware=[
Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())),
Middleware(AuthContextMiddleware),
],
lifespan=lambda app: session_manager.run(),
)

auth = _MutableBearerAuth("token-A")
async with asgi_app.router.lifespan_context(asgi_app):
async with (
httpx.ASGITransport(asgi_app) as transport,
httpx.AsyncClient(
transport=transport,
base_url=f"http://{host}",
auth=auth,
timeout=httpx.Timeout(30, read=30),
follow_redirects=True,
) as http_client,
):
transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client)
async with Client(transport_ctx) as client: # pragma: no branch
r1 = await client.call_tool("whoami", {})
assert isinstance(r1.content[0], TextContent)
assert r1.content[0].text == "token-A"

auth.token = "token-B"
r2 = await client.call_tool("whoami", {})
assert isinstance(r2.content[0], TextContent)
assert r2.content[0].text == "token-B"

auth.token = None
r3 = await client.call_tool("whoami", {})
assert isinstance(r3.content[0], TextContent)
assert r3.content[0].text == "<none>"
Loading