Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -491,11 +491,36 @@ def start_refresh(self, credentials, request, rab_manager):
# A refresh is already in progress.
return

# Safely unwrap functools.partial wrappers to isolate the genuine request callable.
actual_request = request
partial_args = ()
partial_kwargs = {}

if isinstance(request, functools.partial):
actual_request = request.func
partial_args = request.args
partial_kwargs = request.keywords

# Execute the clone protocol on the concrete underlying request adapter.
lookup_request = actual_request
if hasattr(actual_request, "clone"):
lookup_request = actual_request.clone()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't call .clone() before the background task is safely spawned. If asyncio.create_task fails (for example, if the event loop is shutting down or isn't running on this thread), the background coroutine is discarded without ever starting. Since the coroutine body never runs, the finally block is skipped and the cloned session is permanently leaked.

If we move actual_request.clone() inside the try block of the _worker coroutine, we can guarantee that if a clone is created, it will definitely hit the finally block and get closed.


# Re-apply initial partial call arguments to the detached request adapter.
if isinstance(request, functools.partial):
lookup_callable = functools.partial(
lookup_request, *partial_args, **partial_kwargs
)
else:
lookup_callable = lookup_request

async def _worker():
try:
# credentials._lookup_regional_access_boundary should be async in the async creds class
regional_access_boundary_info = (
await credentials._lookup_regional_access_boundary(request)
await credentials._lookup_regional_access_boundary(
lookup_callable
)
)
except Exception as e:
if _helpers.is_logging_enabled(_LOGGER):
Expand All @@ -505,6 +530,29 @@ async def _worker():
exc_info=True,
)
regional_access_boundary_info = None
finally:
# Cleanly terminate the detached private socket pool.
if lookup_request is not actual_request and hasattr(
lookup_request, "close"
):
if inspect.iscoroutinefunction(lookup_request.close):
try:
await lookup_request.close()
except Exception as e:
if _helpers.is_logging_enabled(_LOGGER):
_LOGGER.warning(
"Failed to close cloned async request adapter: %s",
e,
)
else:
try:
lookup_request.close()
except Exception as e:
if _helpers.is_logging_enabled(_LOGGER):
_LOGGER.warning(
"Failed to close cloned request adapter: %s",
e,
)

rab_manager.process_regional_access_boundary_info(
regional_access_boundary_info
Expand Down
4 changes: 4 additions & 0 deletions packages/google-auth/google/auth/aio/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,7 @@ async def close(self) -> None:
Close the underlying session.
"""
raise NotImplementedError("close must be implemented.")

def clone(self) -> "Request":
"""Create an independent detached copy of this request callable."""
return self
25 changes: 25 additions & 0 deletions packages/google-auth/google/auth/aio/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,28 @@ async def close(self) -> None:
if not self._closed and self._session:
await self._session.close()
self._closed = True

def clone(self) -> "Request":
"""Creates a detached copy of this request adapter.

Returns:
google.auth.aio.transport.aiohttp.Request: An independent request adapter
running a new aiohttp.ClientSession with identical environment proxy and
trace configurations.
"""
new_session = None
if self._session:
trust_env = getattr(self._session, "_trust_env", True)
trace_configs = getattr(self._session, "_trace_configs", None)
new_session = aiohttp.ClientSession(
auto_decompress=False,
trust_env=trust_env,
trace_configs=list(trace_configs) if trace_configs else None,
)
else:
new_session = aiohttp.ClientSession(
auto_decompress=False,
trust_env=True,
)

return Request(session=new_session)
31 changes: 31 additions & 0 deletions packages/google-auth/google/auth/transport/_aiohttp_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,37 @@ async def __call__(
new_exc = exceptions.TransportError(caught_exc)
raise new_exc from caught_exc

def clone(self):
"""Create an independent detached copy of this request adapter.

Returns:
google.auth.transport._aiohttp_requests.Request: An independent request adapter
running an isolated aiohttp.ClientSession with identical environment proxy and
observability configurations.
"""
new_session = None
if self.session:
trust_env = getattr(self.session, "_trust_env", True)
trace_configs = getattr(self.session, "_trace_configs", None)
new_session = aiohttp.ClientSession(
auto_decompress=False,
trust_env=trust_env,
trace_configs=list(trace_configs) if trace_configs else None,
)
else:
new_session = aiohttp.ClientSession(
auto_decompress=False,
trust_env=True,
)

return Request(session=new_session)

async def close(self):
"""Cleanly release the underlying aiohttp ClientSession resources."""
if not getattr(self, "_closed", False) and self.session:
await self.session.close()
self._closed = True


class AuthorizedSession(aiohttp.ClientSession):
"""This is an async implementation of the Authorized Session class. We utilize an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ async def test_async_refresh_manager_session_closed_ignored(self):
)

request = mock.Mock()
request.clone.return_value = request
rab_manager = mock.Mock()

manager = (
Expand All @@ -694,6 +695,91 @@ async def test_async_refresh_manager_session_closed_ignored(self):
credentials._lookup_regional_access_boundary.assert_called_once_with(request)
rab_manager.process_regional_access_boundary_info.assert_called_once_with(None)

@pytest.mark.asyncio
async def test_start_refresh_async_clones_request_and_unwraps_partial(self):
import functools

credentials = mock.AsyncMock()
credentials._lookup_regional_access_boundary.return_value = {
"encodedLocations": "0xA30"
}

mock_request = mock.Mock()
mock_cloned_request = mock.Mock()
mock_request.clone.return_value = mock_cloned_request
mock_cloned_request.close = mock.AsyncMock()

# Wrap in a functools.partial to simulate AuthorizedSession.request() timeouts
partial_request = functools.partial(mock_request, timeout=180)

rab_manager = mock.Mock()

manager = (
_regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager()
)
manager.start_refresh(credentials, partial_request, rab_manager)

await manager._worker_task

# Verify that actual_request.clone() was called
mock_request.clone.assert_called_once()

# Verify that the lookup ran on a re-wrapped partial of the cloned request
called_arg = credentials._lookup_regional_access_boundary.call_args[0][0]
assert isinstance(called_arg, functools.partial)
assert called_arg.func is mock_cloned_request
assert called_arg.keywords == {"timeout": 180}

# Verify that the cloned request was closed cleanly in the finally block
mock_cloned_request.close.assert_called_once()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since close is an async method, we should use assert_awaited_once() instead of assert_called_once(). This ensures the coroutine was actually awaited and didn't just return an unawaited coroutine object.

rab_manager.process_regional_access_boundary_info.assert_called_once_with(
{"encodedLocations": "0xA30"}
)

@pytest.mark.asyncio
async def test_start_refresh_async_mimics_ephemeral_session_closed_bug(self):
# Specifically mimics the real-world race condition where a fast foreground main call
# pulls the rug out from under the background worker when using an un-cloned session.
import asyncio

manager = (
_regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager()
)

class EphemeralRequest:
def __init__(self):
self.closed = False

async def __call__(self, *args, **kwargs):
await asyncio.sleep(0.05)
if self.closed:
raise RuntimeError("Session is closed")
return "success"

ephemeral_req = EphemeralRequest()

credentials = mock.AsyncMock()

async def mock_lookup(req):
return await req()

credentials._lookup_regional_access_boundary.side_effect = mock_lookup

rab_manager = mock.Mock()

# Start the background refresh worker
manager.start_refresh(credentials, ephemeral_req, rab_manager)

# Simulate fast foreground primary call (completes in 10ms and closes the session)
await asyncio.sleep(0.01)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using hardcoded sleeps (0.01 and 0.05 seconds) to coordinate tasks makes tests flaky on busy CI servers.

We can make this test 100% deterministic with zero sleeps by using asyncio.Event to coordinate exactly when the background worker has started and when the foreground action completes.

ephemeral_req.closed = True

# Await the background worker task to settle
await manager._worker_task

# Verify that the background worker hit the "Session is closed" error and failed open cleanly
rab_manager.process_regional_access_boundary_info.assert_called_once_with(None)


def test_get_service_account_rab_endpoint(monkeypatch):
from google.auth.transport import _mtls_helper
Expand Down
18 changes: 18 additions & 0 deletions packages/google-auth/tests/transport/aio/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,21 @@ async def test_request_call_raises_transport_error_for_closed_session(

exc.match("session is closed.")
aiohttp_request._closed = False

async def test_request_clone(self):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new clone test only covers the case where the request doesn't have an active session (self._session is None). We should add a quick test to make sure that cloning a request with an active session correctly copies the trust_env and trace_configs attributes.

request = auth_aiohttp.Request()
cloned = request.clone()
assert cloned is not request
assert isinstance(cloned, auth_aiohttp.Request)
assert cloned._session is not request._session
await request.close()
await cloned.close()

async def test_request_close(self):
request = auth_aiohttp.Request()
assert not getattr(request, "_closed", False)
await request.close()
assert request._closed
# Second call should be idempotent
await request.close()
assert request._closed
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ async def test_async_refresh_manager_start_refresh():
}

request = mock.Mock()
request.clone.return_value = request
rab_manager = mock.Mock()

manager = (
Expand Down
Loading