Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/servers/simple-auth/mcp_simple_auth/auth_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ async def introspect_handler(request: Request) -> Response:
"iat": int(time.time()),
"token_type": "Bearer",
"aud": access_token.resource, # RFC 8707 audience claim
"sub": access_token.subject, # RFC 7662 subject
"iss": str(server_settings.server_url),
}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ async def handle_simple_callback(self, username: str, password: str, state: str)
scopes=[self.settings.mcp_scope],
code_challenge=code_challenge,
resource=resource, # RFC 8707
subject=username,
)
self.auth_codes[new_code] = auth_code

Expand Down Expand Up @@ -219,6 +220,7 @@ async def exchange_authorization_code(
scopes=authorization_code.scopes,
expires_at=int(time.time()) + 3600,
resource=authorization_code.resource, # RFC 8707
subject=authorization_code.subject,
)

# Store user data mapping for this token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ async def verify_token(self, token: str) -> AccessToken | None:
scopes=data.get("scope", "").split() if data.get("scope") else [],
expires_at=data.get("exp"),
resource=data.get("aud"), # Include resource in token
subject=data.get("sub"), # RFC 7662 subject (resource owner)
claims=data,
)
except Exception as e:
logger.warning(f"Token introspection failed: {e}")
Expand Down
6 changes: 5 additions & 1 deletion src/mcp/server/auth/provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Generic, Literal, Protocol, TypeVar
from typing import Any, Generic, Literal, Protocol, TypeVar
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

from pydantic import AnyUrl, BaseModel
Expand All @@ -25,13 +25,15 @@ class AuthorizationCode(BaseModel):
redirect_uri: AnyUrl
redirect_uri_provided_explicitly: bool
resource: str | None = None # RFC 8707 resource indicator
subject: str | None = None # resource owner; propagate to the issued AccessToken


class RefreshToken(BaseModel):
token: str
client_id: str
scopes: list[str]
expires_at: int | None = None
subject: str | None = None # resource owner; propagate to refreshed AccessTokens


class AccessToken(BaseModel):
Expand All @@ -40,6 +42,8 @@ class AccessToken(BaseModel):
scopes: list[str]
expires_at: int | None = None
resource: str | None = None # RFC 8707 resource indicator
subject: str | None = None # RFC 7662/9068 `sub`: resource owner; unique only per issuer
claims: dict[str, Any] | None = None # additional claims (e.g. `iss`, `act`)


RegistrationErrorCode = Literal[
Expand Down
7 changes: 6 additions & 1 deletion src/mcp/server/mcpserver/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,14 @@ async def log(
related_request_id=self.request_id,
)

# TODO(maxisbey): see if this is needed otherwise remove
@property
def client_id(self) -> str | None:
"""Get the client ID if available."""
"""Get the client ID if available.

Note: this reads from the MCP request's `_meta` params, not the OAuth
bearer token. For that, use `get_access_token().client_id`.
"""
return self.request_context.meta.get("client_id") if self.request_context.meta else None # pragma: no cover

@property
Expand Down
10 changes: 10 additions & 0 deletions tests/server/mcpserver/auth/test_auth_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ async def authorize(self, client: OAuthClientInformationFull, params: Authorizat
redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly,
expires_at=time.time() + 300,
scopes=params.scopes or ["read", "write"],
subject="test-user",
)
self.auth_codes[code.code] = code

Expand All @@ -79,6 +80,7 @@ async def exchange_authorization_code(
client_id=client.client_id,
scopes=authorization_code.scopes,
expires_at=int(time.time()) + 3600,
subject=authorization_code.subject,
)

self.refresh_tokens[refresh_token] = access_token
Expand Down Expand Up @@ -108,6 +110,7 @@ async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_t
client_id=token_info.client_id,
scopes=token_info.scopes,
expires_at=token_info.expires_at,
subject=token_info.subject,
)

return refresh_obj
Expand Down Expand Up @@ -141,6 +144,7 @@ async def exchange_refresh_token(
client_id=client.client_id,
scopes=scopes or token_info.scopes,
expires_at=int(time.time()) + 3600,
subject=refresh_token.subject,
)

self.refresh_tokens[new_refresh_token] = new_access_token
Expand Down Expand Up @@ -169,6 +173,7 @@ async def load_access_token(self, token: str) -> AccessToken | None:
client_id=token_info.client_id,
scopes=token_info.scopes,
expires_at=token_info.expires_at,
subject=token_info.subject,
)

async def revoke_token(self, token: AccessToken | RefreshToken) -> None:
Expand Down Expand Up @@ -832,6 +837,7 @@ async def test_authorization_get(
assert auth_info.client_id == client_info["client_id"]
assert "read" in auth_info.scopes
assert "write" in auth_info.scopes
assert auth_info.subject == "test-user"

# 6. Refresh the token
response = await test_client.post(
Expand All @@ -852,6 +858,10 @@ async def test_authorization_get(
assert new_token_response["access_token"] != access_token
assert new_token_response["refresh_token"] != refresh_token

refreshed_auth_info = await mock_oauth_provider.load_access_token(new_token_response["access_token"])
assert refreshed_auth_info
assert refreshed_auth_info.subject == "test-user"

# 7. Revoke the token
response = await test_client.post(
"/revoke",
Expand Down
Loading