diff --git a/packages/google-auth/google/auth/_regional_access_boundary_utils.py b/packages/google-auth/google/auth/_regional_access_boundary_utils.py index 81011911df3d..2dab4addbbe9 100644 --- a/packages/google-auth/google/auth/_regional_access_boundary_utils.py +++ b/packages/google-auth/google/auth/_regional_access_boundary_utils.py @@ -14,6 +14,7 @@ """Utilities for Regional Access Boundary management.""" +import asyncio import copy import datetime import functools @@ -384,3 +385,61 @@ def start_refresh(self, credentials, request, rab_manager): credentials, copied_request, rab_manager ) self._worker.start() + + +class _AsyncRegionalAccessBoundaryRefreshManager(object): + """Manages a task for background refreshing of the Regional Access Boundary in async flows.""" + + def __init__(self): + self._lock = threading.Lock() + self._worker_task = None + + def __getstate__(self): + """Pickle helper that serializes the _lock and _worker_task attributes.""" + state = self.__dict__.copy() + state["_lock"] = None + state["_worker_task"] = None + return state + + def __setstate__(self, state): + """Pickle helper that deserializes the _lock and _worker_task attributes.""" + self.__dict__.update(state) + self._lock = threading.Lock() + self._worker_task = None + + def start_refresh(self, credentials, request, rab_manager): + """ + Starts a background task to refresh the Regional Access Boundary if one is not already running. + + Args: + credentials (CredentialsWithRegionalAccessBoundary): The credentials + to refresh. + request (google.auth.aio.transport.Request): The object used to make + HTTP requests. + rab_manager (_RegionalAccessBoundaryManager): The manager container to update. + """ + with self._lock: + if self._worker_task and not self._worker_task.done(): + # A refresh is already in progress. + return + + async def _worker(): + try: + # credentials._lookup_regional_access_boundary should be async in the async creds class + regional_access_boundary_info = ( + await credentials._lookup_regional_access_boundary(request) + ) + except Exception as e: + if _helpers.is_logging_enabled(_LOGGER): + _LOGGER.warning( + "Asynchronous Regional Access Boundary lookup raised an exception: %s", + e, + exc_info=True, + ) + regional_access_boundary_info = None + + rab_manager.process_regional_access_boundary_info( + regional_access_boundary_info + ) + + self._worker_task = asyncio.create_task(_worker()) diff --git a/packages/google-auth/google/auth/credentials.py b/packages/google-auth/google/auth/credentials.py index 4a686cb01907..a3d845e9196b 100644 --- a/packages/google-auth/google/auth/credentials.py +++ b/packages/google-auth/google/auth/credentials.py @@ -239,9 +239,25 @@ def before_request(self, request, method, url, headers): else: self._blocking_refresh(request) + self._after_refresh(request, method, url, headers) + metrics.add_metric_header(headers, self._metric_header_for_usage()) self.apply(headers) + def _after_refresh(self, request, method, url, headers): + """Hook for subclasses to perform actions after refresh but before + applying credentials to headers. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + method (str): The request's HTTP method or the RPC method being + invoked. + url (str): The request's URI or the RPC service's URI. + headers (Mapping): The request's headers. + """ + pass + def with_non_blocking_refresh(self): self._use_non_blocking_refresh = True @@ -309,6 +325,22 @@ def __init__(self): _regional_access_boundary_utils._RegionalAccessBoundaryManager() ) + def __setstate__(self, state): + """Pickle helper that restores state, safely reconstructing RAB fields if missing.""" + self.__dict__.update(state) + if "_rab_manager" not in self.__dict__: + from google.auth import _regional_access_boundary_utils + + self._rab_manager = ( + _regional_access_boundary_utils._RegionalAccessBoundaryManager() + ) + if "_use_non_blocking_refresh" not in self.__dict__: + self._use_non_blocking_refresh = False + if "_refresh_worker" not in self.__dict__: + from google.auth._refresh_worker import RefreshThreadManager + + self._refresh_worker = RefreshThreadManager() + @property def regional_access_boundary(self): """Optional[str]: The encoded Regional Access Boundary locations.""" @@ -369,6 +401,8 @@ def _copy_regional_access_boundary_manager(self, target): # but share the immutable data reference to avoid unnecessary initial lookups. new_manager = _regional_access_boundary_utils._RegionalAccessBoundaryManager() new_manager._data = self._rab_manager._data + # Preserve the type of refresh manager (sync or async) + new_manager.refresh_manager = self._rab_manager.refresh_manager.__class__() target._rab_manager = new_manager def _set_regional_access_boundary(self, seed): @@ -459,20 +493,10 @@ def apply(self, headers, token=None): super().apply(headers, token) self._rab_manager.apply_headers(headers) - def before_request(self, request, method, url, headers): - """Refreshes the access token and triggers the Regional Access Boundary - lookup if necessary. - """ - if self._use_non_blocking_refresh: - self._non_blocking_refresh(request) - else: - self._blocking_refresh(request) - + def _after_refresh(self, request, method, url, headers): + """Triggers the Regional Access Boundary lookup if necessary.""" self._maybe_start_regional_access_boundary_refresh(request, url) - metrics.add_metric_header(headers, self._metric_header_for_usage()) - self.apply(headers) - def refresh(self, request): """Refreshes the access token. diff --git a/packages/google-auth/google/auth/jwt.py b/packages/google-auth/google/auth/jwt.py index b6fe60736fa1..38a84bfd97aa 100644 --- a/packages/google-auth/google/auth/jwt.py +++ b/packages/google-auth/google/auth/jwt.py @@ -55,6 +55,7 @@ from google.auth import _service_account_info from google.auth import crypt from google.auth import exceptions +from google.auth import iam import google.auth.credentials try: @@ -317,7 +318,9 @@ def decode(token, certs=None, verify=True, audience=None, clock_skew_in_seconds= class Credentials( - google.auth.credentials.Signing, google.auth.credentials.CredentialsWithQuotaProject + google.auth.credentials.Signing, + google.auth.credentials.CredentialsWithQuotaProject, + google.auth.credentials.CredentialsWithRegionalAccessBoundary, ): """Credentials that use a JWT as the bearer token. @@ -490,7 +493,15 @@ def from_signing_credentials(cls, credentials, audience, **kwargs): """ kwargs.setdefault("issuer", credentials.signer_email) kwargs.setdefault("subject", credentials.signer_email) - return cls(credentials.signer, audience=audience, **kwargs) + jwt_creds = cls(credentials.signer, audience=audience, **kwargs) + + if isinstance( + credentials, + google.auth.credentials.CredentialsWithRegionalAccessBoundary, + ): + credentials._copy_regional_access_boundary_manager(jwt_creds) + + return jwt_creds def with_claims( self, issuer=None, subject=None, audience=None, additional_claims=None @@ -514,7 +525,7 @@ def with_claims( new_additional_claims = copy.deepcopy(self._additional_claims) new_additional_claims.update(additional_claims or {}) - return self.__class__( + cred = self.__class__( self._signer, issuer=issuer if issuer is not None else self._issuer, subject=subject if subject is not None else self._subject, @@ -522,10 +533,12 @@ def with_claims( additional_claims=new_additional_claims, quota_project_id=self._quota_project_id, ) + self._copy_regional_access_boundary_manager(cred) + return cred @_helpers.copy_docstring(google.auth.credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): - return self.__class__( + cred = self.__class__( self._signer, issuer=self._issuer, subject=self._subject, @@ -533,6 +546,8 @@ def with_quota_project(self, quota_project_id): additional_claims=self._additional_claims, quota_project_id=quota_project_id, ) + self._copy_regional_access_boundary_manager(cred) + return cred def _make_jwt(self): """Make a signed JWT. @@ -559,7 +574,7 @@ def _make_jwt(self): return jwt, expiry - def refresh(self, request): + def _perform_refresh_token(self, request): """Refreshes the access token. Args: @@ -569,6 +584,15 @@ def refresh(self, request): # (pylint doesn't correctly recognize overridden methods.) self.token, self.expiry = self._make_jwt() + def _build_regional_access_boundary_lookup_url(self, request=None): + """Builds the lookup URL using the service account's email address.""" + if not self.signer_email: + return None + + return iam._SERVICE_ACCOUNT_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT.format( + service_account_email=self.signer_email + ) + @_helpers.copy_docstring(google.auth.credentials.Signing) def sign_bytes(self, message): return self._signer.sign(message) diff --git a/packages/google-auth/google/oauth2/_client_async.py b/packages/google-auth/google/oauth2/_client_async.py index a6201fbdcb94..d5e5a8483660 100644 --- a/packages/google-auth/google/oauth2/_client_async.py +++ b/packages/google-auth/google/oauth2/_client_async.py @@ -23,6 +23,7 @@ .. _Section 3.1 of rfc6749: https://tools.ietf.org/html/rfc6749#section-3.2 """ +import asyncio import http.client as http_client import json import urllib @@ -288,3 +289,145 @@ async def refresh_grant( request, token_uri, body, can_retry=can_retry ) return client._handle_refresh_grant_response(response_data, refresh_token) + + +async def _lookup_regional_access_boundary(request, url, headers=None, fail_fast=False): + """Implements the global lookup of a credential Regional Access Boundary. + For the lookup, we send a request to the global lookup endpoint and then + parse the response. Service account credentials, workload identity + pools and workforce pools implementation may have Regional Access Boundaries configured. + Args: + request (google.auth.aio.transport.Request): A callable used to make + HTTP requests. + url (str): The Regional Access Boundary lookup url. + headers (Optional[Mapping[str, str]]): The headers for the request. + fail_fast (bool): Whether the lookup should fail fast (uses a short timeout and no retries). + Returns: + Optional[Mapping[str,list|str]]: A dictionary containing + "locations" as a list of allowed locations as strings and + "encodedLocations" as a hex string. + e.g: + { + "locations": [ + "us-central1", "us-east1", "europe-west1", "asia-east1" + ], + "encodedLocations": "0xA30" + } + """ + response_data = await _lookup_regional_access_boundary_request( + request, url, headers=headers, fail_fast=fail_fast + ) + if response_data is None: + # Error was already logged by _lookup_regional_access_boundary_request + return None + + if "encodedLocations" not in response_data: + client._LOGGER.error( + "Regional Access Boundary response malformed: missing 'encodedLocations' key in %s", + response_data, + ) + return None + return response_data + + +async def _lookup_regional_access_boundary_request( + request, url, can_retry=True, headers=None, fail_fast=False +): + """Makes a request to the Regional Access Boundary lookup endpoint. + + Args: + request (google.auth.aio.transport.Request): A callable used to make + HTTP requests. + url (str): The Regional Access Boundary lookup url. + can_retry (bool): Enable or disable request retry behavior. Defaults to true. + headers (Optional[Mapping[str, str]]): The headers for the request. + fail_fast (bool): Whether the lookup should fail fast (uses a short timeout and no retries). + + Returns: + Optional[Mapping[str, str]]: The JSON-decoded response data on success, or None on failure. + """ + ( + response_status_ok, + response_data, + retryable_error, + ) = await _lookup_regional_access_boundary_request_no_throw( + request, url, can_retry=can_retry, headers=headers, fail_fast=fail_fast + ) + if not response_status_ok: + client._LOGGER.warning( + "Regional Access Boundary HTTP request failed after retries: response_data=%s, retryable_error=%s", + response_data, + retryable_error, + ) + return None + return response_data + + +async def _lookup_regional_access_boundary_request_no_throw( + request, url, can_retry=True, headers=None, fail_fast=False +): + """Makes a request to the Regional Access Boundary lookup endpoint. This + function doesn't throw on response errors. + + Args: + request (google.auth.aio.transport.Request): A callable used to make + HTTP requests. + url (str): The Regional Access Boundary lookup url. + can_retry (bool): Enable or disable request retry behavior. Defaults to true. + headers (Optional[Mapping[str, str]]): The headers for the request. + fail_fast (bool): Whether the lookup should fail fast (uses a short timeout and no retries). + + Returns: + Tuple(bool, Mapping[str, str], Optional[bool]): A boolean indicating + if the request is successful, a mapping for the JSON-decoded response + data and in the case of an error a boolean indicating if the error + is retryable. + """ + + response_data = {} + retryable_error = False + + timeout = ( + client._BLOCKING_REGIONAL_ACCESS_BOUNDARY_LOOKUP_TIMEOUT if fail_fast else None + ) + total_attempts = 1 if fail_fast else 6 + retries = _exponential_backoff.AsyncExponentialBackoff( + total_attempts=total_attempts + ) + + async for _ in retries: + try: + if timeout: + response = await asyncio.wait_for( + request(method="GET", url=url, headers=headers), timeout=timeout + ) + else: + response = await request(method="GET", url=url, headers=headers) + except asyncio.TimeoutError: + return False, {}, False + + response_body1 = await response.content() + response_body = ( + response_body1.decode("utf-8") + if hasattr(response_body1, "decode") + else response_body1 + ) + + try: + response_data = json.loads(response_body) + except ValueError: + response_data = response_body + + if response.status == http_client.OK: + return True, response_data, None + + retryable_error = client._can_retry( + status_code=response.status, response_data=response_data + ) + if response.status == http_client.BAD_GATEWAY: + retryable_error = True + + if not can_retry or not retryable_error: + return False, response_data, retryable_error + + return False, response_data, retryable_error diff --git a/packages/google-auth/google/oauth2/_service_account_async.py b/packages/google-auth/google/oauth2/_service_account_async.py index fa6cfb7b7d7a..1872aaa58bc3 100644 --- a/packages/google-auth/google/oauth2/_service_account_async.py +++ b/packages/google-auth/google/oauth2/_service_account_async.py @@ -24,6 +24,7 @@ from google.auth import _credentials_async as credentials_async from google.auth import _helpers +from google.auth import _regional_access_boundary_utils from google.oauth2 import _client_async from google.oauth2 import service_account @@ -66,6 +67,12 @@ class Credentials( credentials = credentials.with_quota_project('myproject-123') """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._rab_manager.refresh_manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + @_helpers.copy_docstring(credentials_async.Credentials) async def refresh(self, request): assertion = self._make_authorization_grant_assertion() @@ -75,12 +82,36 @@ async def refresh(self, request): self.token = access_token self.expiry = expiry + async def _lookup_regional_access_boundary(self, request, fail_fast=False): + """Calls the Regional Access Boundary lookup API to retrieve the Regional Access Boundary information. + + Args: + request (google.auth.aio.transport.Request): The object used to make + HTTP requests. + fail_fast (bool): Whether the lookup should fail fast. + + Returns: + Optional[Dict[str, str]]: The Regional Access Boundary information. + """ + url = self._build_regional_access_boundary_lookup_url(request=request) + if not url: + return None + + headers = {} + self._apply(headers) + self._rab_manager.apply_headers(headers) + + return await _client_async._lookup_regional_access_boundary( + request, url, headers=headers, fail_fast=fail_fast + ) + @_helpers.copy_docstring(credentials_async.Credentials) async def before_request(self, request, method, url, headers): - # Explicit override to bypass synchronous CredentialsWithRegionalAccessBoundary. await credentials_async.Credentials.before_request( self, request, method, url, headers ) + self._maybe_start_regional_access_boundary_refresh(request, url) + self._rab_manager.apply_headers(headers) class IDTokenCredentials( diff --git a/packages/google-auth/tests/compute_engine/test_credentials.py b/packages/google-auth/tests/compute_engine/test_credentials.py index 5a60ffd44145..864ddf6436df 100644 --- a/packages/google-auth/tests/compute_engine/test_credentials.py +++ b/packages/google-auth/tests/compute_engine/test_credentials.py @@ -306,8 +306,9 @@ def test_build_regional_access_boundary_lookup_url_default_email( url = creds._build_regional_access_boundary_lookup_url(request=mock_request) mock_get_service_account_info.assert_called_once_with(mock_request, "default") - expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" - assert url == expected_url + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" + assert url in (expected_url_standard, expected_url_mtls) @mock.patch("google.auth.compute_engine._metadata.get", autospec=True) def test_build_regional_access_boundary_lookup_url_http_client_request( @@ -323,8 +324,9 @@ def test_build_regional_access_boundary_lookup_url_http_client_request( url = creds._build_regional_access_boundary_lookup_url(request=req) - expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" - assert url == expected_url + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/resolved-email@example.com/allowedLocations" + assert url in (expected_url_standard, expected_url_mtls) @mock.patch( "google.auth.compute_engine._metadata.get_service_account_info", autospec=True @@ -343,9 +345,9 @@ def test_build_regional_access_boundary_lookup_url_explicit_email( url = creds._build_regional_access_boundary_lookup_url() mock_get_service_account_info.assert_not_called() - assert url == ( - "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/foo@bar.com/allowedLocations" - ) + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/foo@bar.com/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/foo@bar.com/allowedLocations" + assert url in (expected_url_standard, expected_url_mtls) @mock.patch( "google.auth.compute_engine._metadata.get_universe_domain", autospec=True diff --git a/packages/google-auth/tests/oauth2/test_service_account.py b/packages/google-auth/tests/oauth2/test_service_account.py index f0d8f0759e50..4da25a65407a 100644 --- a/packages/google-auth/tests/oauth2/test_service_account.py +++ b/packages/google-auth/tests/oauth2/test_service_account.py @@ -230,13 +230,16 @@ def test_with_quota_project(self): def test_build_regional_access_boundary_lookup_url(self): credentials = self.make_credentials() - expected_url = ( - "https://iamcredentials.googleapis.com/v1/projects/-/" - "serviceAccounts/{}/allowedLocations".format( - credentials.service_account_email - ) + url = credentials._build_regional_access_boundary_lookup_url() + + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + credentials.service_account_email ) - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + credentials.service_account_email + ) + + assert url in (expected_url_standard, expected_url_mtls) def test_with_token_uri(self): credentials = self.make_credentials() diff --git a/packages/google-auth/tests/test__regional_access_boundary_utils.py b/packages/google-auth/tests/test__regional_access_boundary_utils.py index ab6ec75fd9b8..cc634e2f293c 100644 --- a/packages/google-auth/tests/test__regional_access_boundary_utils.py +++ b/packages/google-auth/tests/test__regional_access_boundary_utils.py @@ -301,6 +301,24 @@ def test_serialization(self): assert unpickled.refresh_manager._lock is not None assert unpickled.refresh_manager._worker is None + def test_unpickle_old_credentials_without_rab(self): + creds = CredentialsImpl() + old_state = creds.__dict__.copy() + if "_rab_manager" in old_state: + del old_state["_rab_manager"] + if "_use_non_blocking_refresh" in old_state: + del old_state["_use_non_blocking_refresh"] + if "_refresh_worker" in old_state: + del old_state["_refresh_worker"] + + new_instance = CredentialsImpl.__new__(CredentialsImpl) + new_instance.__setstate__(old_state) + + assert hasattr(new_instance, "_rab_manager") + assert new_instance._rab_manager is not None + assert new_instance._use_non_blocking_refresh is False + assert new_instance._refresh_worker is not None + @mock.patch( "google.auth._regional_access_boundary_utils._RegionalAccessBoundaryRefreshManager.start_refresh" ) diff --git a/packages/google-auth/tests/test_credentials.py b/packages/google-auth/tests/test_credentials.py index e1528a3ce365..df2c4edac331 100644 --- a/packages/google-auth/tests/test_credentials.py +++ b/packages/google-auth/tests/test_credentials.py @@ -154,6 +154,18 @@ def test_before_request_with_regional_access_boundary(): assert headers["x-allowed-locations"] == DUMMY_BOUNDARY +def test_copy_regional_access_boundary_manager_preserves_type(): + class CustomRefreshManager(object): + pass + + creds = CredentialsImpl() + creds._rab_manager.refresh_manager = CustomRefreshManager() + + new_creds = creds._make_copy() + + assert isinstance(new_creds._rab_manager.refresh_manager, CustomRefreshManager) + + def test_before_request_metrics(): credentials = CredentialsImplWithMetrics() request = "token" diff --git a/packages/google-auth/tests/test_external_account.py b/packages/google-auth/tests/test_external_account.py index dc296f7a52ae..8469a2912fef 100644 --- a/packages/google-auth/tests/test_external_account.py +++ b/packages/google-auth/tests/test_external_account.py @@ -1729,13 +1729,21 @@ def test_before_request_expired(self, utcnow): def test_build_regional_access_boundary_lookup_url_workload(self): credentials = self.make_credentials() - expected_url = "https://iamcredentials.googleapis.com/v1/projects/123456/locations/global/workloadIdentityPools/POOL_ID/allowedLocations" - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + url = credentials._build_regional_access_boundary_lookup_url() + + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/123456/locations/global/workloadIdentityPools/POOL_ID/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/123456/locations/global/workloadIdentityPools/POOL_ID/allowedLocations" + + assert url in (expected_url_standard, expected_url_mtls) def test_build_regional_access_boundary_lookup_url_workforce(self): credentials = self.make_workforce_pool_credentials() - expected_url = "https://iamcredentials.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + url = credentials._build_regional_access_boundary_lookup_url() + + expected_url_standard = "https://iamcredentials.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" + + assert url in (expected_url_standard, expected_url_mtls) @pytest.mark.parametrize( "audience", diff --git a/packages/google-auth/tests/test_external_account_authorized_user.py b/packages/google-auth/tests/test_external_account_authorized_user.py index 648966d924bf..83176d5bbf23 100644 --- a/packages/google-auth/tests/test_external_account_authorized_user.py +++ b/packages/google-auth/tests/test_external_account_authorized_user.py @@ -603,8 +603,12 @@ def test_from_file_full_options(self, tmpdir): def test_build_regional_access_boundary_lookup_url(self): credentials = self.make_credentials() - expected_url = "https://iamcredentials.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + url = credentials._build_regional_access_boundary_lookup_url() + + expected_url_standard = "https://iamcredentials.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/locations/global/workforcePools/POOL_ID/allowedLocations" + + assert url in (expected_url_standard, expected_url_mtls) @pytest.mark.parametrize( "audience", diff --git a/packages/google-auth/tests/test_iam.py b/packages/google-auth/tests/test_iam.py index 26a4c825a7b3..29949b926e34 100644 --- a/packages/google-auth/tests/test_iam.py +++ b/packages/google-auth/tests/test_iam.py @@ -15,6 +15,7 @@ import base64 import datetime import http.client as http_client +import importlib import json from unittest import mock @@ -113,3 +114,37 @@ def test_sign_bytes_retryable_failure(self, mock_time): with pytest.raises(exceptions.TransportError): signer.sign("123") request.call_count == 3 + + +def test_endpoint_constants_mtls(monkeypatch): + from google.auth.transport import _mtls_helper + + # Mock check_use_client_cert to return True (simulating mTLS environment) + monkeypatch.setattr(_mtls_helper, "check_use_client_cert", lambda: True) + + # Force a reload of the iam module to trigger the top-level domain computation + importlib.reload(iam) + + try: + # Verify it constructed the mTLS domain for ALL endpoints + assert ( + "iamcredentials.mtls.googleapis.com" + in iam._SERVICE_ACCOUNT_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT + ) + assert ( + "iamcredentials.mtls.googleapis.com" + in iam._WORKFORCE_POOL_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT + ) + assert ( + "iamcredentials.mtls.googleapis.com" + in iam._WORKLOAD_IDENTITY_POOL_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT + ) + assert "iamcredentials.mtls.googleapis.com" in iam._IAM_ENDPOINT + assert "iamcredentials.mtls.googleapis.com" in iam._IAM_SIGN_ENDPOINT + assert "iamcredentials.mtls.googleapis.com" in iam._IAM_SIGNJWT_ENDPOINT + assert "iamcredentials.mtls.googleapis.com" in iam._IAM_IDTOKEN_ENDPOINT + + finally: + # Restore the original state for other tests by undoing the patch and reloading again + monkeypatch.undo() + importlib.reload(iam) diff --git a/packages/google-auth/tests/test_impersonated_credentials.py b/packages/google-auth/tests/test_impersonated_credentials.py index 500209f663d7..572d961cc3a9 100644 --- a/packages/google-auth/tests/test_impersonated_credentials.py +++ b/packages/google-auth/tests/test_impersonated_credentials.py @@ -719,11 +719,16 @@ def test_build_regional_access_boundary_lookup_url_no_email(self): def test_build_regional_access_boundary_lookup_url_success(self): credentials = self.make_credentials() - # Ensure service_account_email is properly set by default mock - expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + url = credentials._build_regional_access_boundary_lookup_url() + + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + credentials.service_account_email + ) + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( credentials.service_account_email ) - assert credentials._build_regional_access_boundary_lookup_url() == expected_url + + assert url in (expected_url_standard, expected_url_mtls) def test_with_scopes_provide_default_scopes(self): credentials = self.make_credentials() diff --git a/packages/google-auth/tests/test_jwt.py b/packages/google-auth/tests/test_jwt.py index 4c5988469494..9ed90cdf12b8 100644 --- a/packages/google-auth/tests/test_jwt.py +++ b/packages/google-auth/tests/test_jwt.py @@ -553,6 +553,44 @@ def test_before_request_refreshes(self): self.credentials.before_request(None, "GET", "http://example.com?a=1#3", {}) assert self.credentials.valid + def test_build_regional_access_boundary_lookup_url(self): + url = self.credentials._build_regional_access_boundary_lookup_url() + expected_url_standard = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + self.SERVICE_ACCOUNT_EMAIL + ) + expected_url_mtls = "https://iamcredentials.mtls.googleapis.com/v1/projects/-/serviceAccounts/{}/allowedLocations".format( + self.SERVICE_ACCOUNT_EMAIL + ) + + assert url in (expected_url_standard, expected_url_mtls) + + def test_cloning_retains_rab_manager_data(self): + self.credentials._rab_manager._data = mock.sentinel.rab_data + + cloned_claims = self.credentials.with_claims(audience="new-audience") + cloned_quota = self.credentials.with_quota_project("new-quota") + + # Verify references to immutable boundary data are shared + assert cloned_claims._rab_manager._data == mock.sentinel.rab_data + assert cloned_quota._rab_manager._data == mock.sentinel.rab_data + + # Verify manager objects and lock properties are isolated to prevent race conditions + assert cloned_claims._rab_manager is not self.credentials._rab_manager + assert cloned_quota._rab_manager is not self.credentials._rab_manager + + def test_from_signing_credentials_copies_rab_state(self): + from google.oauth2 import service_account + + sa_creds = service_account.Credentials.from_service_account_info( + SERVICE_ACCOUNT_INFO + ) + sa_creds._rab_manager._data = mock.sentinel.rab_data + + jwt_creds = jwt.Credentials.from_signing_credentials(sa_creds, audience="aud") + + assert jwt_creds._rab_manager._data == mock.sentinel.rab_data + assert jwt_creds._rab_manager is not sa_creds._rab_manager + class TestOnDemandCredentials(object): SERVICE_ACCOUNT_EMAIL = "service-account@example.com" diff --git a/packages/google-auth/tests_async/oauth2/test__client_async.py b/packages/google-auth/tests_async/oauth2/test__client_async.py index 5ad9596cf85c..3cea1f0ee330 100644 --- a/packages/google-auth/tests_async/oauth2/test__client_async.py +++ b/packages/google-auth/tests_async/oauth2/test__client_async.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import datetime import http.client as http_client import json @@ -492,3 +493,68 @@ async def test__token_endpoint_request_no_throw_with_retry(can_retry): assert mock_request.call_count == 3 else: assert mock_request.call_count == 1 + + +@pytest.mark.asyncio +async def test__lookup_regional_access_boundary_success(): + request = make_request({"encodedLocations": "0xA30", "locations": ["us-central1"]}) + result = await _client._lookup_regional_access_boundary( + request, "http://example.com" + ) + assert result == {"encodedLocations": "0xA30", "locations": ["us-central1"]} + + +@pytest.mark.asyncio +async def test__lookup_regional_access_boundary_malformed(): + request = make_request({"locations": ["us-central1"]}) + result = await _client._lookup_regional_access_boundary( + request, "http://example.com" + ) + assert result is None + + +@pytest.mark.asyncio +@mock.patch("asyncio.wait_for", side_effect=asyncio.TimeoutError) +async def test__lookup_regional_access_boundary_request_no_throw_timeout(mock_wait_for): + request = mock.AsyncMock(spec=["transport.Request"]) + + ( + success, + data, + retryable, + ) = await _client._lookup_regional_access_boundary_request_no_throw( + request, "http://example.com", fail_fast=True + ) + + assert success is False + assert data == {} + assert retryable is False + + +@pytest.mark.asyncio +@mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) +async def test__lookup_regional_access_boundary_request_no_throw_bad_gateway_retry( + mock_sleep, +): + bad_gateway_response = mock.AsyncMock(spec=["transport.Response"]) + bad_gateway_response.status = http_client.BAD_GATEWAY + bad_gateway_response.content = mock.AsyncMock(return_value=b"{}") + + ok_response = mock.AsyncMock(spec=["transport.Response"]) + ok_response.status = http_client.OK + ok_response.content = mock.AsyncMock(return_value=b'{"encodedLocations": "0xA30"}') + + request = mock.AsyncMock(spec=["transport.Request"]) + request.side_effect = [bad_gateway_response, ok_response] + + ( + success, + data, + retryable, + ) = await _client._lookup_regional_access_boundary_request_no_throw( + request, "http://example.com" + ) + + assert success is True + assert data == {"encodedLocations": "0xA30"} + assert request.call_count == 2 diff --git a/packages/google-auth/tests_async/oauth2/test_service_account_async.py b/packages/google-auth/tests_async/oauth2/test_service_account_async.py index 5a9a89fcaac2..d633f870e400 100644 --- a/packages/google-auth/tests_async/oauth2/test_service_account_async.py +++ b/packages/google-auth/tests_async/oauth2/test_service_account_async.py @@ -229,6 +229,73 @@ async def test_before_request_refreshes(self, jwt_grant): # Credentials should now be valid. assert credentials.valid + @mock.patch( + "google.oauth2._client_async._lookup_regional_access_boundary", autospec=True + ) + @pytest.mark.asyncio + async def test_before_request_triggers_rab_refresh(self, mock_lookup): + credentials = self.make_credentials() + credentials.token = "tok" + + mock_lookup.return_value = { + "locations": ["us-central1", "europe-west1"], + "encodedLocations": "0xA30", + } + + request = mock.AsyncMock(spec=["transport.Request"]) + headers1 = {} + + with mock.patch.object( + credentials, + "_is_regional_access_boundary_lookup_required", + return_value=True, + ): + # First request triggers background refresh, but proceeds without the header + await credentials.before_request( + request, "GET", "https://storage.googleapis.com/bucket", headers1 + ) + assert "x-allowed-locations" not in headers1 + + # Wait for the background task to finish and update the cache + await credentials._rab_manager.refresh_manager._worker_task + assert mock_lookup.called + + # Second request should now find the data in the cache and attach the header + headers2 = {} + await credentials.before_request( + request, "GET", "https://storage.googleapis.com/bucket", headers2 + ) + assert headers2["x-allowed-locations"] == "0xA30" + + @mock.patch( + "google.oauth2._client_async._lookup_regional_access_boundary", autospec=True + ) + @pytest.mark.asyncio + async def test_before_request_rab_refresh_failure_ignored(self, mock_lookup): + credentials = self.make_credentials() + credentials.token = "tok" + + mock_lookup.side_effect = Exception("Transport failed") + + request = mock.AsyncMock(spec=["transport.Request"]) + headers = {} + + with mock.patch.object( + credentials, + "_is_regional_access_boundary_lookup_required", + return_value=True, + ): + # The exception must be caught gracefully and not bubble up + await credentials.before_request( + request, "GET", "https://storage.googleapis.com/bucket", headers + ) + + # Wait for the background task to finish + await credentials._rab_manager.refresh_manager._worker_task + + assert mock_lookup.called + assert "x-allowed-locations" not in headers + class TestIDTokenCredentials(object): SERVICE_ACCOUNT_EMAIL = "service-account@example.com" diff --git a/packages/google-auth/tests_async/test__regional_access_boundary_utils.py b/packages/google-auth/tests_async/test__regional_access_boundary_utils.py new file mode 100644 index 000000000000..268ee37261c8 --- /dev/null +++ b/packages/google-auth/tests_async/test__regional_access_boundary_utils.py @@ -0,0 +1,84 @@ +# Copyright 2026 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest import mock + +import pytest # type: ignore + +from google.auth import _regional_access_boundary_utils + + +@pytest.mark.asyncio +async def test_async_refresh_manager_start_refresh(): + credentials = mock.AsyncMock() + credentials._lookup_regional_access_boundary.return_value = { + "encodedLocations": "0xA30" + } + + request = mock.Mock() + rab_manager = mock.Mock() + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + manager.start_refresh(credentials, request, rab_manager) + + # Wait for the background task to finish + await manager._worker_task + + credentials._lookup_regional_access_boundary.assert_called_once_with(request) + rab_manager.process_regional_access_boundary_info.assert_called_once_with( + {"encodedLocations": "0xA30"} + ) + + +@pytest.mark.asyncio +async def test_async_refresh_manager_duplicate_refresh_prevented(): + credentials = mock.AsyncMock() + + # Use events to control the concurrency timing + lookup_started = asyncio.Event() + lookup_finish = asyncio.Event() + + async def controlled_lookup(*args, **kwargs): + lookup_started.set() # Signal that the background lookup has started. + await lookup_finish.wait() # Block until the test allows the lookup to complete. + return {"encodedLocations": "0xA30"} + + credentials._lookup_regional_access_boundary.side_effect = controlled_lookup + + request = mock.Mock() + rab_manager = mock.Mock() + + manager = ( + _regional_access_boundary_utils._AsyncRegionalAccessBoundaryRefreshManager() + ) + + # Start the initial refresh task in the background. + manager.start_refresh(credentials, request, rab_manager) + + # Wait until the background task has begun executing the lookup. + await lookup_started.wait() + + # Attempt a second refresh while the initial task is still in progress. + manager.start_refresh(credentials, request, rab_manager) + + # Unblock the initial task and wait for it to complete. + lookup_finish.set() + await manager._worker_task + + # Verify that the second refresh request was ignored and only one lookup occurred. + assert credentials._lookup_regional_access_boundary.call_count == 1