Skip to content

Commit 0019801

Browse files
GWealesasha-gitg
authored andcommitted
fix: add protection for arbitrary module imports
Close #4947 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 888296476
1 parent 79ed953 commit 0019801

3 files changed

Lines changed: 292 additions & 3 deletions

File tree

src/google/adk/cli/adk_web_server.py

Lines changed: 178 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import json
2121
import logging
2222
import os
23+
import re
2324
import sys
2425
import time
2526
import traceback
@@ -140,6 +141,158 @@ def _parse_cors_origins(
140141
return literal_origins, combined_regex
141142

142143

144+
def _is_origin_allowed(
145+
origin: str,
146+
allowed_literal_origins: list[str],
147+
allowed_origin_regex: Optional[re.Pattern[str]],
148+
) -> bool:
149+
"""Check whether the given origin matches the allowed origins."""
150+
if "*" in allowed_literal_origins:
151+
return True
152+
if origin in allowed_literal_origins:
153+
return True
154+
if allowed_origin_regex is not None:
155+
return allowed_origin_regex.fullmatch(origin) is not None
156+
return False
157+
158+
159+
def _normalize_origin_scheme(scheme: str) -> str:
160+
"""Normalize request schemes to the browser Origin scheme space."""
161+
if scheme == "ws":
162+
return "http"
163+
if scheme == "wss":
164+
return "https"
165+
return scheme
166+
167+
168+
def _strip_optional_quotes(value: str) -> str:
169+
"""Strip a single pair of wrapping quotes from a header value."""
170+
if len(value) >= 2 and value[0] == '"' and value[-1] == '"':
171+
return value[1:-1]
172+
return value
173+
174+
175+
def _get_scope_header(
176+
scope: dict[str, Any], header_name: bytes
177+
) -> Optional[str]:
178+
"""Return the first matching header value from an ASGI scope."""
179+
for candidate_name, candidate_value in scope.get("headers", []):
180+
if candidate_name == header_name:
181+
return candidate_value.decode("latin-1").split(",", 1)[0].strip()
182+
return None
183+
184+
185+
def _get_request_origin(scope: dict[str, Any]) -> Optional[str]:
186+
"""Compute the effective origin for the current HTTP/WebSocket request."""
187+
forwarded = _get_scope_header(scope, b"forwarded")
188+
if forwarded is not None:
189+
proto = None
190+
host = None
191+
for element in forwarded.split(",", 1)[0].split(";"):
192+
if "=" not in element:
193+
continue
194+
name, value = element.split("=", 1)
195+
if name.strip().lower() == "proto":
196+
proto = _strip_optional_quotes(value.strip())
197+
elif name.strip().lower() == "host":
198+
host = _strip_optional_quotes(value.strip())
199+
if proto is not None and host is not None:
200+
return f"{_normalize_origin_scheme(proto)}://{host}"
201+
202+
host = _get_scope_header(scope, b"x-forwarded-host")
203+
if host is None:
204+
host = _get_scope_header(scope, b"host")
205+
if host is None:
206+
return None
207+
208+
proto = _get_scope_header(scope, b"x-forwarded-proto")
209+
if proto is None:
210+
proto = scope.get("scheme", "http")
211+
return f"{_normalize_origin_scheme(proto)}://{host}"
212+
213+
214+
def _is_request_origin_allowed(
215+
origin: str,
216+
scope: dict[str, Any],
217+
allowed_literal_origins: list[str],
218+
allowed_origin_regex: Optional[re.Pattern[str]],
219+
has_configured_allowed_origins: bool,
220+
) -> bool:
221+
"""Validate an Origin header against explicit config or same-origin."""
222+
if has_configured_allowed_origins and _is_origin_allowed(
223+
origin, allowed_literal_origins, allowed_origin_regex
224+
):
225+
return True
226+
227+
request_origin = _get_request_origin(scope)
228+
if request_origin is None:
229+
return False
230+
return origin == request_origin
231+
232+
233+
_SAFE_HTTP_METHODS = frozenset({"GET", "HEAD", "OPTIONS"})
234+
235+
236+
class _OriginCheckMiddleware:
237+
"""ASGI middleware that blocks cross-origin state-changing requests."""
238+
239+
def __init__(
240+
self,
241+
app: Any,
242+
has_configured_allowed_origins: bool,
243+
allowed_origins: list[str],
244+
allowed_origin_regex: Optional[re.Pattern[str]],
245+
) -> None:
246+
self._app = app
247+
self._has_configured_allowed_origins = has_configured_allowed_origins
248+
self._allowed_origins = allowed_origins
249+
self._allowed_origin_regex = allowed_origin_regex
250+
251+
async def __call__(
252+
self,
253+
scope: dict[str, Any],
254+
receive: Any,
255+
send: Any,
256+
) -> None:
257+
if scope["type"] != "http":
258+
await self._app(scope, receive, send)
259+
return
260+
261+
method = scope.get("method", "GET")
262+
if method in _SAFE_HTTP_METHODS:
263+
await self._app(scope, receive, send)
264+
return
265+
266+
origin = _get_scope_header(scope, b"origin")
267+
if origin is None:
268+
await self._app(scope, receive, send)
269+
return
270+
271+
if _is_request_origin_allowed(
272+
origin,
273+
scope,
274+
self._allowed_origins,
275+
self._allowed_origin_regex,
276+
self._has_configured_allowed_origins,
277+
):
278+
await self._app(scope, receive, send)
279+
return
280+
281+
response_body = b"Forbidden: origin not allowed"
282+
await send({
283+
"type": "http.response.start",
284+
"status": 403,
285+
"headers": [
286+
(b"content-type", b"text/plain"),
287+
(b"content-length", str(len(response_body)).encode()),
288+
],
289+
})
290+
await send({
291+
"type": "http.response.body",
292+
"body": response_body,
293+
})
294+
295+
143296
class ApiServerSpanExporter(export_lib.SpanExporter):
144297

145298
def __init__(self, trace_dict):
@@ -759,8 +912,12 @@ async def internal_lifespan(app: FastAPI):
759912
# Run the FastAPI server.
760913
app = FastAPI(lifespan=internal_lifespan)
761914

915+
has_configured_allowed_origins = bool(allow_origins)
762916
if allow_origins:
763917
literal_origins, combined_regex = _parse_cors_origins(allow_origins)
918+
compiled_origin_regex = (
919+
re.compile(combined_regex) if combined_regex is not None else None
920+
)
764921
app.add_middleware(
765922
CORSMiddleware,
766923
allow_origins=literal_origins,
@@ -769,6 +926,16 @@ async def internal_lifespan(app: FastAPI):
769926
allow_methods=["*"],
770927
allow_headers=["*"],
771928
)
929+
else:
930+
literal_origins = []
931+
compiled_origin_regex = None
932+
933+
app.add_middleware(
934+
_OriginCheckMiddleware,
935+
has_configured_allowed_origins=has_configured_allowed_origins,
936+
allowed_origins=literal_origins,
937+
allowed_origin_regex=compiled_origin_regex,
938+
)
772939

773940
@app.get("/health")
774941
async def health() -> dict[str, str]:
@@ -1755,14 +1922,23 @@ async def run_agent_live(
17551922
enable_affective_dialog: bool | None = Query(default=None),
17561923
enable_session_resumption: bool | None = Query(default=None),
17571924
) -> None:
1925+
ws_origin = websocket.headers.get("origin")
1926+
if ws_origin is not None and not _is_request_origin_allowed(
1927+
ws_origin,
1928+
websocket.scope,
1929+
literal_origins,
1930+
compiled_origin_regex,
1931+
has_configured_allowed_origins,
1932+
):
1933+
await websocket.close(code=1008, reason="Origin not allowed")
1934+
return
1935+
17581936
await websocket.accept()
17591937

17601938
session = await self.session_service.get_session(
17611939
app_name=app_name, user_id=user_id, session_id=session_id
17621940
)
17631941
if not session:
1764-
# Accept first so that the client is aware of connection establishment,
1765-
# then close with a specific code.
17661942
await websocket.close(code=1002, reason="Session not found")
17671943
return
17681944

tests/unittests/cli/test_adk_web_server_run_live.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from google.adk.events.event import Event
2323
from google.adk.sessions.in_memory_session_service import InMemorySessionService
2424
import pytest
25+
from starlette.websockets import WebSocketDisconnect
2526

2627

2728
class _DummyAgent(BaseAgent):
@@ -203,3 +204,75 @@ async def _get_runner_async(_self, _app_name: str):
203204
run_config.session_resumption.transparent
204205
is expected_session_resumption_transparent
205206
)
207+
208+
209+
_WS_BASE_URL = (
210+
"/run_live"
211+
"?app_name=test_app"
212+
"&user_id=user"
213+
"&session_id=session"
214+
"&modalities=AUDIO"
215+
)
216+
217+
218+
def _build_ws_client():
219+
"""Build a TestClient wired to a capturing runner."""
220+
session_service = InMemorySessionService()
221+
asyncio.run(
222+
session_service.create_session(
223+
app_name="test_app",
224+
user_id="user",
225+
session_id="session",
226+
state={},
227+
)
228+
)
229+
230+
runner = _CapturingRunner()
231+
adk_web_server = AdkWebServer(
232+
agent_loader=_DummyAgentLoader(),
233+
session_service=session_service,
234+
memory_service=types.SimpleNamespace(),
235+
artifact_service=types.SimpleNamespace(),
236+
credential_service=types.SimpleNamespace(),
237+
eval_sets_manager=types.SimpleNamespace(),
238+
eval_set_results_manager=types.SimpleNamespace(),
239+
agents_dir=".",
240+
)
241+
242+
async def _get_runner_async(_self, _app_name: str):
243+
return runner
244+
245+
adk_web_server.get_runner_async = _get_runner_async.__get__(adk_web_server) # pytype: disable=attribute-error
246+
247+
fast_api_app = adk_web_server.get_fast_api_app(
248+
setup_observer=lambda _observer, _server: None,
249+
tear_down_observer=lambda _observer, _server: None,
250+
)
251+
return TestClient(fast_api_app)
252+
253+
254+
def test_run_live_rejects_disallowed_origin():
255+
client = _build_ws_client()
256+
with pytest.raises(WebSocketDisconnect) as exc_info:
257+
with client.websocket_connect(
258+
_WS_BASE_URL,
259+
headers={"origin": "https://evil.com"},
260+
) as ws:
261+
ws.receive_text()
262+
assert exc_info.value.code == 1008
263+
264+
265+
def test_run_live_allows_matching_origin():
266+
client = _build_ws_client()
267+
with client.websocket_connect(
268+
_WS_BASE_URL,
269+
headers={"origin": "http://testserver"},
270+
) as ws:
271+
_ = ws.receive_text()
272+
273+
274+
def test_run_live_allows_no_origin_header():
275+
"""Non-browser clients (curl, wscat, SDKs) send no Origin header."""
276+
client = _build_ws_client()
277+
with client.websocket_connect(_WS_BASE_URL) as ws:
278+
_ = ws.receive_text()

tests/unittests/cli/test_fast_api.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def builder_test_client(
593593
session_service_uri="",
594594
artifact_service_uri="",
595595
memory_service_uri="",
596-
allow_origins=["*"],
596+
allow_origins=None,
597597
a2a=False,
598598
host="127.0.0.1",
599599
port=8000,
@@ -1595,6 +1595,46 @@ def test_builder_final_save_preserves_tools_and_cleans_tmp(
15951595
assert not tmp_dir.exists() or not any(tmp_dir.iterdir())
15961596

15971597

1598+
def test_builder_save_rejects_cross_origin_post(builder_test_client, tmp_path):
1599+
response = builder_test_client.post(
1600+
"/builder/save?tmp=true",
1601+
headers={"origin": "https://evil.com"},
1602+
files=[(
1603+
"files",
1604+
("app/root_agent.yaml", b"name: app\n", "application/x-yaml"),
1605+
)],
1606+
)
1607+
1608+
assert response.status_code == 403
1609+
assert response.text == "Forbidden: origin not allowed"
1610+
assert not (tmp_path / "app" / "tmp" / "app").exists()
1611+
1612+
1613+
def test_builder_save_allows_same_origin_post(builder_test_client, tmp_path):
1614+
response = builder_test_client.post(
1615+
"/builder/save?tmp=true",
1616+
headers={"origin": "http://testserver"},
1617+
files=[(
1618+
"files",
1619+
("app/root_agent.yaml", b"name: app\n", "application/x-yaml"),
1620+
)],
1621+
)
1622+
1623+
assert response.status_code == 200
1624+
assert response.json() is True
1625+
assert (tmp_path / "app" / "tmp" / "app" / "root_agent.yaml").is_file()
1626+
1627+
1628+
def test_builder_get_allows_cross_origin_get(builder_test_client):
1629+
response = builder_test_client.get(
1630+
"/builder/app/missing?tmp=true",
1631+
headers={"origin": "https://evil.com"},
1632+
)
1633+
1634+
assert response.status_code == 200
1635+
assert response.text == ""
1636+
1637+
15981638
def test_builder_cancel_deletes_tmp_idempotent(builder_test_client, tmp_path):
15991639
tmp_agent_root = tmp_path / "app" / "tmp" / "app"
16001640
tmp_agent_root.mkdir(parents=True, exist_ok=True)

0 commit comments

Comments
 (0)