2020import json
2121import logging
2222import os
23+ import re
2324import sys
2425import time
2526import traceback
@@ -140,6 +141,158 @@ def _parse_cors_origins(
140141 return literal_origins , combined_regex
141142
142143
144+ def _is_origin_allowed (
145+ origin : str ,
146+ allowed_literal_origins : list [str ],
147+ allowed_origin_regex : Optional [re .Pattern [str ]],
148+ ) -> bool :
149+ """Check whether the given origin matches the allowed origins."""
150+ if "*" in allowed_literal_origins :
151+ return True
152+ if origin in allowed_literal_origins :
153+ return True
154+ if allowed_origin_regex is not None :
155+ return allowed_origin_regex .fullmatch (origin ) is not None
156+ return False
157+
158+
159+ def _normalize_origin_scheme (scheme : str ) -> str :
160+ """Normalize request schemes to the browser Origin scheme space."""
161+ if scheme == "ws" :
162+ return "http"
163+ if scheme == "wss" :
164+ return "https"
165+ return scheme
166+
167+
168+ def _strip_optional_quotes (value : str ) -> str :
169+ """Strip a single pair of wrapping quotes from a header value."""
170+ if len (value ) >= 2 and value [0 ] == '"' and value [- 1 ] == '"' :
171+ return value [1 :- 1 ]
172+ return value
173+
174+
175+ def _get_scope_header (
176+ scope : dict [str , Any ], header_name : bytes
177+ ) -> Optional [str ]:
178+ """Return the first matching header value from an ASGI scope."""
179+ for candidate_name , candidate_value in scope .get ("headers" , []):
180+ if candidate_name == header_name :
181+ return candidate_value .decode ("latin-1" ).split ("," , 1 )[0 ].strip ()
182+ return None
183+
184+
185+ def _get_request_origin (scope : dict [str , Any ]) -> Optional [str ]:
186+ """Compute the effective origin for the current HTTP/WebSocket request."""
187+ forwarded = _get_scope_header (scope , b"forwarded" )
188+ if forwarded is not None :
189+ proto = None
190+ host = None
191+ for element in forwarded .split ("," , 1 )[0 ].split (";" ):
192+ if "=" not in element :
193+ continue
194+ name , value = element .split ("=" , 1 )
195+ if name .strip ().lower () == "proto" :
196+ proto = _strip_optional_quotes (value .strip ())
197+ elif name .strip ().lower () == "host" :
198+ host = _strip_optional_quotes (value .strip ())
199+ if proto is not None and host is not None :
200+ return f"{ _normalize_origin_scheme (proto )} ://{ host } "
201+
202+ host = _get_scope_header (scope , b"x-forwarded-host" )
203+ if host is None :
204+ host = _get_scope_header (scope , b"host" )
205+ if host is None :
206+ return None
207+
208+ proto = _get_scope_header (scope , b"x-forwarded-proto" )
209+ if proto is None :
210+ proto = scope .get ("scheme" , "http" )
211+ return f"{ _normalize_origin_scheme (proto )} ://{ host } "
212+
213+
214+ def _is_request_origin_allowed (
215+ origin : str ,
216+ scope : dict [str , Any ],
217+ allowed_literal_origins : list [str ],
218+ allowed_origin_regex : Optional [re .Pattern [str ]],
219+ has_configured_allowed_origins : bool ,
220+ ) -> bool :
221+ """Validate an Origin header against explicit config or same-origin."""
222+ if has_configured_allowed_origins and _is_origin_allowed (
223+ origin , allowed_literal_origins , allowed_origin_regex
224+ ):
225+ return True
226+
227+ request_origin = _get_request_origin (scope )
228+ if request_origin is None :
229+ return False
230+ return origin == request_origin
231+
232+
233+ _SAFE_HTTP_METHODS = frozenset ({"GET" , "HEAD" , "OPTIONS" })
234+
235+
236+ class _OriginCheckMiddleware :
237+ """ASGI middleware that blocks cross-origin state-changing requests."""
238+
239+ def __init__ (
240+ self ,
241+ app : Any ,
242+ has_configured_allowed_origins : bool ,
243+ allowed_origins : list [str ],
244+ allowed_origin_regex : Optional [re .Pattern [str ]],
245+ ) -> None :
246+ self ._app = app
247+ self ._has_configured_allowed_origins = has_configured_allowed_origins
248+ self ._allowed_origins = allowed_origins
249+ self ._allowed_origin_regex = allowed_origin_regex
250+
251+ async def __call__ (
252+ self ,
253+ scope : dict [str , Any ],
254+ receive : Any ,
255+ send : Any ,
256+ ) -> None :
257+ if scope ["type" ] != "http" :
258+ await self ._app (scope , receive , send )
259+ return
260+
261+ method = scope .get ("method" , "GET" )
262+ if method in _SAFE_HTTP_METHODS :
263+ await self ._app (scope , receive , send )
264+ return
265+
266+ origin = _get_scope_header (scope , b"origin" )
267+ if origin is None :
268+ await self ._app (scope , receive , send )
269+ return
270+
271+ if _is_request_origin_allowed (
272+ origin ,
273+ scope ,
274+ self ._allowed_origins ,
275+ self ._allowed_origin_regex ,
276+ self ._has_configured_allowed_origins ,
277+ ):
278+ await self ._app (scope , receive , send )
279+ return
280+
281+ response_body = b"Forbidden: origin not allowed"
282+ await send ({
283+ "type" : "http.response.start" ,
284+ "status" : 403 ,
285+ "headers" : [
286+ (b"content-type" , b"text/plain" ),
287+ (b"content-length" , str (len (response_body )).encode ()),
288+ ],
289+ })
290+ await send ({
291+ "type" : "http.response.body" ,
292+ "body" : response_body ,
293+ })
294+
295+
143296class ApiServerSpanExporter (export_lib .SpanExporter ):
144297
145298 def __init__ (self , trace_dict ):
@@ -759,8 +912,12 @@ async def internal_lifespan(app: FastAPI):
759912 # Run the FastAPI server.
760913 app = FastAPI (lifespan = internal_lifespan )
761914
915+ has_configured_allowed_origins = bool (allow_origins )
762916 if allow_origins :
763917 literal_origins , combined_regex = _parse_cors_origins (allow_origins )
918+ compiled_origin_regex = (
919+ re .compile (combined_regex ) if combined_regex is not None else None
920+ )
764921 app .add_middleware (
765922 CORSMiddleware ,
766923 allow_origins = literal_origins ,
@@ -769,6 +926,16 @@ async def internal_lifespan(app: FastAPI):
769926 allow_methods = ["*" ],
770927 allow_headers = ["*" ],
771928 )
929+ else :
930+ literal_origins = []
931+ compiled_origin_regex = None
932+
933+ app .add_middleware (
934+ _OriginCheckMiddleware ,
935+ has_configured_allowed_origins = has_configured_allowed_origins ,
936+ allowed_origins = literal_origins ,
937+ allowed_origin_regex = compiled_origin_regex ,
938+ )
772939
773940 @app .get ("/health" )
774941 async def health () -> dict [str , str ]:
@@ -1755,14 +1922,23 @@ async def run_agent_live(
17551922 enable_affective_dialog : bool | None = Query (default = None ),
17561923 enable_session_resumption : bool | None = Query (default = None ),
17571924 ) -> None :
1925+ ws_origin = websocket .headers .get ("origin" )
1926+ if ws_origin is not None and not _is_request_origin_allowed (
1927+ ws_origin ,
1928+ websocket .scope ,
1929+ literal_origins ,
1930+ compiled_origin_regex ,
1931+ has_configured_allowed_origins ,
1932+ ):
1933+ await websocket .close (code = 1008 , reason = "Origin not allowed" )
1934+ return
1935+
17581936 await websocket .accept ()
17591937
17601938 session = await self .session_service .get_session (
17611939 app_name = app_name , user_id = user_id , session_id = session_id
17621940 )
17631941 if not session :
1764- # Accept first so that the client is aware of connection establishment,
1765- # then close with a specific code.
17661942 await websocket .close (code = 1002 , reason = "Session not found" )
17671943 return
17681944
0 commit comments