From 5462caa84ecf862e29c75946539ce66ad0329ab0 Mon Sep 17 00:00:00 2001 From: BSmick6 Date: Tue, 26 May 2026 13:06:57 -0400 Subject: [PATCH] fix(security): warn on empty allowed_hosts and improve 421/403 response bodies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When TransportSecuritySettings is constructed with DNS rebinding protection enabled but allowed_hosts=[], every request is silently rejected with a bare HTTP 421 — hard to diagnose without reading the SDK source. Add a model_validator that emits a logger.warning at construction time pointing users at allowed_hosts configuration. Also include the received header value and a configuration hint in the 421/403 response bodies. Removes all pragma: no cover markers from transport_security.py by adding a direct unit test file (test_transport_security.py) that exercises all branches without subprocesses. Updates the existing integration tests to use substring matching now that the response bodies carry extra context. Closes #2688 Co-Authored-By: Claude Sonnet 4.6 --- src/mcp/server/transport_security.py | 50 +++-- tests/server/test_sse_security.py | 6 +- tests/server/test_streamable_http_security.py | 6 +- tests/server/test_transport_security.py | 182 ++++++++++++++++++ 4 files changed, 223 insertions(+), 21 deletions(-) create mode 100644 tests/server/test_transport_security.py diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0e..4084cc8f50 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -2,9 +2,10 @@ import logging -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from starlette.requests import Request from starlette.responses import Response +from typing_extensions import Self logger = logging.getLogger(__name__) @@ -31,6 +32,17 @@ class TransportSecuritySettings(BaseModel): Only applies when `enable_dns_rebinding_protection` is `True`. """ + @model_validator(mode="after") + def _warn_if_protection_enabled_with_empty_allowlist(self) -> Self: + if self.enable_dns_rebinding_protection and not self.allowed_hosts: + logger.warning( + "TransportSecuritySettings has DNS rebinding protection enabled but " + "allowed_hosts is empty — all requests will be rejected with HTTP 421. " + "Set allowed_hosts to your server's hostname(s), e.g. " + 'TransportSecuritySettings(allowed_hosts=["your-host.example.com:*"])' + ) + return self + # TODO(Marcelo): This should be a proper ASGI middleware. I'm sad to see this. class TransportSecurityMiddleware: @@ -40,7 +52,7 @@ def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) - def _validate_host(self, host: str | None) -> bool: # pragma: no cover + def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" if not host: logger.warning("Missing Host header in request") @@ -62,7 +74,7 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover logger.warning(f"Invalid Host header: {host}") return False - def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover + def _validate_origin(self, origin: str | None) -> bool: """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests if not origin: @@ -94,7 +106,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res Returns None if validation passes, or an error Response if validation fails. """ # Always validate Content-Type for POST requests - if is_post: # pragma: no branch + if is_post: content_type = request.headers.get("content-type") if not self._validate_content_type(content_type): return Response("Invalid Content-Type header", status_code=400) @@ -103,14 +115,22 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res if not self.settings.enable_dns_rebinding_protection: return None - # Validate Host header # pragma: no cover - host = request.headers.get("host") # pragma: no cover - if not self._validate_host(host): # pragma: no cover - return Response("Invalid Host header", status_code=421) # pragma: no cover - - # Validate Origin header # pragma: no cover - origin = request.headers.get("origin") # pragma: no cover - if not self._validate_origin(origin): # pragma: no cover - return Response("Invalid Origin header", status_code=403) # pragma: no cover - - return None # pragma: no cover + # Validate Host header + host = request.headers.get("host") + if not self._validate_host(host): + return Response( + f"Invalid Host header: {host!r}. " + "Configure TransportSecuritySettings(allowed_hosts=[...]) with your server's hostname.", + status_code=421, + ) + + # Validate Origin header + origin = request.headers.get("origin") + if not self._validate_origin(origin): + return Response( + f"Invalid Origin header: {origin!r}. " + "Configure TransportSecuritySettings(allowed_origins=[...]) with your server's origin.", + status_code=403, + ) + + return None diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a25..d184c67c76 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -105,7 +105,7 @@ async def test_sse_security_invalid_host_header(server_port: int): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert "Invalid Host header" in response.text finally: process.terminate() @@ -128,7 +128,7 @@ async def test_sse_security_invalid_origin_header(server_port: int): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) assert response.status_code == 403 - assert response.text == "Invalid Origin header" + assert "Invalid Origin header" in response.text finally: process.terminate() @@ -215,7 +215,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert "Invalid Host header" in response.text finally: process.terminate() diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 897555353e..abec538abc 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -126,7 +126,7 @@ async def test_streamable_http_security_invalid_host_header(server_port: int): headers=headers, ) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert "Invalid Host header" in response.text finally: process.terminate() @@ -154,7 +154,7 @@ async def test_streamable_http_security_invalid_origin_header(server_port: int): headers=headers, ) assert response.status_code == 403 - assert response.text == "Invalid Origin header" + assert "Invalid Origin header" in response.text finally: process.terminate() @@ -269,7 +269,7 @@ async def test_streamable_http_security_get_request(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert "Invalid Host header" in response.text # Test GET request with valid host header headers = { diff --git a/tests/server/test_transport_security.py b/tests/server/test_transport_security.py new file mode 100644 index 0000000000..877efeae24 --- /dev/null +++ b/tests/server/test_transport_security.py @@ -0,0 +1,182 @@ +"""Unit tests for TransportSecuritySettings and TransportSecurityMiddleware.""" + +import logging + +import pytest +from starlette.requests import Request + +from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings + + +def make_request(headers: dict[str, str], method: str = "GET") -> Request: + scope = { + "type": "http", + "method": method, + "headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()], + "path": "/", + "query_string": b"", + } + return Request(scope) + + +# --------------------------------------------------------------------------- +# TransportSecuritySettings — construction-time warning +# --------------------------------------------------------------------------- + + +def test_no_warning_when_protection_disabled(caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"): + TransportSecuritySettings(enable_dns_rebinding_protection=False) + assert not caplog.records + + +def test_no_warning_when_allowed_hosts_populated(caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"): + TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["example.com"], + ) + assert not caplog.records + + +def test_warning_when_protection_enabled_with_empty_allowed_hosts(caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"): + TransportSecuritySettings(enable_dns_rebinding_protection=True) + assert len(caplog.records) == 1 + assert "allowed_hosts is empty" in caplog.records[0].message + assert "HTTP 421" in caplog.records[0].message + assert "allowed_hosts=" in caplog.records[0].message + + +# --------------------------------------------------------------------------- +# TransportSecurityMiddleware._validate_host +# --------------------------------------------------------------------------- + + +def test_validate_host_missing_host() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"])) + assert m._validate_host(None) is False + + +def test_validate_host_exact_match() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"])) + assert m._validate_host("example.com") is True + + +def test_validate_host_exact_no_match() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"])) + assert m._validate_host("other.com") is False + + +def test_validate_host_port_wildcard_match() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["localhost:*"])) + assert m._validate_host("localhost:8080") is True + + +def test_validate_host_port_wildcard_different_base() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["localhost:*"])) + assert m._validate_host("other:8080") is False + + +def test_validate_host_port_wildcard_no_port() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["localhost:*"])) + assert m._validate_host("localhost") is False + + +# --------------------------------------------------------------------------- +# TransportSecurityMiddleware._validate_origin +# --------------------------------------------------------------------------- + + +def test_validate_origin_absent_is_allowed() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://example.com"])) + assert m._validate_origin(None) is True + + +def test_validate_origin_exact_match() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://example.com"])) + assert m._validate_origin("http://example.com") is True + + +def test_validate_origin_exact_no_match() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://example.com"])) + assert m._validate_origin("http://other.com") is False + + +def test_validate_origin_port_wildcard_match() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://localhost:*"])) + assert m._validate_origin("http://localhost:3000") is True + + +def test_validate_origin_port_wildcard_different_base() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://localhost:*"])) + assert m._validate_origin("http://other:3000") is False + + +# --------------------------------------------------------------------------- +# TransportSecurityMiddleware.validate_request +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_validate_request_post_valid_content_type() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + request = make_request({"content-type": "application/json"}, method="POST") + assert await m.validate_request(request, is_post=True) is None + + +@pytest.mark.anyio +async def test_validate_request_post_invalid_content_type() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + request = make_request({"content-type": "text/plain"}, method="POST") + response = await m.validate_request(request, is_post=True) + assert response is not None + assert response.status_code == 400 + + +@pytest.mark.anyio +async def test_validate_request_get_skips_content_type() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + request = make_request({}) + assert await m.validate_request(request, is_post=False) is None + + +@pytest.mark.anyio +async def test_validate_request_protection_disabled_allows_any_host() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + request = make_request({"host": "attacker.example.com"}) + assert await m.validate_request(request) is None + + +@pytest.mark.anyio +async def test_validate_request_valid_host_and_no_origin() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"])) + request = make_request({"host": "example.com"}) + assert await m.validate_request(request) is None + + +@pytest.mark.anyio +async def test_validate_request_invalid_host_returns_421_with_detail() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"])) + request = make_request({"host": "attacker.com"}) + response = await m.validate_request(request) + assert response is not None + assert response.status_code == 421 + assert b"attacker.com" in response.body + assert b"allowed_hosts" in response.body + + +@pytest.mark.anyio +async def test_validate_request_invalid_origin_returns_403_with_detail() -> None: + m = TransportSecurityMiddleware( + TransportSecuritySettings( + allowed_hosts=["example.com"], + allowed_origins=["http://example.com"], + ) + ) + request = make_request({"host": "example.com", "origin": "http://attacker.com"}) + response = await m.validate_request(request) + assert response is not None + assert response.status_code == 403 + assert b"attacker.com" in response.body + assert b"allowed_origins" in response.body