diff --git a/packages/google-auth/google/auth/_regional_access_boundary_utils.py b/packages/google-auth/google/auth/_regional_access_boundary_utils.py index c97bf8f484df..6845f0e18e5f 100644 --- a/packages/google-auth/google/auth/_regional_access_boundary_utils.py +++ b/packages/google-auth/google/auth/_regional_access_boundary_utils.py @@ -491,11 +491,36 @@ def start_refresh(self, credentials, request, rab_manager): # A refresh is already in progress. return + # Safely unwrap functools.partial wrappers to isolate the genuine request callable. + actual_request = request + partial_args = () + partial_kwargs = {} + + if isinstance(request, functools.partial): + actual_request = request.func + partial_args = request.args + partial_kwargs = request.keywords + + # Execute the clone protocol on the concrete underlying request adapter. + lookup_request = actual_request + if hasattr(actual_request, "clone"): + lookup_request = actual_request.clone() + + # Re-apply initial partial call arguments to the detached request adapter. + if isinstance(request, functools.partial): + lookup_callable = functools.partial( + lookup_request, *partial_args, **partial_kwargs + ) + else: + lookup_callable = lookup_request + async def _worker(): try: # credentials._lookup_regional_access_boundary should be async in the async creds class regional_access_boundary_info = ( - await credentials._lookup_regional_access_boundary(request) + await credentials._lookup_regional_access_boundary( + lookup_callable + ) ) except Exception as e: if _helpers.is_logging_enabled(_LOGGER): @@ -505,6 +530,29 @@ async def _worker(): exc_info=True, ) regional_access_boundary_info = None + finally: + # Cleanly terminate the detached private socket pool. + if lookup_request is not actual_request and hasattr( + lookup_request, "close" + ): + if inspect.iscoroutinefunction(lookup_request.close): + try: + await lookup_request.close() + except Exception as e: + if _helpers.is_logging_enabled(_LOGGER): + _LOGGER.warning( + "Failed to close cloned async request adapter: %s", + e, + ) + else: + try: + lookup_request.close() + except Exception as e: + if _helpers.is_logging_enabled(_LOGGER): + _LOGGER.warning( + "Failed to close cloned request adapter: %s", + e, + ) rab_manager.process_regional_access_boundary_info( regional_access_boundary_info diff --git a/packages/google-auth/google/auth/aio/transport/__init__.py b/packages/google-auth/google/auth/aio/transport/__init__.py index 166a3be50914..d25ede281c6d 100644 --- a/packages/google-auth/google/auth/aio/transport/__init__.py +++ b/packages/google-auth/google/auth/aio/transport/__init__.py @@ -142,3 +142,7 @@ async def close(self) -> None: Close the underlying session. """ raise NotImplementedError("close must be implemented.") + + def clone(self) -> "Request": + """Create an independent detached copy of this request callable.""" + return self diff --git a/packages/google-auth/google/auth/aio/transport/aiohttp.py b/packages/google-auth/google/auth/aio/transport/aiohttp.py index 642d15927d0f..338c5e2f7a53 100644 --- a/packages/google-auth/google/auth/aio/transport/aiohttp.py +++ b/packages/google-auth/google/auth/aio/transport/aiohttp.py @@ -203,3 +203,28 @@ async def close(self) -> None: if not self._closed and self._session: await self._session.close() self._closed = True + + def clone(self) -> "Request": + """Creates a detached copy of this request adapter. + + Returns: + google.auth.aio.transport.aiohttp.Request: An independent request adapter + running a new aiohttp.ClientSession with identical environment proxy and + trace configurations. + """ + new_session = None + if self._session: + trust_env = getattr(self._session, "_trust_env", True) + trace_configs = getattr(self._session, "_trace_configs", None) + new_session = aiohttp.ClientSession( + auto_decompress=False, + trust_env=trust_env, + trace_configs=list(trace_configs) if trace_configs else None, + ) + else: + new_session = aiohttp.ClientSession( + auto_decompress=False, + trust_env=True, + ) + + return Request(session=new_session) diff --git a/packages/google-auth/google/auth/transport/_aiohttp_requests.py b/packages/google-auth/google/auth/transport/_aiohttp_requests.py index e8321965e0db..33107719ce35 100644 --- a/packages/google-auth/google/auth/transport/_aiohttp_requests.py +++ b/packages/google-auth/google/auth/transport/_aiohttp_requests.py @@ -203,6 +203,37 @@ async def __call__( new_exc = exceptions.TransportError(caught_exc) raise new_exc from caught_exc + def clone(self): + """Create an independent detached copy of this request adapter. + + Returns: + google.auth.transport._aiohttp_requests.Request: An independent request adapter + running an isolated aiohttp.ClientSession with identical environment proxy and + observability configurations. + """ + new_session = None + if self.session: + trust_env = getattr(self.session, "_trust_env", True) + trace_configs = getattr(self.session, "_trace_configs", None) + new_session = aiohttp.ClientSession( + auto_decompress=False, + trust_env=trust_env, + trace_configs=list(trace_configs) if trace_configs else None, + ) + else: + new_session = aiohttp.ClientSession( + auto_decompress=False, + trust_env=True, + ) + + return Request(session=new_session) + + async def close(self): + """Cleanly release the underlying aiohttp ClientSession resources.""" + if not getattr(self, "_closed", False) and self.session: + await self.session.close() + self._closed = True + class AuthorizedSession(aiohttp.ClientSession): """This is an async implementation of the Authorized Session class. We utilize an diff --git a/packages/google-auth/tests/test__regional_access_boundary_utils.py b/packages/google-auth/tests/test__regional_access_boundary_utils.py index c612b60b8ed2..aeaaf68d260f 100644 --- a/packages/google-auth/tests/test__regional_access_boundary_utils.py +++ b/packages/google-auth/tests/test__regional_access_boundary_utils.py @@ -678,6 +678,7 @@ async def test_async_refresh_manager_session_closed_ignored(self): ) request = mock.Mock() + request.clone.return_value = request rab_manager = mock.Mock() manager = ( @@ -694,6 +695,91 @@ async def test_async_refresh_manager_session_closed_ignored(self): credentials._lookup_regional_access_boundary.assert_called_once_with(request) rab_manager.process_regional_access_boundary_info.assert_called_once_with(None) + @pytest.mark.asyncio + async def test_start_refresh_async_clones_request_and_unwraps_partial(self): + import functools + + credentials = mock.AsyncMock() + credentials._lookup_regional_access_boundary.return_value = { + "encodedLocations": "0xA30" + } + + mock_request = mock.Mock() + mock_cloned_request = mock.Mock() + mock_request.clone.return_value = mock_cloned_request + mock_cloned_request.close = mock.AsyncMock() + + # Wrap in a functools.partial to simulate AuthorizedSession.request() timeouts + partial_request = functools.partial(mock_request, timeout=180) + + rab_manager = mock.Mock() + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + manager.start_refresh(credentials, partial_request, rab_manager) + + await manager._worker_task + + # Verify that actual_request.clone() was called + mock_request.clone.assert_called_once() + + # Verify that the lookup ran on a re-wrapped partial of the cloned request + called_arg = credentials._lookup_regional_access_boundary.call_args[0][0] + assert isinstance(called_arg, functools.partial) + assert called_arg.func is mock_cloned_request + assert called_arg.keywords == {"timeout": 180} + + # Verify that the cloned request was closed cleanly in the finally block + mock_cloned_request.close.assert_called_once() + rab_manager.process_regional_access_boundary_info.assert_called_once_with( + {"encodedLocations": "0xA30"} + ) + + @pytest.mark.asyncio + async def test_start_refresh_async_mimics_ephemeral_session_closed_bug(self): + # Specifically mimics the real-world race condition where a fast foreground main call + # pulls the rug out from under the background worker when using an un-cloned session. + import asyncio + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + class EphemeralRequest: + def __init__(self): + self.closed = False + + async def __call__(self, *args, **kwargs): + await asyncio.sleep(0.05) + if self.closed: + raise RuntimeError("Session is closed") + return "success" + + ephemeral_req = EphemeralRequest() + + credentials = mock.AsyncMock() + + async def mock_lookup(req): + return await req() + + credentials._lookup_regional_access_boundary.side_effect = mock_lookup + + rab_manager = mock.Mock() + + # Start the background refresh worker + manager.start_refresh(credentials, ephemeral_req, rab_manager) + + # Simulate fast foreground primary call (completes in 10ms and closes the session) + await asyncio.sleep(0.01) + ephemeral_req.closed = True + + # Await the background worker task to settle + await manager._worker_task + + # Verify that the background worker hit the "Session is closed" error and failed open cleanly + rab_manager.process_regional_access_boundary_info.assert_called_once_with(None) + def test_get_service_account_rab_endpoint(monkeypatch): from google.auth.transport import _mtls_helper diff --git a/packages/google-auth/tests/transport/aio/test_aiohttp.py b/packages/google-auth/tests/transport/aio/test_aiohttp.py index 553f35775fac..f9845c78df94 100644 --- a/packages/google-auth/tests/transport/aio/test_aiohttp.py +++ b/packages/google-auth/tests/transport/aio/test_aiohttp.py @@ -169,3 +169,21 @@ async def test_request_call_raises_transport_error_for_closed_session( exc.match("session is closed.") aiohttp_request._closed = False + + async def test_request_clone(self): + request = auth_aiohttp.Request() + cloned = request.clone() + assert cloned is not request + assert isinstance(cloned, auth_aiohttp.Request) + assert cloned._session is not request._session + await request.close() + await cloned.close() + + async def test_request_close(self): + request = auth_aiohttp.Request() + assert not getattr(request, "_closed", False) + await request.close() + assert request._closed + # Second call should be idempotent + await request.close() + assert request._closed diff --git a/packages/google-auth/tests_async/test__regional_access_boundary_utils.py b/packages/google-auth/tests_async/test__regional_access_boundary_utils.py index 268ee37261c8..944c2ae408fb 100644 --- a/packages/google-auth/tests_async/test__regional_access_boundary_utils.py +++ b/packages/google-auth/tests_async/test__regional_access_boundary_utils.py @@ -28,6 +28,7 @@ async def test_async_refresh_manager_start_refresh(): } request = mock.Mock() + request.clone.return_value = request rab_manager = mock.Mock() manager = (