diff --git a/.kokoro/samples/python3.14/common.cfg b/.kokoro/samples/python3.14/common.cfg new file mode 100644 index 000000000..c82a73a9e --- /dev/null +++ b/.kokoro/samples/python3.14/common.cfg @@ -0,0 +1,37 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +# Build logs will be here +action { + define_artifacts { + regex: "**/*sponge_log.xml" + } +} + +# Specify which tests to run +env_vars: { + key: "RUN_TESTS_SESSION" + value: "unit-3.14" +} + +# Download trampoline resources. +gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/trampoline" + +# Download resources for system tests (service account key, etc.) +gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/google-auth-library-python" + +# Use the trampoline script to run in docker. +build_file: "google-auth-library-python/.kokoro/trampoline.sh" + +# Configure the docker image for kokoro-trampoline. +env_vars: { + key: "TRAMPOLINE_IMAGE" + value: "gcr.io/cloud-devrel-kokoro-resources/python-multi" +} +env_vars: { + key: "TRAMPOLINE_BUILD_FILE" + value: "github/google-auth-library-python/.kokoro/build.sh" +} +env_vars: { + key: "TRAMPOLINE_BUILD_FILE" + value: "github/google-auth-library-python/.kokoro/samples-test-setup.sh" +} \ No newline at end of file diff --git a/.kokoro/samples/python3.14/continuous.cfg b/.kokoro/samples/python3.14/continuous.cfg new file mode 100644 index 000000000..a1c8d9759 --- /dev/null +++ b/.kokoro/samples/python3.14/continuous.cfg @@ -0,0 +1,6 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +env_vars: { + key: "INSTALL_LIBRARY_FROM_SOURCE" + value: "True" +} \ No newline at end of file diff --git a/.kokoro/samples/python3.14/periodic-head.cfg b/.kokoro/samples/python3.14/periodic-head.cfg new file mode 100644 index 000000000..83eace873 --- /dev/null +++ b/.kokoro/samples/python3.14/periodic-head.cfg @@ -0,0 +1,11 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +env_vars: { + key: "INSTALL_LIBRARY_FROM_SOURCE" + value: "True" +} + +env_vars: { + key: "TRAMPOLINE_BUILD_FILE" + value: "github/google-auth-library-python/.kokoro/test-samples-against-head.sh" +} diff --git a/.kokoro/samples/python3.14/periodic.cfg b/.kokoro/samples/python3.14/periodic.cfg new file mode 100644 index 000000000..71cd1e597 --- /dev/null +++ b/.kokoro/samples/python3.14/periodic.cfg @@ -0,0 +1,6 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +env_vars: { + key: "INSTALL_LIBRARY_FROM_SOURCE" + value: "False" +} diff --git a/.kokoro/samples/python3.14/presubmit.cfg b/.kokoro/samples/python3.14/presubmit.cfg new file mode 100644 index 000000000..a1c8d9759 --- /dev/null +++ b/.kokoro/samples/python3.14/presubmit.cfg @@ -0,0 +1,6 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +env_vars: { + key: "INSTALL_LIBRARY_FROM_SOURCE" + value: "True" +} \ No newline at end of file diff --git a/.librarian/state.yaml b/.librarian/state.yaml index 9b7e2ca09..826e2646e 100644 --- a/.librarian/state.yaml +++ b/.librarian/state.yaml @@ -1,7 +1,7 @@ image: us-central1-docker.pkg.dev/cloud-sdk-librarian-prod/images-prod/python-librarian-generator:latest libraries: - id: google-auth - version: 2.43.0 + version: 2.44.0 last_generated_commit: 102d9f92ac6ed649a61efd9b208e4d1de278e9bb apis: [] source_roots: diff --git a/CHANGELOG.md b/CHANGELOG.md index a71cd68e3..51f0c787a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,24 @@ [1]: https://pypi.org/project/google-auth/#history +## [2.44.0](https://github.com/googleapis/google-auth-library-python/compare/v2.43.0...v2.44.0) (2025-12-13) + + +### Features + +* support Python 3.14 (#1822) ([0f7097e78f247665b6ef0287d482033f7be2ed6d](https://github.com/googleapis/google-auth-library-python/commit/0f7097e78f247665b6ef0287d482033f7be2ed6d)) +* add ecdsa p-384 support (#1872) ([39c381a5f6881b590025f36d333d12eff8dc60fc](https://github.com/googleapis/google-auth-library-python/commit/39c381a5f6881b590025f36d333d12eff8dc60fc)) +* MDS connections use mTLS (#1856) ([0387bb95713653d47e846cad3a010eb55ef2db4c](https://github.com/googleapis/google-auth-library-python/commit/0387bb95713653d47e846cad3a010eb55ef2db4c)) +* Implement token revocation in STS client and add revoke() metho… (#1849) ([d5638986ca03ee95bfffa9ad821124ed7e903e63](https://github.com/googleapis/google-auth-library-python/commit/d5638986ca03ee95bfffa9ad821124ed7e903e63)) +* Add shlex to correctly parse executable commands with spaces (#1855) ([cf6fc3cced78bc1362a7fe596c32ebc9ce03c26b](https://github.com/googleapis/google-auth-library-python/commit/cf6fc3cced78bc1362a7fe596c32ebc9ce03c26b)) + + +### Bug Fixes + +* Use public refresh method for source credentials in ImpersonatedCredentials (#1884) ([e0c3296f471747258f6d98d2d9bfde636358ecde](https://github.com/googleapis/google-auth-library-python/commit/e0c3296f471747258f6d98d2d9bfde636358ecde)) +* Add temporary patch to workload cert logic to accomodate Cloud Run mis-configuration (#1880) ([78de7907b8bdb7b5510e3c6fa8a3f3721e2436d7](https://github.com/googleapis/google-auth-library-python/commit/78de7907b8bdb7b5510e3c6fa8a3f3721e2436d7)) +* Delegate workload cert and key default lookup to helper function (#1877) ([b0993c7edaba505d0fb0628af28760c43034c959](https://github.com/googleapis/google-auth-library-python/commit/b0993c7edaba505d0fb0628af28760c43034c959)) + ## [2.43.0](https://github.com/googleapis/google-cloud-python/compare/google-auth-v2.42.1...google-auth-v2.43.0) (2025-11-05) diff --git a/google/auth/_service_account_info.py b/google/auth/_service_account_info.py index 6b64adcae..c432080a9 100644 --- a/google/auth/_service_account_info.py +++ b/google/auth/_service_account_info.py @@ -56,7 +56,7 @@ def from_dict(data, require=None, use_rsa_signer=True): if use_rsa_signer: signer = crypt.RSASigner.from_service_account_info(data) else: - signer = crypt.ES256Signer.from_service_account_info(data) + signer = crypt.EsSigner.from_service_account_info(data) return signer diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index ddbe8ac2f..96f1ff526 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -24,15 +24,23 @@ import os from urllib.parse import urljoin +import requests + from google.auth import _helpers from google.auth import environment_vars from google.auth import exceptions from google.auth import metrics from google.auth import transport from google.auth._exponential_backoff import ExponentialBackoff +from google.auth.compute_engine import _mtls + _LOGGER = logging.getLogger(__name__) +_GCE_DEFAULT_MDS_IP = "169.254.169.254" +_GCE_DEFAULT_HOST = "metadata.google.internal" +_GCE_DEFAULT_MDS_HOSTS = [_GCE_DEFAULT_HOST, _GCE_DEFAULT_MDS_IP] + # Environment variable GCE_METADATA_HOST is originally named # GCE_METADATA_ROOT. For compatibility reasons, here it checks # the new variable first; if not set, the system falls back @@ -40,15 +48,48 @@ _GCE_METADATA_HOST = os.getenv(environment_vars.GCE_METADATA_HOST, None) if not _GCE_METADATA_HOST: _GCE_METADATA_HOST = os.getenv( - environment_vars.GCE_METADATA_ROOT, "metadata.google.internal" + environment_vars.GCE_METADATA_ROOT, _GCE_DEFAULT_HOST + ) + + +def _validate_gce_mds_configured_environment(): + """Validates the GCE metadata server environment configuration for mTLS. + + mTLS is only supported when connecting to the default metadata server hosts. + If we are in strict mode (which requires mTLS), ensure that the metadata host + has not been overridden to a custom value (which means mTLS will fail). + + Raises: + google.auth.exceptions.MutualTLSChannelError: if the environment + configuration is invalid for mTLS. + """ + mode = _mtls._parse_mds_mode() + if mode == _mtls.MdsMtlsMode.STRICT: + # mTLS is only supported when connecting to the default metadata host. + # Raise an exception if we are in strict mode (which requires mTLS) + # but the metadata host has been overridden to a custom MDS. (which means mTLS will fail) + if _GCE_METADATA_HOST not in _GCE_DEFAULT_MDS_HOSTS: + raise exceptions.MutualTLSChannelError( + "Mutual TLS is required, but the metadata host has been overridden. " + "mTLS is only supported when connecting to the default metadata host." + ) + + +def _get_metadata_root(use_mtls: bool): + """Returns the metadata server root URL.""" + + scheme = "https" if use_mtls else "http" + return "{}://{}/computeMetadata/v1/".format(scheme, _GCE_METADATA_HOST) + + +def _get_metadata_ip_root(use_mtls: bool): + """Returns the metadata server IP root URL.""" + scheme = "https" if use_mtls else "http" + return "{}://{}".format( + scheme, os.getenv(environment_vars.GCE_METADATA_IP, _GCE_DEFAULT_MDS_IP) ) -_METADATA_ROOT = "http://{}/computeMetadata/v1/".format(_GCE_METADATA_HOST) -# This is used to ping the metadata server, it avoids the cost of a DNS -# lookup. -_METADATA_IP_ROOT = "http://{}".format( - os.getenv(environment_vars.GCE_METADATA_IP, "169.254.169.254") -) + _METADATA_FLAVOR_HEADER = "metadata-flavor" _METADATA_FLAVOR_VALUE = "Google" _METADATA_HEADERS = {_METADATA_FLAVOR_HEADER: _METADATA_FLAVOR_VALUE} @@ -102,6 +143,33 @@ def detect_gce_residency_linux(): return content.startswith(_GOOGLE) +def _prepare_request_for_mds(request, use_mtls=False) -> None: + """Prepares a request for the metadata server. + + This will check if mTLS should be used and mount the mTLS adapter if needed. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + use_mtls (bool): Whether to use mTLS for the request. + + Returns: + google.auth.transport.Request: A request object to use. + If mTLS is enabled, the request will have the mTLS adapter mounted. + Otherwise, the original request will be returned unchanged. + """ + # Only modify the request if mTLS is enabled. + if use_mtls: + # Ensure the request has a session to mount the adapter to. + if not request.session: + request.session = requests.Session() + + adapter = _mtls.MdsMtlsAdapter() + # Mount the adapter for all default GCE metadata hosts. + for host in _GCE_DEFAULT_MDS_HOSTS: + request.session.mount(f"https://{host}/", adapter) + + def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): """Checks to see if the metadata server is available. @@ -115,6 +183,8 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): Returns: bool: True if the metadata server is reachable, False otherwise. """ + use_mtls = _mtls.should_use_mds_mtls() + _prepare_request_for_mds(request, use_mtls=use_mtls) # NOTE: The explicit ``timeout`` is a workaround. The underlying # issue is that resolving an unknown host on some networks will take # 20-30 seconds; making this timeout short fixes the issue, but @@ -129,7 +199,10 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): for attempt in backoff: try: response = request( - url=_METADATA_IP_ROOT, method="GET", headers=headers, timeout=timeout + url=_get_metadata_ip_root(use_mtls), + method="GET", + headers=headers, + timeout=timeout, ) metadata_flavor = response.headers.get(_METADATA_FLAVOR_HEADER) @@ -153,7 +226,7 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): def get( request, path, - root=_METADATA_ROOT, + root=None, params=None, recursive=False, retry_count=5, @@ -168,7 +241,8 @@ def get( HTTP requests. path (str): The resource to retrieve. For example, ``'instance/service-accounts/default'``. - root (str): The full path to the metadata server root. + root (Optional[str]): The full path to the metadata server root. If not + provided, the default root will be used. params (Optional[Mapping[str, str]]): A mapping of query parameter keys to values. recursive (bool): Whether to do a recursive query of metadata. See @@ -189,7 +263,24 @@ def get( Raises: google.auth.exceptions.TransportError: if an error occurred while retrieving metadata. + google.auth.exceptions.MutualTLSChannelError: if using mtls and the environment + configuration is invalid for mTLS (for example, the metadata host + has been overridden in strict mTLS mode). + """ + use_mtls = _mtls.should_use_mds_mtls() + # Prepare the request object for mTLS if needed. + # This will create a new request object with the mTLS session. + _prepare_request_for_mds(request, use_mtls=use_mtls) + + if root is None: + root = _get_metadata_root(use_mtls) + + # mTLS is only supported when connecting to the default metadata host. + # If we are in strict mode (which requires mTLS), ensure that the metadata host + # has not been overridden to a non-default host value (which means mTLS will fail). + _validate_gce_mds_configured_environment() + base_url = urljoin(root, path) query_params = {} if params is None else params diff --git a/google/auth/compute_engine/_mtls.py b/google/auth/compute_engine/_mtls.py new file mode 100644 index 000000000..6525dd03e --- /dev/null +++ b/google/auth/compute_engine/_mtls.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Mutual TLS for Google Compute Engine metadata server.""" + +from dataclasses import dataclass, field +import enum +import logging +import os +from pathlib import Path +import ssl +from urllib.parse import urlparse, urlunparse + +import requests +from requests.adapters import HTTPAdapter + +from google.auth import environment_vars, exceptions + + +_LOGGER = logging.getLogger(__name__) + +_WINDOWS_OS_NAME = "nt" + +# MDS mTLS certificate paths based on OS. +# Documentation to well known locations can be found at: +# https://cloud.google.com/compute/docs/metadata/overview#https-mds-certificates +_WINDOWS_MTLS_COMPONENTS_BASE_PATH = Path("C:/ProgramData/Google/ComputeEngine") +_MTLS_COMPONENTS_BASE_PATH = Path("/run/google-mds-mtls") + + +def _get_mds_root_crt_path(): + if os.name == _WINDOWS_OS_NAME: + return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-root.crt" + else: + return _MTLS_COMPONENTS_BASE_PATH / "root.crt" + + +def _get_mds_client_combined_cert_path(): + if os.name == _WINDOWS_OS_NAME: + return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-client.key" + else: + return _MTLS_COMPONENTS_BASE_PATH / "client.key" + + +@dataclass +class MdsMtlsConfig: + ca_cert_path: Path = field( + default_factory=_get_mds_root_crt_path + ) # path to CA certificate + client_combined_cert_path: Path = field( + default_factory=_get_mds_client_combined_cert_path + ) # path to file containing client certificate and key + + +def _certs_exist(mds_mtls_config: MdsMtlsConfig): + """Checks if the mTLS certificates exist.""" + return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists( + mds_mtls_config.client_combined_cert_path + ) + + +class MdsMtlsMode(enum.Enum): + """MDS mTLS mode. Used to configure connection behavior when connecting to MDS. + + STRICT: Always use HTTPS/mTLS. If certificates are not found locally, an error will be returned. + NONE: Never use mTLS. Requests will use regular HTTP. + DEFAULT: Use mTLS if certificates are found locally, otherwise use regular HTTP. + """ + + STRICT = "strict" + NONE = "none" + DEFAULT = "default" + + +def _parse_mds_mode(): + """Parses the GCE_METADATA_MTLS_MODE environment variable.""" + mode_str = os.environ.get( + environment_vars.GCE_METADATA_MTLS_MODE, "default" + ).lower() + try: + return MdsMtlsMode(mode_str) + except ValueError: + raise ValueError( + "Invalid value for GCE_METADATA_MTLS_MODE. Must be one of 'strict', 'none', or 'default'." + ) + + +def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()): + """Determines if mTLS should be used for the metadata server.""" + mode = _parse_mds_mode() + if mode == MdsMtlsMode.STRICT: + if not _certs_exist(mds_mtls_config): + raise exceptions.MutualTLSChannelError( + "mTLS certificates not found in strict mode." + ) + return True + elif mode == MdsMtlsMode.NONE: + return False + else: # Default mode + return _certs_exist(mds_mtls_config) + + +class MdsMtlsAdapter(HTTPAdapter): + """An HTTP adapter that uses mTLS for the metadata server.""" + + def __init__( + self, mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig(), *args, **kwargs + ): + self.ssl_context = ssl.create_default_context() + self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path) + self.ssl_context.load_cert_chain( + certfile=mds_mtls_config.client_combined_cert_path + ) + super(MdsMtlsAdapter, self).__init__(*args, **kwargs) + + def init_poolmanager(self, *args, **kwargs): + kwargs["ssl_context"] = self.ssl_context + return super(MdsMtlsAdapter, self).init_poolmanager(*args, **kwargs) + + def proxy_manager_for(self, *args, **kwargs): + kwargs["ssl_context"] = self.ssl_context + return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs) + + def send(self, request, **kwargs): + # If we are in strict mode, always use mTLS (no HTTP fallback) + if _parse_mds_mode() == MdsMtlsMode.STRICT: + return super(MdsMtlsAdapter, self).send(request, **kwargs) + + # In default mode, attempt mTLS first, then fallback to HTTP on failure + try: + response = super(MdsMtlsAdapter, self).send(request, **kwargs) + response.raise_for_status() + return response + except ( + ssl.SSLError, + requests.exceptions.SSLError, + requests.exceptions.HTTPError, + ) as e: + _LOGGER.warning( + "mTLS connection to Compute Engine Metadata server failed. " + "Falling back to standard HTTP. Reason: %s", + e, + ) + # Fallback to standard HTTP + parsed_original_url = urlparse(request.url) + http_fallback_url = urlunparse(parsed_original_url._replace(scheme="http")) + request.url = http_fallback_url + + # Use a standard HTTPAdapter for the fallback + http_adapter = HTTPAdapter() + return http_adapter.send(request, **kwargs) diff --git a/google/auth/crypt/__init__.py b/google/auth/crypt/__init__.py index 6d147e706..59519b475 100644 --- a/google/auth/crypt/__init__.py +++ b/google/auth/crypt/__init__.py @@ -40,13 +40,19 @@ from google.auth.crypt import base from google.auth.crypt import rsa +# google.auth.crypt.es depends on the crytpography module which may not be +# successfully imported depending on the system. try: + from google.auth.crypt import es from google.auth.crypt import es256 except ImportError: # pragma: NO COVER + es = None # type: ignore es256 = None # type: ignore -if es256 is not None: # pragma: NO COVER +if es is not None and es256 is not None: # pragma: NO COVER __all__ = [ + "EsSigner", + "EsVerifier", "ES256Signer", "ES256Verifier", "RSASigner", @@ -54,6 +60,11 @@ "Signer", "Verifier", ] + + EsSigner = es.EsSigner + EsVerifier = es.EsVerifier + ES256Signer = es256.ES256Signer + ES256Verifier = es256.ES256Verifier else: # pragma: NO COVER __all__ = ["RSASigner", "RSAVerifier", "Signer", "Verifier"] @@ -65,10 +76,6 @@ RSASigner = rsa.RSASigner RSAVerifier = rsa.RSAVerifier -if es256 is not None: # pragma: NO COVER - ES256Signer = es256.ES256Signer - ES256Verifier = es256.ES256Verifier - def verify_signature(message, signature, certs, verifier_cls=rsa.RSAVerifier): """Verify an RSA or ECDSA cryptographic signature. diff --git a/google/auth/crypt/es.py b/google/auth/crypt/es.py new file mode 100644 index 000000000..f9466af3c --- /dev/null +++ b/google/auth/crypt/es.py @@ -0,0 +1,221 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ECDSA verifier and signer that use the ``cryptography`` library. +""" + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import cryptography.exceptions +from cryptography.hazmat import backends +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature +from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature +import cryptography.x509 + +from google.auth import _helpers +from google.auth.crypt import base + + +_CERTIFICATE_MARKER = b"-----BEGIN CERTIFICATE-----" +_BACKEND = backends.default_backend() +_PADDING = padding.PKCS1v15() + + +@dataclass +class _ESAttributes: + """A class that models ECDSA attributes. + + Attributes: + rs_size (int): Size for ASN.1 r and s size. + sha_algo (hashes.HashAlgorithm): Hash algorithm. + algorithm (str): Algorithm name. + """ + + rs_size: int + sha_algo: hashes.HashAlgorithm + algorithm: str + + @classmethod + def from_key( + cls, key: Union[ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey] + ): + return cls.from_curve(key.curve) + + @classmethod + def from_curve(cls, curve: ec.EllipticCurve): + # ECDSA raw signature has (r||s) format where r,s are two + # integers of size 32 bytes for P-256 curve and 48 bytes + # for P-384 curve. For P-256 curve, we use SHA256 hash algo, + # and for P-384 curve we use SHA384 algo. + if isinstance(curve, ec.SECP384R1): + return cls(48, hashes.SHA384(), "ES384") + else: + # default to ES256 + return cls(32, hashes.SHA256(), "ES256") + + +class EsVerifier(base.Verifier): + """Verifies ECDSA cryptographic signatures using public keys. + + Args: + public_key ( + cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey): + The public key used to verify signatures. + """ + + def __init__(self, public_key: ec.EllipticCurvePublicKey) -> None: + self._pubkey = public_key + self._attributes = _ESAttributes.from_key(public_key) + + @_helpers.copy_docstring(base.Verifier) + def verify(self, message: bytes, signature: bytes) -> bool: + # First convert (r||s) raw signature to ASN1 encoded signature. + sig_bytes = _helpers.to_bytes(signature) + if len(sig_bytes) != self._attributes.rs_size * 2: + return False + r = int.from_bytes(sig_bytes[: self._attributes.rs_size], byteorder="big") + s = int.from_bytes(sig_bytes[self._attributes.rs_size :], byteorder="big") + asn1_sig = encode_dss_signature(r, s) + + message = _helpers.to_bytes(message) + try: + self._pubkey.verify(asn1_sig, message, ec.ECDSA(self._attributes.sha_algo)) + return True + except (ValueError, cryptography.exceptions.InvalidSignature): + return False + + @classmethod + def from_string(cls, public_key: Union[str, bytes]) -> "EsVerifier": + """Construct an Verifier instance from a public key or public + certificate string. + + Args: + public_key (Union[str, bytes]): The public key in PEM format or the + x509 public key certificate. + + Returns: + Verifier: The constructed verifier. + + Raises: + ValueError: If the public key can't be parsed. + """ + public_key_data = _helpers.to_bytes(public_key) + + if _CERTIFICATE_MARKER in public_key_data: + cert = cryptography.x509.load_pem_x509_certificate( + public_key_data, _BACKEND + ) + pubkey = cert.public_key() # type: Any + + else: + pubkey = serialization.load_pem_public_key(public_key_data, _BACKEND) + + if not isinstance(pubkey, ec.EllipticCurvePublicKey): + raise TypeError("Expected public key of type EllipticCurvePublicKey") + + return cls(pubkey) + + +class EsSigner(base.Signer, base.FromServiceAccountMixin): + """Signs messages with an ECDSA private key. + + Args: + private_key ( + cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey): + The private key to sign with. + key_id (str): Optional key ID used to identify this private key. This + can be useful to associate the private key with its associated + public key or certificate. + """ + + def __init__( + self, private_key: ec.EllipticCurvePrivateKey, key_id: Optional[str] = None + ) -> None: + self._key = private_key + self._key_id = key_id + self._attributes = _ESAttributes.from_key(private_key) + + @property + def algorithm(self) -> str: + """Name of the algorithm used to sign messages. + Returns: + str: The algorithm name. + """ + return self._attributes.algorithm + + @property # type: ignore + @_helpers.copy_docstring(base.Signer) + def key_id(self) -> Optional[str]: + return self._key_id + + @_helpers.copy_docstring(base.Signer) + def sign(self, message: bytes) -> bytes: + message = _helpers.to_bytes(message) + asn1_signature = self._key.sign(message, ec.ECDSA(self._attributes.sha_algo)) + + # Convert ASN1 encoded signature to (r||s) raw signature. + (r, s) = decode_dss_signature(asn1_signature) + return r.to_bytes(self._attributes.rs_size, byteorder="big") + s.to_bytes( + self._attributes.rs_size, byteorder="big" + ) + + @classmethod + def from_string( + cls, key: Union[bytes, str], key_id: Optional[str] = None + ) -> "EsSigner": + """Construct a RSASigner from a private key in PEM format. + + Args: + key (Union[bytes, str]): Private key in PEM format. + key_id (str): An optional key id used to identify the private key. + + Returns: + google.auth.crypt._cryptography_rsa.RSASigner: The + constructed signer. + + Raises: + ValueError: If ``key`` is not ``bytes`` or ``str`` (unicode). + UnicodeDecodeError: If ``key`` is ``bytes`` but cannot be decoded + into a UTF-8 ``str``. + ValueError: If ``cryptography`` "Could not deserialize key data." + """ + key_bytes = _helpers.to_bytes(key) + private_key = serialization.load_pem_private_key( + key_bytes, password=None, backend=_BACKEND + ) + + if not isinstance(private_key, ec.EllipticCurvePrivateKey): + raise TypeError("Expected private key of type EllipticCurvePrivateKey") + + return cls(private_key, key_id=key_id) + + def __getstate__(self) -> Dict[str, Any]: + """Pickle helper that serializes the _key attribute.""" + state = self.__dict__.copy() + state["_key"] = self._key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + """Pickle helper that deserializes the _key attribute.""" + state["_key"] = serialization.load_pem_private_key(state["_key"], None) + self.__dict__.update(state) diff --git a/google/auth/crypt/es256.py b/google/auth/crypt/es256.py index 820e4becc..e7bda5d3f 100644 --- a/google/auth/crypt/es256.py +++ b/google/auth/crypt/es256.py @@ -15,93 +15,22 @@ """ECDSA (ES256) verifier and signer that use the ``cryptography`` library. """ -from cryptography import utils # type: ignore -import cryptography.exceptions -from cryptography.hazmat import backends -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import ec -from cryptography.hazmat.primitives.asymmetric import padding -from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature -from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature -import cryptography.x509 +from google.auth.crypt.es import EsSigner +from google.auth.crypt.es import EsVerifier -from google.auth import _helpers -from google.auth.crypt import base - -_CERTIFICATE_MARKER = b"-----BEGIN CERTIFICATE-----" -_BACKEND = backends.default_backend() -_PADDING = padding.PKCS1v15() - - -class ES256Verifier(base.Verifier): +class ES256Verifier(EsVerifier): """Verifies ECDSA cryptographic signatures using public keys. Args: - public_key ( - cryptography.hazmat.primitives.asymmetric.ec.ECDSAPublicKey): - The public key used to verify signatures. + public_key (cryptography.hazmat.primitives.asymmetric.ec.ECDSAPublicKey): The public key used to verify + signatures. """ - def __init__(self, public_key): - self._pubkey = public_key - - @_helpers.copy_docstring(base.Verifier) - def verify(self, message, signature): - # First convert (r||s) raw signature to ASN1 encoded signature. - sig_bytes = _helpers.to_bytes(signature) - if len(sig_bytes) != 64: - return False - r = ( - int.from_bytes(sig_bytes[:32], byteorder="big") - if _helpers.is_python_3() - else utils.int_from_bytes(sig_bytes[:32], byteorder="big") - ) - s = ( - int.from_bytes(sig_bytes[32:], byteorder="big") - if _helpers.is_python_3() - else utils.int_from_bytes(sig_bytes[32:], byteorder="big") - ) - asn1_sig = encode_dss_signature(r, s) - - message = _helpers.to_bytes(message) - try: - self._pubkey.verify(asn1_sig, message, ec.ECDSA(hashes.SHA256())) - return True - except (ValueError, cryptography.exceptions.InvalidSignature): - return False - - @classmethod - def from_string(cls, public_key): - """Construct an Verifier instance from a public key or public - certificate string. - - Args: - public_key (Union[str, bytes]): The public key in PEM format or the - x509 public key certificate. - - Returns: - Verifier: The constructed verifier. - - Raises: - ValueError: If the public key can't be parsed. - """ - public_key_data = _helpers.to_bytes(public_key) - - if _CERTIFICATE_MARKER in public_key_data: - cert = cryptography.x509.load_pem_x509_certificate( - public_key_data, _BACKEND - ) - pubkey = cert.public_key() - - else: - pubkey = serialization.load_pem_public_key(public_key_data, _BACKEND) + pass - return cls(pubkey) - -class ES256Signer(base.Signer, base.FromServiceAccountMixin): +class ES256Signer(EsSigner): """Signs messages with an ECDSA private key. Args: @@ -113,63 +42,4 @@ class ES256Signer(base.Signer, base.FromServiceAccountMixin): public key or certificate. """ - def __init__(self, private_key, key_id=None): - self._key = private_key - self._key_id = key_id - - @property # type: ignore - @_helpers.copy_docstring(base.Signer) - def key_id(self): - return self._key_id - - @_helpers.copy_docstring(base.Signer) - def sign(self, message): - message = _helpers.to_bytes(message) - asn1_signature = self._key.sign(message, ec.ECDSA(hashes.SHA256())) - - # Convert ASN1 encoded signature to (r||s) raw signature. - (r, s) = decode_dss_signature(asn1_signature) - return ( - (r.to_bytes(32, byteorder="big") + s.to_bytes(32, byteorder="big")) - if _helpers.is_python_3() - else (utils.int_to_bytes(r, 32) + utils.int_to_bytes(s, 32)) - ) - - @classmethod - def from_string(cls, key, key_id=None): - """Construct a RSASigner from a private key in PEM format. - - Args: - key (Union[bytes, str]): Private key in PEM format. - key_id (str): An optional key id used to identify the private key. - - Returns: - google.auth.crypt._cryptography_rsa.RSASigner: The - constructed signer. - - Raises: - ValueError: If ``key`` is not ``bytes`` or ``str`` (unicode). - UnicodeDecodeError: If ``key`` is ``bytes`` but cannot be decoded - into a UTF-8 ``str``. - ValueError: If ``cryptography`` "Could not deserialize key data." - """ - key = _helpers.to_bytes(key) - private_key = serialization.load_pem_private_key( - key, password=None, backend=_BACKEND - ) - return cls(private_key, key_id=key_id) - - def __getstate__(self): - """Pickle helper that serializes the _key attribute.""" - state = self.__dict__.copy() - state["_key"] = self._key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - return state - - def __setstate__(self, state): - """Pickle helper that deserializes the _key attribute.""" - state["_key"] = serialization.load_pem_private_key(state["_key"], None) - self.__dict__.update(state) + pass diff --git a/google/auth/environment_vars.py b/google/auth/environment_vars.py index e5f3598e8..5da3a7382 100644 --- a/google/auth/environment_vars.py +++ b/google/auth/environment_vars.py @@ -60,6 +60,12 @@ """Environment variable providing an alternate ip:port to be used for ip-only GCE metadata requests.""" +GCE_METADATA_MTLS_MODE = "GCE_METADATA_MTLS_MODE" +"""Environment variable controlling the mTLS behavior for GCE metadata requests. + +Can be one of "strict", "none", or "default". +""" + GOOGLE_API_USE_CLIENT_CERTIFICATE = "GOOGLE_API_USE_CLIENT_CERTIFICATE" """Environment variable controlling whether to use client certificate or not. diff --git a/google/auth/external_account_authorized_user.py b/google/auth/external_account_authorized_user.py index f8fbf950b..2594e048f 100644 --- a/google/auth/external_account_authorized_user.py +++ b/google/auth/external_account_authorized_user.py @@ -321,6 +321,30 @@ def _build_trust_boundary_lookup_url(self): universe_domain=self._universe_domain, pool_id=pool_id ) + def revoke(self, request): + """Revokes the refresh token. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.OAuthError: If the token could not be + revoked. + """ + if not self._revoke_url or not self._refresh_token_val: + raise exceptions.OAuthError( + "The credentials do not contain the necessary fields to " + "revoke the refresh token. You must specify revoke_url and " + "refresh_token." + ) + + self._sts_client.revoke_token( + request, self._refresh_token_val, "refresh_token", self._revoke_url + ) + self.token = None + self._refresh_token = None + @_helpers.copy_docstring(credentials.Credentials) def get_cred_info(self): if self._cred_file_path: diff --git a/google/auth/impersonated_credentials.py b/google/auth/impersonated_credentials.py index 334573428..e2724382a 100644 --- a/google/auth/impersonated_credentials.py +++ b/google/auth/impersonated_credentials.py @@ -286,7 +286,7 @@ def _refresh_token(self, request): self._source_credentials.token_state == credentials.TokenState.STALE or self._source_credentials.token_state == credentials.TokenState.INVALID ): - self._source_credentials._refresh_token(request) + self._source_credentials.refresh(request) body = { "delegates": self._delegates, diff --git a/google/auth/jwt.py b/google/auth/jwt.py index 1ebd565d4..9b79f173b 100644 --- a/google/auth/jwt.py +++ b/google/auth/jwt.py @@ -59,17 +59,18 @@ import google.auth.credentials try: - from google.auth.crypt import es256 + from google.auth.crypt import es except ImportError: # pragma: NO COVER - es256 = None # type: ignore + es = None # type: ignore _DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds _DEFAULT_MAX_CACHE_SIZE = 10 _ALGORITHM_TO_VERIFIER_CLASS = {"RS256": crypt.RSAVerifier} -_CRYPTOGRAPHY_BASED_ALGORITHMS = frozenset(["ES256"]) +_CRYPTOGRAPHY_BASED_ALGORITHMS = frozenset(["ES256", "ES384"]) -if es256 is not None: # pragma: NO COVER - _ALGORITHM_TO_VERIFIER_CLASS["ES256"] = es256.ES256Verifier # type: ignore +if es is not None: # pragma: NO COVER + _ALGORITHM_TO_VERIFIER_CLASS["ES256"] = es.EsVerifier # type: ignore + _ALGORITHM_TO_VERIFIER_CLASS["ES384"] = es.EsVerifier # type: ignore def encode(signer, payload, header=None, key_id=None): @@ -95,8 +96,8 @@ def encode(signer, payload, header=None, key_id=None): header.update({"typ": "JWT"}) if "alg" not in header: - if es256 is not None and isinstance(signer, es256.ES256Signer): - header.update({"alg": "ES256"}) + if es is not None and isinstance(signer, es.EsSigner): + header.update({"alg": signer.algorithm}) else: header.update({"alg": "RS256"}) diff --git a/google/auth/pluggable.py b/google/auth/pluggable.py index fd349537d..730a72c28 100644 --- a/google/auth/pluggable.py +++ b/google/auth/pluggable.py @@ -37,6 +37,7 @@ from collections import Mapping # type: ignore import json import os +import shlex import subprocess import sys import time @@ -220,7 +221,7 @@ def retrieve_subject_token(self, request): exe_stderr = sys.stdout if self.interactive else subprocess.STDOUT result = subprocess.run( - self._credential_source_executable_command.split(), + shlex.split(self._credential_source_executable_command), timeout=exe_timeout, stdin=exe_stdin, stdout=exe_stdout, @@ -273,7 +274,7 @@ def revoke(self, request): # Run executable result = subprocess.run( - self._credential_source_executable_command.split(), + shlex.split(self._credential_source_executable_command), timeout=self._credential_source_executable_interactive_timeout_millis / 1000, stdout=subprocess.PIPE, diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py index 7740f2fe8..f5d6b6724 100644 --- a/google/auth/transport/_mtls_helper.py +++ b/google/auth/transport/_mtls_helper.py @@ -47,6 +47,20 @@ b"-----BEGIN PASSPHRASE-----(.+)-----END PASSPHRASE-----", re.DOTALL ) +# Temporary patch to accomodate incorrect cert config in Cloud Run prod environment. +_WELL_KNOWN_CLOUD_RUN_CERT_PATH = ( + "/var/run/secrets/workload-spiffe-credentials/certificates.pem" +) +_WELL_KNOWN_CLOUD_RUN_KEY_PATH = ( + "/var/run/secrets/workload-spiffe-credentials/private_key.pem" +) +_INCORRECT_CLOUD_RUN_CERT_PATH = ( + "/var/lib/volumes/certificate/workload-certificates/certificates.pem" +) +_INCORRECT_CLOUD_RUN_KEY_PATH = ( + "/var/lib/volumes/certificate/workload-certificates/private_key.pem" +) + def _check_config_path(config_path): """Checks for config file path. If it exists, returns the absolute path with user expansion; @@ -183,6 +197,25 @@ def _get_workload_cert_and_key_paths(config_path): ) key_path = workload["key_path"] + # == BEGIN Temporary Cloud Run PATCH == + # See https://github.com/googleapis/google-auth-library-python/issues/1881 + if (cert_path == _INCORRECT_CLOUD_RUN_CERT_PATH) and ( + key_path == _INCORRECT_CLOUD_RUN_KEY_PATH + ): + if not path.exists(cert_path) and not path.exists(key_path): + _LOGGER.debug( + "Applying Cloud Run certificate path patch. " + "Configured paths not found: %s, %s. " + "Using well-known paths: %s, %s", + cert_path, + key_path, + _WELL_KNOWN_CLOUD_RUN_CERT_PATH, + _WELL_KNOWN_CLOUD_RUN_KEY_PATH, + ) + cert_path = _WELL_KNOWN_CLOUD_RUN_CERT_PATH + key_path = _WELL_KNOWN_CLOUD_RUN_KEY_PATH + # == END Temporary Cloud Run PATCH == + return cert_path, key_path @@ -279,7 +312,7 @@ def _run_cert_provider_command(command, expect_encrypted_key=False): def get_client_ssl_credentials( generate_encrypted_key=False, context_aware_metadata_path=CONTEXT_AWARE_METADATA_PATH, - certificate_config_path=CERTIFICATE_CONFIGURATION_DEFAULT_PATH, + certificate_config_path=None, ): """Returns the client side certificate, private key and passphrase. @@ -306,13 +339,10 @@ def get_client_ssl_credentials( the cert, key and passphrase. """ - # 1. Check for certificate config json. - cert_config_path = _check_config_path(certificate_config_path) - if cert_config_path: - # Attempt to retrieve X.509 Workload cert and key. - cert, key = _get_workload_cert_and_key(cert_config_path) - if cert and key: - return True, cert, key, None + # 1. Attempt to retrieve X.509 Workload cert and key. + cert, key = _get_workload_cert_and_key(certificate_config_path) + if cert and key: + return True, cert, key, None # 2. Check for context aware metadata json metadata_path = _check_config_path(context_aware_metadata_path) diff --git a/google/auth/version.py b/google/auth/version.py index 20f2c8c0a..80d1360d3 100644 --- a/google/auth/version.py +++ b/google/auth/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.43.0" +__version__ = "2.44.0" diff --git a/google/oauth2/sts.py b/google/oauth2/sts.py index ad3962735..60d6f83c4 100644 --- a/google/oauth2/sts.py +++ b/google/oauth2/sts.py @@ -57,7 +57,7 @@ def __init__(self, token_exchange_endpoint, client_authentication=None): super(Client, self).__init__(client_authentication) self._token_exchange_endpoint = token_exchange_endpoint - def _make_request(self, request, headers, request_body): + def _make_request(self, request, headers, request_body, url=None): # Initialize request headers. request_headers = _URLENCODED_HEADERS.copy() @@ -69,9 +69,12 @@ def _make_request(self, request, headers, request_body): # Apply OAuth client authentication. self.apply_client_authentication_options(request_headers, request_body) + # Use default token exchange endpoint if no url is provided. + url = url or self._token_exchange_endpoint + # Execute request. response = request( - url=self._token_exchange_endpoint, + url=url, method="POST", headers=request_headers, body=urllib.parse.urlencode(request_body).encode("utf-8"), @@ -87,10 +90,12 @@ def _make_request(self, request, headers, request_body): if response.status != http_client.OK: utils.handle_error_response(response_body) - response_data = json.loads(response_body) + # A successful token revocation returns an empty response body. + if not response_body: + return {} - # Return successful response. - return response_data + # Other successful responses should be valid JSON. + return json.loads(response_body) def exchange_token( self, @@ -174,3 +179,23 @@ def refresh_token(self, request, refresh_token): None, {"grant_type": "refresh_token", "refresh_token": refresh_token}, ) + + def revoke_token(self, request, token, token_type_hint, revoke_url): + """Revokes the provided token based on the RFC7009 spec. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token (str): The OAuth 2.0 token to revoke. + token_type_hint (str): Hint for the type of token being revoked. + revoke_url (str): The STS endpoint URL for revoking tokens. + + Raises: + google.auth.exceptions.OAuthError: If the token revocation endpoint + returned an error. + """ + request_body = {"token": token} + if token_type_hint: + request_body["token_type_hint"] = token_type_hint + + return self._make_request(request, None, request_body, revoke_url) diff --git a/noxfile.py b/noxfile.py index 728e8c7cc..11f677a3b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -36,7 +36,7 @@ DEFAULT_PYTHON_VERSION = "3.10" # TODO(https://github.com/googleapis/google-auth-library-python/issues/1787): # Remove or restore testing for Python 3.7/3.8 -UNIT_TEST_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13"] +UNIT_TEST_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] # Error if a python version is missing nox.options.error_on_missing_interpreters = True @@ -53,6 +53,7 @@ "unit-3.11", "unit-3.12", "unit-3.13", + "unit-3.14", # cover must be last to avoid error `No data to report` "cover", "docs", diff --git a/samples/cloud-client/snippets/custom_aws_supplier.py b/samples/cloud-client/snippets/custom_aws_supplier.py new file mode 100644 index 000000000..ec5bf8a10 --- /dev/null +++ b/samples/cloud-client/snippets/custom_aws_supplier.py @@ -0,0 +1,117 @@ +# Copyright 2025 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import sys + +import boto3 +from dotenv import load_dotenv +from google.auth.aws import Credentials as AwsCredentials +from google.auth.aws import AwsSecurityCredentials, AwsSecurityCredentialsSupplier +from google.auth.exceptions import GoogleAuthError +from google.auth.transport.requests import AuthorizedSession + +load_dotenv() + + +class CustomAwsSupplier(AwsSecurityCredentialsSupplier): + """Custom AWS Security Credentials Supplier.""" + + def __init__(self): + """Initializes the Boto3 session, prioritizing environment variables for region.""" + # Explicitly read the region from the environment first. This ensures that + # a value from a .env file is picked up reliably for local testing. + region = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") + + # If region is None, Boto3's discovery chain will be used when needed. + self.session = boto3.Session(region_name=region) + self._cached_region = None + print(f"[INFO] CustomAwsSupplier initialized. Region from env: {region}") + + def get_aws_region(self, context, request) -> str: + """Returns the AWS region using Boto3's default provider chain.""" + if self._cached_region: + return self._cached_region + + # Accessing region_name will use the value from the constructor if provided, + # otherwise it triggers Boto3's lazy-loading discovery (e.g., metadata service). + self._cached_region = self.session.region_name + + if not self._cached_region: + print("[ERROR] Boto3 was unable to resolve an AWS region.", file=sys.stderr) + raise GoogleAuthError("Boto3 was unable to resolve an AWS region.") + + print(f"[INFO] Boto3 resolved AWS Region: {self._cached_region}") + return self._cached_region + + def get_aws_security_credentials(self, context, request=None) -> AwsSecurityCredentials: + """Retrieves AWS security credentials using Boto3's default provider chain.""" + aws_credentials = self.session.get_credentials() + if not aws_credentials: + print("[ERROR] Unable to resolve AWS credentials.", file=sys.stderr) + raise GoogleAuthError("Unable to resolve AWS credentials from the provider chain.") + + print(f"[INFO] Resolved AWS Access Key ID: {aws_credentials.access_key}") + + return AwsSecurityCredentials( + access_key_id=aws_credentials.access_key, + secret_access_key=aws_credentials.secret_key, + session_token=aws_credentials.token, + ) + + +def main(): + """Main function to demonstrate the custom AWS supplier.""" + print("--- Starting Script ---") + + gcp_audience = os.getenv("GCP_WORKLOAD_AUDIENCE") + sa_impersonation_url = os.getenv("GCP_SERVICE_ACCOUNT_IMPERSONATION_URL") + gcs_bucket_name = os.getenv("GCS_BUCKET_NAME") + + print(f"GCP_WORKLOAD_AUDIENCE: {gcp_audience}") + print(f"GCS_BUCKET_NAME: {gcs_bucket_name}") + + if not all([gcp_audience, sa_impersonation_url, gcs_bucket_name]): + print("[ERROR] Missing required environment variables.", file=sys.stderr) + raise GoogleAuthError("Missing required environment variables.") + + custom_supplier = CustomAwsSupplier() + + credentials = AwsCredentials( + audience=gcp_audience, + subject_token_type="urn:ietf:params:aws:token-type:aws4_request", + service_account_impersonation_url=sa_impersonation_url, + aws_security_credentials_supplier=custom_supplier, + scopes=['https://www.googleapis.com/auth/devstorage.read_write'], + ) + + bucket_url = f"https://storage.googleapis.com/storage/v1/b/{gcs_bucket_name}" + print(f"Request URL: {bucket_url}") + + authed_session = AuthorizedSession(credentials) + try: + print("Attempting to make authenticated request to Google Cloud Storage...") + res = authed_session.get(bucket_url) + res.raise_for_status() + print("\n--- SUCCESS! ---") + print("Successfully authenticated and retrieved bucket data:") + print(json.dumps(res.json(), indent=2)) + except Exception as e: + print("--- FAILED --- ", file=sys.stderr) + print(e, file=sys.stderr) + exit(1) + + +if __name__ == "__main__": + main() diff --git a/samples/cloud-client/snippets/custom_okta_supplier.py b/samples/cloud-client/snippets/custom_okta_supplier.py new file mode 100644 index 000000000..12f83dcfa --- /dev/null +++ b/samples/cloud-client/snippets/custom_okta_supplier.py @@ -0,0 +1,190 @@ +# Copyright 2025 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import urllib.parse +import os +import time + +import requests +from dotenv import load_dotenv +from google.auth.exceptions import GoogleAuthError +from google.auth.identity_pool import Credentials as IdentityPoolClient +from google.auth.transport.requests import AuthorizedSession + +load_dotenv() + +# Workload Identity Pool Configuration +GCP_WORKLOAD_AUDIENCE = os.getenv("GCP_WORKLOAD_AUDIENCE") +SERVICE_ACCOUNT_IMPERSONATION_URL = os.getenv("GCP_SERVICE_ACCOUNT_IMPERSONATION_URL") +GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME") + +# Okta Configuration +OKTA_DOMAIN = os.getenv("OKTA_DOMAIN") +OKTA_CLIENT_ID = os.getenv("OKTA_CLIENT_ID") +OKTA_CLIENT_SECRET = os.getenv("OKTA_CLIENT_SECRET") + +# Constants +TOKEN_URL = "https://sts.googleapis.com/v1/token" +SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" + + +class OktaClientCredentialsSupplier: + """A custom SubjectTokenSupplier that authenticates with Okta. + + This supplier uses the Client Credentials grant flow for machine-to-machine + (M2M) authentication with Okta. + """ + + def __init__(self, domain, client_id, client_secret): + self.okta_token_url = f"{domain}/oauth2/default/v1/token" + self.client_id = client_id + self.client_secret = client_secret + self.access_token = None + self.expiry_time = 0 + print("OktaClientCredentialsSupplier initialized.") + + def get_subject_token(self, context, request=None) -> str: + """Fetches a new token if the current one is expired or missing. + + Args: + context: The context object, not used in this implementation. + + Returns: + The Okta Access token. + """ + # Check if the current token is still valid (with a 60-second buffer). + is_token_valid = self.access_token and time.time() < self.expiry_time - 60 + + if is_token_valid: + print("[Supplier] Returning cached Okta Access token.") + return self.access_token + + print( + "[Supplier] Token is missing or expired. Fetching new Okta Access token..." + ) + self._fetch_okta_access_token() + return self.access_token + + def _fetch_okta_access_token(self): + """Performs the Client Credentials grant flow with Okta.""" + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + data = { + "grant_type": "client_credentials", + "scope": "gcp.test.read", + } + encoded_data = urllib.parse.urlencode(data) + + try: + response = requests.post( + self.okta_token_url, + headers=headers, + data=encoded_data, + auth=(self.client_id, self.client_secret), + ) + response.raise_for_status() + token_data = response.json() + + if "access_token" in token_data and "expires_in" in token_data: + self.access_token = token_data["access_token"] + self.expiry_time = time.time() + token_data["expires_in"] + print( + f"[Supplier] Successfully received Access Token from Okta. " + f"Expires in {token_data['expires_in']} seconds." + ) + else: + raise GoogleAuthError( + "Access token or expires_in not found in Okta response." + ) + except requests.exceptions.RequestException as e: + print(f"[Supplier] Error fetching token from Okta: {e}") + if e.response: + print(f"[Supplier] Okta response: {e.response.text}") + raise GoogleAuthError( + "Failed to authenticate with Okta using Client Credentials grant." + ) from e + + +def main(): + """Main function to demonstrate the custom Okta supplier. + + TODO(Developer): + 1. Before running this sample, set up your environment variables. You can do + this by creating a .env file in the same directory as this script and + populating it with the following variables: + - GCP_WORKLOAD_AUDIENCE: The audience for the GCP workload identity pool. + - GCP_SERVICE_ACCOUNT_IMPERSONATION_URL: The URL for service account impersonation (optional). + - GCS_BUCKET_NAME: The name of the GCS bucket to access. + - OKTA_DOMAIN: Your Okta domain (e.g., https://dev-12345.okta.com). + - OKTA_CLIENT_ID: The Client ID of your Okta M2M application. + - OKTA_CLIENT_SECRET: The Client Secret of your Okta M2M application. + """ + if not all( + [ + GCP_WORKLOAD_AUDIENCE, + GCS_BUCKET_NAME, + OKTA_DOMAIN, + OKTA_CLIENT_ID, + OKTA_CLIENT_SECRET, + ] + ): + raise GoogleAuthError( + "Missing required environment variables. Please check your .env file." + ) + + # 1. Instantiate the custom supplier with Okta credentials. + okta_supplier = OktaClientCredentialsSupplier( + OKTA_DOMAIN, OKTA_CLIENT_ID, OKTA_CLIENT_SECRET + ) + + # 2. Instantiate an IdentityPoolClient. + client = IdentityPoolClient( + audience=GCP_WORKLOAD_AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + subject_token_supplier=okta_supplier, + # If you choose to provide explicit scopes: use the `scopes` parameter. + default_scopes=['https://www.googleapis.com/auth/cloud-platform'], + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + ) + + # 3. Construct the URL for the Cloud Storage JSON API. + bucket_url = f"https://storage.googleapis.com/storage/v1/b/{GCS_BUCKET_NAME}" + print(f"[Test] Getting metadata for bucket: {GCS_BUCKET_NAME}...") + print(f"[Test] Request URL: {bucket_url}") + + # 4. Use the client to make an authenticated request. + authed_session = AuthorizedSession(client) + try: + res = authed_session.get(bucket_url) + res.raise_for_status() + print("\n--- SUCCESS! ---") + print("Successfully authenticated and retrieved bucket data:") + print(json.dumps(res.json(), indent=2)) + except requests.exceptions.RequestException as e: + print("\n--- FAILED ---") + print(f"Request failed: {e}") + if e.response: + print(f"Response: {e.response.text}") + exit(1) + except GoogleAuthError as e: + print("\n--- FAILED ---") + print(f"Authentication or request failed: {e}") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/samples/cloud-client/snippets/noxfile.py b/samples/cloud-client/snippets/noxfile.py index c21466d4f..3cdf3cf3b 100644 --- a/samples/cloud-client/snippets/noxfile.py +++ b/samples/cloud-client/snippets/noxfile.py @@ -60,7 +60,7 @@ ] -@nox.session(python=["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]) +@nox.session(python=["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]) def unit(session): # constraints_path = str( # CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" diff --git a/samples/cloud-client/snippets/requirements.txt b/samples/cloud-client/snippets/requirements.txt index 416c56b94..b5c5cea30 100644 --- a/samples/cloud-client/snippets/requirements.txt +++ b/samples/cloud-client/snippets/requirements.txt @@ -1,4 +1,7 @@ google-cloud-compute==1.5.1 google-cloud-storage==3.1.0 -google-auth==2.38.0 -pytest==7.1.2 +google-auth==2.41.1 +pytest==8.4.2 +boto3>=1.26.0 +requests==2.32.3 +python-dotenv==1.1.1 diff --git a/scripts/decrypt-secrets.sh b/scripts/decrypt-secrets.sh index f0ef994ed..7e7f03bdc 100755 --- a/scripts/decrypt-secrets.sh +++ b/scripts/decrypt-secrets.sh @@ -20,6 +20,10 @@ ROOT=$( dirname "$DIR" ) # Work from the project root. cd $ROOT +# Create working directory if not exists. system_tests/data is not tracked by +# Git to prevent the secrets from being leaked online. +mkdir -p system_tests/data + gcloud kms decrypt \ --location=global \ --keyring=ci \ diff --git a/scripts/encrypt-secrets.sh b/scripts/encrypt-secrets.sh index b6521e8f5..fba27fba0 100755 --- a/scripts/encrypt-secrets.sh +++ b/scripts/encrypt-secrets.sh @@ -29,4 +29,6 @@ gcloud kms encrypt \ --plaintext-file=system_tests/secrets.tar \ --ciphertext-file=system_tests/secrets.tar.enc -rm system_tests/secrets.tar \ No newline at end of file +rm system_tests/secrets.tar + +rm system_tests/data \ No newline at end of file diff --git a/setup.py b/setup.py index 20f79ce66..014b32a95 100644 --- a/setup.py +++ b/setup.py @@ -129,6 +129,7 @@ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", diff --git a/system_tests/secrets.tar.enc b/system_tests/secrets.tar.enc index b45028153..d89763db0 100644 Binary files a/system_tests/secrets.tar.enc and b/system_tests/secrets.tar.enc differ diff --git a/system_tests/system_tests_async/test_default.py b/system_tests/system_tests_async/test_default.py index 32299c059..dfffba28f 100644 --- a/system_tests/system_tests_async/test_default.py +++ b/system_tests/system_tests_async/test_default.py @@ -16,8 +16,11 @@ import pytest from google.auth import _default_async +from google.auth.exceptions import RefreshError + +EXPECT_PROJECT_ID = os.getenv("EXPECT_PROJECT_ID") +CREDENTIALS = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "") -EXPECT_PROJECT_ID = os.environ.get("EXPECT_PROJECT_ID") @pytest.mark.asyncio async def test_application_default_credentials(verify_refresh): @@ -26,4 +29,10 @@ async def test_application_default_credentials(verify_refresh): if EXPECT_PROJECT_ID is not None: assert project_id is not None - await verify_refresh(credentials) + try: + await verify_refresh(credentials) + except RefreshError as e: + # allow expired credentials for explicit_authorized_user tests + # TODO: https://github.com/googleapis/google-auth-library-python/issues/1882 + if not CREDENTIALS.endswith("authorized_user.json") or "Token has been expired or revoked" not in str(e): + raise diff --git a/system_tests/system_tests_sync/test_default.py b/system_tests/system_tests_sync/test_default.py index 560ab3284..322c57b62 100644 --- a/system_tests/system_tests_sync/test_default.py +++ b/system_tests/system_tests_sync/test_default.py @@ -15,8 +15,10 @@ import os import google.auth +from google.auth.exceptions import RefreshError -EXPECT_PROJECT_ID = os.environ.get("EXPECT_PROJECT_ID") +EXPECT_PROJECT_ID = os.getenv("EXPECT_PROJECT_ID") +CREDENTIALS = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "") def test_application_default_credentials(verify_refresh): @@ -25,4 +27,10 @@ def test_application_default_credentials(verify_refresh): if EXPECT_PROJECT_ID is not None: assert project_id is not None - verify_refresh(credentials) + try: + verify_refresh(credentials) + except RefreshError as e: + # allow expired credentials for explicit_authorized_user tests + # TODO: https://github.com/googleapis/google-auth-library-python/issues/1882 + if not CREDENTIALS.endswith("authorized_user.json") or "Token has been expired or revoked" not in str(e): + raise diff --git a/testing/constraints-3.14.txt b/testing/constraints-3.14.txt new file mode 100644 index 000000000..e69de29bb diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index c90bc603a..adb63f667 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -20,12 +20,14 @@ import mock import pytest # type: ignore +import requests from google.auth import _helpers from google.auth import environment_vars from google.auth import exceptions from google.auth import transport from google.auth.compute_engine import _metadata +from google.auth.transport import requests as google_auth_requests PATH = "instance/service-accounts/default" @@ -104,7 +106,7 @@ def test_ping_success(mock_metrics_header_value): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_IP_ROOT, + url="http://169.254.169.254", headers=MDS_PING_REQUEST_HEADER, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -118,7 +120,7 @@ def test_ping_success_retry(mock_metrics_header_value): request.assert_called_with( method="GET", - url=_metadata._METADATA_IP_ROOT, + url="http://169.254.169.254", headers=MDS_PING_REQUEST_HEADER, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -172,7 +174,7 @@ def test_get_success_json(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -191,7 +193,7 @@ def test_get_success_json_content_type_charset(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -211,7 +213,7 @@ def test_get_success_retry(mock_sleep): request.assert_called_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -227,7 +229,7 @@ def test_get_success_text(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -243,7 +245,9 @@ def test_get_success_params(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "?recursive=true", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "?recursive=true", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -258,7 +262,9 @@ def test_get_success_recursive_and_params(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "?recursive=true", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "?recursive=true", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -273,7 +279,9 @@ def test_get_success_recursive(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "?recursive=true", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "?recursive=true", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -322,6 +330,21 @@ def test_get_success_custom_root_old_variable(): ) +def test_get_success_custom_root(): + request = make_request("{}", headers={"content-type": "application/json"}) + + fake_root = "http://another.metadata.service" + + _metadata.get(request, PATH, root=fake_root) + + request.assert_called_once_with( + method="GET", + url="{}/{}".format(fake_root, PATH), + headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, + ) + + @mock.patch("time.sleep", return_value=None) def test_get_failure(mock_sleep): request = make_request("Metadata error", status=http_client.NOT_FOUND) @@ -333,7 +356,7 @@ def test_get_failure(mock_sleep): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -346,7 +369,7 @@ def test_get_return_none_for_not_found_error(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -366,7 +389,7 @@ def test_get_failure_connection_failed(mock_sleep): request.assert_called_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -385,7 +408,7 @@ def test_get_too_many_requests_retryable_error_failure(): request.assert_called_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -402,7 +425,7 @@ def test_get_failure_bad_json(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH, + url="http://metadata.google.internal/computeMetadata/v1/" + PATH, headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -416,7 +439,7 @@ def test_get_project_id(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "project/project-id", + url="http://metadata.google.internal/computeMetadata/v1/project/project-id", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -432,7 +455,7 @@ def test_get_universe_domain_success(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -446,7 +469,7 @@ def test_get_universe_domain_success_empty_response(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -462,7 +485,7 @@ def test_get_universe_domain_not_found(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -483,7 +506,7 @@ def test_get_universe_domain_retryable_error_failure(): request.assert_called_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -526,13 +549,13 @@ def request(self, *args, **kwargs): request_error.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) request_ok.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -552,7 +575,7 @@ def test_get_universe_domain_other_error(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + "universe/universe-domain", + url="http://metadata.google.internal/computeMetadata/v1/universe/universe-domain", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -574,7 +597,7 @@ def test_get_service_account_token(utcnow, mock_metrics_header_value): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "/token", + url="http://metadata.google.internal/computeMetadata/v1/" + PATH + "/token", headers={ "metadata-flavor": "Google", "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, @@ -601,7 +624,10 @@ def test_get_service_account_token_with_scopes_list(utcnow, mock_metrics_header_ request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "/token" + "?scopes=foo%2Cbar", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "/token" + + "?scopes=foo%2Cbar", headers={ "metadata-flavor": "Google", "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, @@ -630,7 +656,10 @@ def test_get_service_account_token_with_scopes_string( request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "/token" + "?scopes=foo%2Cbar", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "/token" + + "?scopes=foo%2Cbar", headers={ "metadata-flavor": "Google", "x-goog-api-client": ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, @@ -651,9 +680,144 @@ def test_get_service_account_info(): request.assert_called_once_with( method="GET", - url=_metadata._METADATA_ROOT + PATH + "/?recursive=true", + url="http://metadata.google.internal/computeMetadata/v1/" + + PATH + + "/?recursive=true", headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) assert info[key] == value + + +def test__get_metadata_root_mtls(): + assert ( + _metadata._get_metadata_root(use_mtls=True) + == "https://metadata.google.internal/computeMetadata/v1/" + ) + + +def test__get_metadata_root_no_mtls(): + assert ( + _metadata._get_metadata_root(use_mtls=False) + == "http://metadata.google.internal/computeMetadata/v1/" + ) + + +def test__get_metadata_ip_root_mtls(): + assert _metadata._get_metadata_ip_root(use_mtls=True) == "https://169.254.169.254" + + +def test__get_metadata_ip_root_no_mtls(): + assert _metadata._get_metadata_ip_root(use_mtls=False) == "http://169.254.169.254" + + +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +def test__prepare_request_for_mds_mtls(mock_mds_mtls_adapter): + request = google_auth_requests.Request(mock.create_autospec(requests.Session)) + _metadata._prepare_request_for_mds(request, use_mtls=True) + mock_mds_mtls_adapter.assert_called_once() + assert request.session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) + + +def test__prepare_request_for_mds_no_mtls(): + request = mock.Mock() + _metadata._prepare_request_for_mds(request, use_mtls=False) + request.session.mount.assert_not_called() + + +@mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE) +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True) +@mock.patch("google.auth.transport.requests.Request") +def test_ping_mtls( + mock_request, mock_should_use_mtls, mock_mds_mtls_adapter, mock_metrics_header_value +): + response = mock.create_autospec(transport.Response, instance=True) + response.status = http_client.OK + response.headers = _metadata._METADATA_HEADERS + mock_request.return_value = response + + assert _metadata.ping(mock_request) + + mock_should_use_mtls.assert_called_once() + mock_mds_mtls_adapter.assert_called_once() + mock_request.assert_called_once_with( + url="https://169.254.169.254", + method="GET", + headers=MDS_PING_REQUEST_HEADER, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, + ) + + +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True) +@mock.patch("google.auth.transport.requests.Request") +def test_get_mtls(mock_request, mock_should_use_mtls, mock_mds_mtls_adapter): + response = mock.create_autospec(transport.Response, instance=True) + response.status = http_client.OK + response.data = _helpers.to_bytes("{}") + response.headers = {"content-type": "application/json"} + mock_request.return_value = response + + _metadata.get(mock_request, "some/path") + + mock_should_use_mtls.assert_called_once() + mock_mds_mtls_adapter.assert_called_once() + mock_request.assert_called_once_with( + url="https://metadata.google.internal/computeMetadata/v1/some/path", + method="GET", + headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_DEFAULT_TIMEOUT, + ) + + +@pytest.mark.parametrize( + "mds_mode, metadata_host, expect_exception", + [ + (_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_HOST, False), + (_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_MDS_IP, False), + (_metadata._mtls.MdsMtlsMode.STRICT, "custom.host", True), + (_metadata._mtls.MdsMtlsMode.NONE, "custom.host", False), + (_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_HOST, False), + (_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_MDS_IP, False), + ], +) +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +def test_validate_gce_mds_configured_environment( + mock_parse_mds_mode, mds_mode, metadata_host, expect_exception +): + mock_parse_mds_mode.return_value = mds_mode + with mock.patch( + "google.auth.compute_engine._metadata._GCE_METADATA_HOST", new=metadata_host + ): + if expect_exception: + with pytest.raises(exceptions.MutualTLSChannelError): + _metadata._validate_gce_mds_configured_environment() + else: + _metadata._validate_gce_mds_configured_environment() + mock_parse_mds_mode.assert_called_once() + + +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +def test__prepare_request_for_mds_mtls_session_exists(mock_mds_mtls_adapter): + mock_session = mock.create_autospec(requests.Session) + request = google_auth_requests.Request(mock_session) + _metadata._prepare_request_for_mds(request, use_mtls=True) + + mock_mds_mtls_adapter.assert_called_once() + assert mock_session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) + + +@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") +def test__prepare_request_for_mds_mtls_no_session(mock_mds_mtls_adapter): + request = google_auth_requests.Request(None) + # Explicitly set session to None to avoid a session being created in the Request constructor. + request.session = None + + with mock.patch("requests.Session") as mock_session_class: + _metadata._prepare_request_for_mds(request, use_mtls=True) + + mock_session_class.assert_called_once() + mock_mds_mtls_adapter.assert_called_once() + assert request.session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) diff --git a/tests/compute_engine/test__mtls.py b/tests/compute_engine/test__mtls.py new file mode 100644 index 000000000..fdd61a07d --- /dev/null +++ b/tests/compute_engine/test__mtls.py @@ -0,0 +1,288 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pathlib import Path + +import mock +import pytest # type: ignore +import requests + +from google.auth import environment_vars, exceptions +from google.auth.compute_engine import _mtls + + +@pytest.fixture +def mock_mds_mtls_config(): + return _mtls.MdsMtlsConfig( + ca_cert_path=Path("/fake/ca.crt"), + client_combined_cert_path=Path("/fake/client.key"), + ) + + +@mock.patch("os.name", "nt") +def test__MdsMtlsConfig_windows_defaults(): + config = _mtls.MdsMtlsConfig() + assert ( + str(config.ca_cert_path) + == "C:/ProgramData/Google/ComputeEngine/mds-mtls-root.crt" + ) + assert ( + str(config.client_combined_cert_path) + == "C:/ProgramData/Google/ComputeEngine/mds-mtls-client.key" + ) + + +@mock.patch("os.name", "posix") +def test__MdsMtlsConfig_non_windows_defaults(): + config = _mtls.MdsMtlsConfig() + assert str(config.ca_cert_path) == "/run/google-mds-mtls/root.crt" + assert str(config.client_combined_cert_path) == "/run/google-mds-mtls/client.key" + + +def test__parse_mds_mode_default(monkeypatch): + monkeypatch.delenv(environment_vars.GCE_METADATA_MTLS_MODE, raising=False) + assert _mtls._parse_mds_mode() == _mtls.MdsMtlsMode.DEFAULT + + +@pytest.mark.parametrize( + "mode_str, expected_mode", + [ + ("strict", _mtls.MdsMtlsMode.STRICT), + ("none", _mtls.MdsMtlsMode.NONE), + ("default", _mtls.MdsMtlsMode.DEFAULT), + ("STRICT", _mtls.MdsMtlsMode.STRICT), + ], +) +def test__parse_mds_mode_valid(monkeypatch, mode_str, expected_mode): + monkeypatch.setenv(environment_vars.GCE_METADATA_MTLS_MODE, mode_str) + assert _mtls._parse_mds_mode() == expected_mode + + +def test__parse_mds_mode_invalid(monkeypatch): + monkeypatch.setenv(environment_vars.GCE_METADATA_MTLS_MODE, "invalid_mode") + with pytest.raises(ValueError): + _mtls._parse_mds_mode() + + +@mock.patch("os.path.exists") +def test__certs_exist_true(mock_exists, mock_mds_mtls_config): + mock_exists.return_value = True + assert _mtls._certs_exist(mock_mds_mtls_config) is True + + +@mock.patch("os.path.exists") +def test__certs_exist_false(mock_exists, mock_mds_mtls_config): + mock_exists.return_value = False + assert _mtls._certs_exist(mock_mds_mtls_config) is False + + +@pytest.mark.parametrize( + "mtls_mode, certs_exist, expected_result", + [ + ("strict", True, True), + ("strict", False, exceptions.MutualTLSChannelError), + ("none", True, False), + ("none", False, False), + ("default", True, True), + ("default", False, False), + ], +) +@mock.patch("os.path.exists") +def test_should_use_mds_mtls( + mock_exists, monkeypatch, mtls_mode, certs_exist, expected_result +): + monkeypatch.setenv(environment_vars.GCE_METADATA_MTLS_MODE, mtls_mode) + mock_exists.return_value = certs_exist + + if isinstance(expected_result, type) and issubclass(expected_result, Exception): + with pytest.raises(expected_result): + _mtls.should_use_mds_mtls() + else: + assert _mtls.should_use_mds_mtls() is expected_result + + +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_init(mock_ssl_context, mock_mds_mtls_config): + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + mock_ssl_context.assert_called_once() + adapter.ssl_context.load_verify_locations.assert_called_once_with( + cafile=mock_mds_mtls_config.ca_cert_path + ) + adapter.ssl_context.load_cert_chain.assert_called_once_with( + certfile=mock_mds_mtls_config.client_combined_cert_path + ) + + +@mock.patch("ssl.create_default_context") +@mock.patch("requests.adapters.HTTPAdapter.init_poolmanager") +def test_mds_mtls_adapter_init_poolmanager( + mock_init_poolmanager, mock_ssl_context, mock_mds_mtls_config +): + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + mock_init_poolmanager.assert_called_with( + 10, 10, block=False, ssl_context=adapter.ssl_context + ) + + +@mock.patch("ssl.create_default_context") +@mock.patch("requests.adapters.HTTPAdapter.proxy_manager_for") +def test_mds_mtls_adapter_proxy_manager_for( + mock_proxy_manager_for, mock_ssl_context, mock_mds_mtls_config +): + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + adapter.proxy_manager_for("test_proxy") + mock_proxy_manager_for.assert_called_once_with( + "test_proxy", ssl_context=adapter.ssl_context + ) + + +@mock.patch("requests.adapters.HTTPAdapter.send") # Patch the PARENT class method +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_session_request( + mock_ssl_context, mock_super_send, mock_mds_mtls_config +): + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + session = requests.Session() + session.mount("https://", adapter) + + # Setup the parent class send return value + response = requests.Response() + response.status_code = 200 + mock_super_send.return_value = response + + response = session.get("https://fake-mds.com") + + # Assert that the request was successful + assert response.status_code == 200 + mock_super_send.assert_called_once() + + +@mock.patch("requests.adapters.HTTPAdapter.send") +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_success( + mock_ssl_context, mock_parse_mds_mode, mock_super_send, mock_mds_mtls_config +): + """Test the explicit 'happy path' where mTLS succeeds without error.""" + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + # Setup the parent class send return value to be successful (200 OK) + mock_response = requests.Response() + mock_response.status_code = 200 + mock_super_send.return_value = mock_response + + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() + + # Call send directly + response = adapter.send(request) + + # Verify we got the response back and no fallback happened + assert response == mock_response + mock_super_send.assert_called_once() + + +@mock.patch("google.auth.compute_engine._mtls.HTTPAdapter") +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_fallback_default_mode( + mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_class, mock_mds_mtls_config +): + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + mock_fallback_send = mock.Mock() + mock_http_adapter_class.return_value.send = mock_fallback_send + + # Simulate SSLError on the super().send() call + with mock.patch( + "requests.adapters.HTTPAdapter.send", side_effect=requests.exceptions.SSLError + ): + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() + adapter.send(request) + + # Check that fallback to HTTPAdapter.send occurred + mock_http_adapter_class.assert_called_once() + mock_fallback_send.assert_called_once() + fallback_request = mock_fallback_send.call_args[0][0] + assert fallback_request.url == "http://fake-mds.com/" + + +@mock.patch("google.auth.compute_engine._mtls.HTTPAdapter") +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_fallback_http_error( + mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_class, mock_mds_mtls_config +): + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + mock_fallback_send = mock.Mock() + mock_http_adapter_class.return_value.send = mock_fallback_send + + # Simulate HTTPError on the super().send() call + mock_mtls_response = requests.Response() + mock_mtls_response.status_code = 404 + with mock.patch( + "requests.adapters.HTTPAdapter.send", return_value=mock_mtls_response + ): + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() + adapter.send(request) + + # Check that fallback to HTTPAdapter.send occurred + mock_http_adapter_class.assert_called_once() + mock_fallback_send.assert_called_once() + fallback_request = mock_fallback_send.call_args[0][0] + assert fallback_request.url == "http://fake-mds.com/" + + +@mock.patch("requests.adapters.HTTPAdapter.send") +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_no_fallback_other_exception( + mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_send, mock_mds_mtls_config +): + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + # Simulate HTTP exception + with mock.patch( + "requests.adapters.HTTPAdapter.send", + side_effect=requests.exceptions.ConnectionError, + ): + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() + with pytest.raises(requests.exceptions.ConnectionError): + adapter.send(request) + + mock_http_adapter_send.assert_not_called() + + +@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("ssl.create_default_context") +def test_mds_mtls_adapter_send_no_fallback_strict_mode( + mock_ssl_context, mock_parse_mds_mode, mock_mds_mtls_config +): + mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.STRICT + adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) + + # Simulate SSLError on the super().send() call + with mock.patch( + "requests.adapters.HTTPAdapter.send", side_effect=requests.exceptions.SSLError + ): + request = requests.Request(method="GET", url="https://fake-mds.com").prepare() + with pytest.raises(requests.exceptions.SSLError): + adapter.send(request) diff --git a/tests/crypt/test_es.py b/tests/crypt/test_es.py new file mode 100644 index 000000000..3a62c1413 --- /dev/null +++ b/tests/crypt/test_es.py @@ -0,0 +1,173 @@ +# Copyright 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import os +import pickle + +from cryptography.hazmat.primitives.asymmetric import ec +import pytest # type: ignore + +from google.auth import _helpers +from google.auth.crypt import base +from google.auth.crypt import es + + +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") + +# To generate es384_privatekey.pem, es384_privatekey.pub, and +# es384_public_cert.pem: +# $ openssl ecparam -genkey -name secp384r1 -noout -out es384_privatekey.pem +# $ openssl ec -in es384_privatekey.pem -pubout -out es384_publickey.pem +# $ openssl req -new -x509 -key es384_privatekey.pem -out \ +# > es384_public_cert.pem + +with open(os.path.join(DATA_DIR, "es384_privatekey.pem"), "rb") as fh: + PRIVATE_KEY_BYTES = fh.read() + PKCS1_KEY_BYTES = PRIVATE_KEY_BYTES + +with open(os.path.join(DATA_DIR, "es384_publickey.pem"), "rb") as fh: + PUBLIC_KEY_BYTES = fh.read() + +with open(os.path.join(DATA_DIR, "es384_public_cert.pem"), "rb") as fh: + PUBLIC_CERT_BYTES = fh.read() + +# RSA keys used to test for type errors in EsVerifier and EsSigner. +with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: + RSA_PRIVATE_KEY_BYTES = fh.read() + RSA_PKCS1_KEY_BYTES = RSA_PRIVATE_KEY_BYTES + +with open(os.path.join(DATA_DIR, "privatekey.pub"), "rb") as fh: + RSA_PUBLIC_KEY_BYTES = fh.read() + +SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "es384_service_account.json") + +with open(SERVICE_ACCOUNT_JSON_FILE, "rb") as fh: + SERVICE_ACCOUNT_INFO = json.load(fh) + + +class TestEsVerifier(object): + def test_verify_success(self): + to_sign = b"foo" + signer = es.EsSigner.from_string(PRIVATE_KEY_BYTES) + actual_signature = signer.sign(to_sign) + + verifier = es.EsVerifier.from_string(PUBLIC_KEY_BYTES) + assert verifier.verify(to_sign, actual_signature) + + def test_verify_unicode_success(self): + to_sign = u"foo" + signer = es.EsSigner.from_string(PRIVATE_KEY_BYTES) + actual_signature = signer.sign(to_sign) + + verifier = es.EsVerifier.from_string(PUBLIC_KEY_BYTES) + assert verifier.verify(to_sign, actual_signature) + + def test_verify_failure(self): + verifier = es.EsVerifier.from_string(PUBLIC_KEY_BYTES) + bad_signature1 = b"" + assert not verifier.verify(b"foo", bad_signature1) + bad_signature2 = b"a" + assert not verifier.verify(b"foo", bad_signature2) + + def test_verify_failure_with_wrong_raw_signature(self): + to_sign = b"foo" + + # This signature has a wrong "r" value in the "(r,s)" raw signature. + wrong_signature = base64.urlsafe_b64decode( + b"m7oaRxUDeYqjZ8qiMwo0PZLTMZWKJLFQREpqce1StMIa_yXQQ-C5WgeIRHW7OqlYSDL0XbUrj_uAw9i-QhfOJQ==" + ) + + verifier = es.EsVerifier.from_string(PUBLIC_KEY_BYTES) + assert not verifier.verify(to_sign, wrong_signature) + + def test_from_string_pub_key(self): + verifier = es.EsVerifier.from_string(PUBLIC_KEY_BYTES) + assert isinstance(verifier, es.EsVerifier) + assert isinstance(verifier._pubkey, ec.EllipticCurvePublicKey) + + def test_from_string_pub_key_unicode(self): + public_key = _helpers.from_bytes(PUBLIC_KEY_BYTES) + verifier = es.EsVerifier.from_string(public_key) + assert isinstance(verifier, es.EsVerifier) + assert isinstance(verifier._pubkey, ec.EllipticCurvePublicKey) + + def test_from_string_pub_cert(self): + verifier = es.EsVerifier.from_string(PUBLIC_CERT_BYTES) + assert isinstance(verifier, es.EsVerifier) + assert isinstance(verifier._pubkey, ec.EllipticCurvePublicKey) + + def test_from_string_pub_cert_unicode(self): + public_cert = _helpers.from_bytes(PUBLIC_CERT_BYTES) + verifier = es.EsVerifier.from_string(public_cert) + assert isinstance(verifier, es.EsVerifier) + assert isinstance(verifier._pubkey, ec.EllipticCurvePublicKey) + + def test_from_string_type_error(self): + with pytest.raises(TypeError): + es.EsVerifier.from_string(RSA_PUBLIC_KEY_BYTES) + + +class TestEsSigner(object): + def test_from_string_pkcs1(self): + signer = es.EsSigner.from_string(PKCS1_KEY_BYTES) + assert isinstance(signer, es.EsSigner) + assert isinstance(signer._key, ec.EllipticCurvePrivateKey) + + def test_from_string_pkcs1_unicode(self): + key_bytes = _helpers.from_bytes(PKCS1_KEY_BYTES) + signer = es.EsSigner.from_string(key_bytes) + assert isinstance(signer, es.EsSigner) + assert isinstance(signer._key, ec.EllipticCurvePrivateKey) + + def test_from_string_bogus_key(self): + key_bytes = "bogus-key" + with pytest.raises(ValueError): + es.EsSigner.from_string(key_bytes) + + def test_from_string_type_error(self): + key_bytes = _helpers.from_bytes(RSA_PKCS1_KEY_BYTES) + with pytest.raises(TypeError): + es.EsSigner.from_string(key_bytes) + + def test_from_service_account_info(self): + signer = es.EsSigner.from_service_account_info(SERVICE_ACCOUNT_INFO) + + assert signer.key_id == SERVICE_ACCOUNT_INFO[base._JSON_FILE_PRIVATE_KEY_ID] + assert isinstance(signer._key, ec.EllipticCurvePrivateKey) + + def test_from_service_account_info_missing_key(self): + with pytest.raises(ValueError) as excinfo: + es.EsSigner.from_service_account_info({}) + + assert excinfo.match(base._JSON_FILE_PRIVATE_KEY) + + def test_from_service_account_file(self): + signer = es.EsSigner.from_service_account_file(SERVICE_ACCOUNT_JSON_FILE) + + assert signer.key_id == SERVICE_ACCOUNT_INFO[base._JSON_FILE_PRIVATE_KEY_ID] + assert isinstance(signer._key, ec.EllipticCurvePrivateKey) + + def test_pickle(self): + signer = es.EsSigner.from_service_account_file(SERVICE_ACCOUNT_JSON_FILE) + + assert signer.key_id == SERVICE_ACCOUNT_INFO[base._JSON_FILE_PRIVATE_KEY_ID] + assert isinstance(signer._key, ec.EllipticCurvePrivateKey) + + pickled_signer = pickle.dumps(signer) + signer = pickle.loads(pickled_signer) + + assert signer.key_id == SERVICE_ACCOUNT_INFO[base._JSON_FILE_PRIVATE_KEY_ID] + assert isinstance(signer._key, ec.EllipticCurvePrivateKey) diff --git a/tests/data/es384_privatekey.pem b/tests/data/es384_privatekey.pem new file mode 100644 index 000000000..12ff96291 --- /dev/null +++ b/tests/data/es384_privatekey.pem @@ -0,0 +1,6 @@ +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDBz1wKJNXd2Rzy52A7F3f9LmLp6KaMUTbL1IT3JaDx1kOp4CUFpI9Zs +rdEx7b7kKQGgBwYFK4EEACKhZANiAATRLiEHuOwLr8bjJnJdYG2mrlWtMEPBHOrm +n7RukR80nV5uAcqt+M319T2togP0tQIe621FUpJq7+Hq0vJJbtI1MPuFSDtpZG04 +5se7BVAw63IPV1EdO6vGXxd5Fay88uU= +-----END EC PRIVATE KEY----- diff --git a/tests/data/es384_public_cert.pem b/tests/data/es384_public_cert.pem new file mode 100644 index 000000000..e8d5d4c68 --- /dev/null +++ b/tests/data/es384_public_cert.pem @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICYzCCAeqgAwIBAgIUeYyowQBkomEoMj72pNh754QlGvAwCgYIKoZIzj0EAwIw +aTELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAkNBMRYwFAYDVQQHDA1Nb3VudGFpbiBW +aWV3MQ8wDQYDVQQKDAZHb29nbGUxDzANBgNVBAsMBkdvb2dsZTETMBEGA1UEAwwK +Z29vZ2xlLmNvbTAeFw0yNTExMTEwMDQzMTlaFw0yNTEyMTEwMDQzMTlaMGkxCzAJ +BgNVBAYTAlVTMQswCQYDVQQIDAJDQTEWMBQGA1UEBwwNTW91bnRhaW4gVmlldzEP +MA0GA1UECgwGR29vZ2xlMQ8wDQYDVQQLDAZHb29nbGUxEzARBgNVBAMMCmdvb2ds +ZS5jb20wdjAQBgcqhkjOPQIBBgUrgQQAIgNiAATRLiEHuOwLr8bjJnJdYG2mrlWt +MEPBHOrmn7RukR80nV5uAcqt+M319T2togP0tQIe621FUpJq7+Hq0vJJbtI1MPuF +SDtpZG045se7BVAw63IPV1EdO6vGXxd5Fay88uWjUzBRMB0GA1UdDgQWBBSRZkxR +63/X4JotxKDRWCI4PwIElDAfBgNVHSMEGDAWgBSRZkxR63/X4JotxKDRWCI4PwIE +lDAPBgNVHRMBAf8EBTADAQH/MAoGCCqGSM49BAMCA2cAMGQCMAU+2yy/luLTa+T6 +Jm86i9GiH/lPYdYwZFvwKJFTdj8FJpv7ySN0J80qzWxtBZTCMQIwZO0ZRdv8s7V3 +022yISIujmsPmgj7lvPuDZZaVn1DVYMG3YmBB+cTp+JTqF3x7lN+ +-----END CERTIFICATE----- diff --git a/tests/data/es384_publickey.pem b/tests/data/es384_publickey.pem new file mode 100644 index 000000000..e78ac0f49 --- /dev/null +++ b/tests/data/es384_publickey.pem @@ -0,0 +1,5 @@ +-----BEGIN PUBLIC KEY----- +MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE0S4hB7jsC6/G4yZyXWBtpq5VrTBDwRzq +5p+0bpEfNJ1ebgHKrfjN9fU9raID9LUCHuttRVKSau/h6tLySW7SNTD7hUg7aWRt +OObHuwVQMOtyD1dRHTurxl8XeRWsvPLl +-----END PUBLIC KEY----- diff --git a/tests/data/es384_service_account.json b/tests/data/es384_service_account.json new file mode 100644 index 000000000..8302344b1 --- /dev/null +++ b/tests/data/es384_service_account.json @@ -0,0 +1,9 @@ +{ + "type":"gdch_service_account", + "format_version":"1", + "project":"mytest", + "private_key_id":"1234567890", + "private_key":"-----BEGIN EC PRIVATE KEY-----\nMIGkAgEBBDAyqgUeNwuUOMCC9Bzyf4uT2rfZyISJFMq3ByfE+ytUbveUd6RtvoCT\nS9cYbmuj06OgBwYFK4EEACKhZANiAATrUB670cjyRUcarD//92jO52Rqo+jKi0x7\nkscWALlC8bx9zED5zpy948FrQhQgb/TLPhunkyTwWe22CzafS8ik5pCZKkWfiJRV\n9IBMJDTMyocCR013qDXKHZOpJ57wAUw=\n-----END EC PRIVATE KEY-----\n", + "name":"mytest", + "token_uri":"https://service-accounts.org.google.com/authenticate" +} diff --git a/tests/oauth2/test_sts.py b/tests/oauth2/test_sts.py index e0fb4ae23..e9075e406 100644 --- a/tests/oauth2/test_sts.py +++ b/tests/oauth2/test_sts.py @@ -41,6 +41,9 @@ class TestStsClient(object): ACTOR_TOKEN = "HEADER.ACTOR_TOKEN_PAYLOAD.SIGNATURE" ACTOR_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" TOKEN_EXCHANGE_ENDPOINT = "https://example.com/token.oauth2" + REVOKE_URL = "https://example.com/revoke.oauth2" + TOKEN_TO_REVOKE = "TOKEN_TO_REVOKE" + TOKEN_TYPE_HINT = "refresh_token" ADDON_HEADERS = {"x-client-version": "0.1.2"} ADDON_OPTIONS = {"additional": {"non-standard": ["options"], "other": "some-value"}} SUCCESS_RESPONSE = { @@ -72,10 +75,13 @@ def make_client(cls, client_auth=None): return sts.Client(cls.TOKEN_EXCHANGE_ENDPOINT, client_auth) @classmethod - def make_mock_request(cls, data, status=http_client.OK): + def make_mock_request(cls, data, status=http_client.OK, use_json=True): response = mock.create_autospec(transport.Response, instance=True) response.status = status - response.data = json.dumps(data).encode("utf-8") + if use_json: + response.data = json.dumps(data).encode("utf-8") + else: + response.data = data.encode("utf-8") request = mock.create_autospec(transport.Request) request.return_value = response @@ -83,10 +89,10 @@ def make_mock_request(cls, data, status=http_client.OK): return request @classmethod - def assert_request_kwargs(cls, request_kwargs, headers, request_data): - """Asserts the request was called with the expected parameters. - """ - assert request_kwargs["url"] == cls.TOKEN_EXCHANGE_ENDPOINT + def assert_request_kwargs(cls, request_kwargs, headers, request_data, url=None): + """Asserts the request was called with the expected parameters.""" + url = url or cls.TOKEN_EXCHANGE_ENDPOINT + assert request_kwargs["url"] == url assert request_kwargs["method"] == "POST" assert request_kwargs["headers"] == headers assert request_kwargs["body"] is not None @@ -447,6 +453,63 @@ def test_refresh_token_failure(self): r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749" ) + def test_revoke_token_success(self): + """Test revoke token with successful response.""" + client = self.make_client(self.CLIENT_AUTH_BASIC) + request = self.make_mock_request(data="", status=http_client.OK, use_json=False) + + response = client.revoke_token( + request, self.TOKEN_TO_REVOKE, self.TOKEN_TYPE_HINT, self.REVOKE_URL + ) + + headers = { + "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING), + "Content-Type": "application/x-www-form-urlencoded", + } + request_data = { + "token": self.TOKEN_TO_REVOKE, + "token_type_hint": self.TOKEN_TYPE_HINT, + } + self.assert_request_kwargs( + request.call_args[1], headers, request_data, url=self.REVOKE_URL + ) + assert response == {} + + def test_revoke_token_success_no_hint(self): + """Test revoke token with successful response.""" + client = self.make_client(self.CLIENT_AUTH_BASIC) + request = self.make_mock_request(data="", status=http_client.OK, use_json=False) + + response = client.revoke_token( + request, self.TOKEN_TO_REVOKE, None, self.REVOKE_URL + ) + + headers = { + "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING), + "Content-Type": "application/x-www-form-urlencoded", + } + request_data = {"token": self.TOKEN_TO_REVOKE} + self.assert_request_kwargs( + request.call_args[1], headers, request_data, url=self.REVOKE_URL + ) + assert response == {} + + def test_revoke_token_failure(self): + """Test revoke token with failure response.""" + client = self.make_client(self.CLIENT_AUTH_BASIC) + request = self.make_mock_request( + status=http_client.BAD_REQUEST, data=self.ERROR_RESPONSE + ) + + with pytest.raises(exceptions.OAuthError) as excinfo: + client.revoke_token( + request, self.TOKEN_TO_REVOKE, self.TOKEN_TYPE_HINT, self.REVOKE_URL + ) + + assert excinfo.match( + r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749" + ) + def test__make_request_success(self): """Test base method with successful response.""" client = self.make_client(self.CLIENT_AUTH_BASIC) @@ -478,3 +541,12 @@ def test_make_request_failure(self): assert excinfo.match( r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749" ) + + def test__make_request_empty_response(self): + """Test _make_request with a successful but empty response body.""" + client = self.make_client() + request = self.make_mock_request(data="", status=http_client.OK, use_json=False) + + response = client._make_request(request, {}, {}) + + assert response == {} diff --git a/tests/test__service_account_info.py b/tests/test__service_account_info.py index be2657074..7e836861e 100644 --- a/tests/test__service_account_info.py +++ b/tests/test__service_account_info.py @@ -23,13 +23,21 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), "data") SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json") -GDCH_SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "gdch_service_account.json") +GDCH_SERVICE_ACCOUNT_ES256_JSON_FILE = os.path.join( + DATA_DIR, "gdch_service_account.json" +) +GDCH_SERVICE_ACCOUNT_ES384_JSON_FILE = os.path.join( + DATA_DIR, "es384_service_account.json" +) with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: SERVICE_ACCOUNT_INFO = json.load(fh) -with open(GDCH_SERVICE_ACCOUNT_JSON_FILE, "r") as fh: - GDCH_SERVICE_ACCOUNT_INFO = json.load(fh) +with open(GDCH_SERVICE_ACCOUNT_ES256_JSON_FILE, "r") as fh: + GDCH_SERVICE_ACCOUNT_ES256_INFO = json.load(fh) + +with open(GDCH_SERVICE_ACCOUNT_ES384_JSON_FILE, "r") as fh: + GDCH_SERVICE_ACCOUNT_ES384_INFO = json.load(fh) def test_from_dict(): @@ -40,10 +48,19 @@ def test_from_dict(): def test_from_dict_es256_signer(): signer = _service_account_info.from_dict( - GDCH_SERVICE_ACCOUNT_INFO, use_rsa_signer=False + GDCH_SERVICE_ACCOUNT_ES256_INFO, use_rsa_signer=False + ) + assert isinstance(signer, crypt.EsSigner) + assert signer.key_id == GDCH_SERVICE_ACCOUNT_ES256_INFO["private_key_id"] + + +def test_from_dict_es384_signer(): + signer = _service_account_info.from_dict( + GDCH_SERVICE_ACCOUNT_ES384_INFO, use_rsa_signer=False ) - assert isinstance(signer, crypt.ES256Signer) - assert signer.key_id == GDCH_SERVICE_ACCOUNT_INFO["private_key_id"] + assert isinstance(signer, crypt.EsSigner) + assert signer.key_id == GDCH_SERVICE_ACCOUNT_ES384_INFO["private_key_id"] + assert signer.algorithm == "ES384" def test_from_dict_bad_private_key(): @@ -75,8 +92,18 @@ def test_from_filename(): def test_from_filename_es256_signer(): _, signer = _service_account_info.from_filename( - GDCH_SERVICE_ACCOUNT_JSON_FILE, use_rsa_signer=False + GDCH_SERVICE_ACCOUNT_ES256_JSON_FILE, use_rsa_signer=False + ) + + assert isinstance(signer, crypt.EsSigner) + assert signer.key_id == GDCH_SERVICE_ACCOUNT_ES256_INFO["private_key_id"] + + +def test_from_filename_es384_signer(): + _, signer = _service_account_info.from_filename( + GDCH_SERVICE_ACCOUNT_ES384_JSON_FILE, use_rsa_signer=False ) - assert isinstance(signer, crypt.ES256Signer) - assert signer.key_id == GDCH_SERVICE_ACCOUNT_INFO["private_key_id"] + assert isinstance(signer, crypt.EsSigner) + assert signer.key_id == GDCH_SERVICE_ACCOUNT_ES384_INFO["private_key_id"] + assert signer.algorithm == "ES384" diff --git a/tests/test_external_account_authorized_user.py b/tests/test_external_account_authorized_user.py index a4e121781..0a54af56d 100644 --- a/tests/test_external_account_authorized_user.py +++ b/tests/test_external_account_authorized_user.py @@ -349,6 +349,50 @@ def test_refresh_without_client_secret(self): request.assert_not_called() + def test_revoke_auth_success(self): + request = self.make_mock_request(status=http_client.OK, data={}) + creds = self.make_credentials(revoke_url=REVOKE_URL) + + creds.revoke(request) + + request.assert_called_once_with( + url=REVOKE_URL, + method="POST", + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + }, + body=("token=" + REFRESH_TOKEN + "&token_type_hint=refresh_token").encode( + "utf-8" + ), + ) + assert creds.token is None + assert creds._refresh_token is None + + def test_revoke_without_revoke_url(self): + request = self.make_mock_request() + creds = self.make_credentials(token=ACCESS_TOKEN) + + with pytest.raises(exceptions.OAuthError) as excinfo: + creds.revoke(request) + + assert excinfo.match( + r"The credentials do not contain the necessary fields to revoke the refresh token. You must specify revoke_url and refresh_token." + ) + + def test_revoke_without_refresh_token(self): + request = self.make_mock_request() + creds = self.make_credentials( + refresh_token=None, token=ACCESS_TOKEN, revoke_url=REVOKE_URL + ) + + with pytest.raises(exceptions.OAuthError) as excinfo: + creds.revoke(request) + + assert excinfo.match( + r"The credentials do not contain the necessary fields to revoke the refresh token. You must specify revoke_url and refresh_token." + ) + def test_info(self): creds = self.make_credentials() info = creds.info diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 28660ea33..a5a904d7d 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -43,6 +43,12 @@ with open(os.path.join(DATA_DIR, "es256_public_cert.pem"), "rb") as fh: EC_PUBLIC_CERT_BYTES = fh.read() +with open(os.path.join(DATA_DIR, "es384_privatekey.pem"), "rb") as fh: + EC384_PRIVATE_KEY_BYTES = fh.read() + +with open(os.path.join(DATA_DIR, "es384_public_cert.pem"), "rb") as fh: + EC384_PUBLIC_CERT_BYTES = fh.read() + SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json") with open(SERVICE_ACCOUNT_JSON_FILE, "rb") as fh: @@ -84,6 +90,11 @@ def es256_signer(): return crypt.ES256Signer.from_string(EC_PRIVATE_KEY_BYTES, "1") +@pytest.fixture +def es384_signer(): + return crypt.EsSigner.from_string(EC384_PRIVATE_KEY_BYTES, "1") + + def test_encode_basic_es256(es256_signer): test_payload = {"test": "value"} encoded = jwt.encode(es256_signer, test_payload) @@ -92,9 +103,19 @@ def test_encode_basic_es256(es256_signer): assert header == {"typ": "JWT", "alg": "ES256", "kid": es256_signer.key_id} +def test_encode_basic_es384(es384_signer): + test_payload = {"test": "value"} + encoded = jwt.encode(es384_signer, test_payload) + header, payload, _, _ = jwt._unverified_decode(encoded) + assert payload == test_payload + assert header == {"typ": "JWT", "alg": "ES384", "kid": es384_signer.key_id} + + @pytest.fixture -def token_factory(signer, es256_signer): - def factory(claims=None, key_id=None, use_es256_signer=False): +def token_factory(signer, es256_signer, es384_signer): + def factory( + claims=None, key_id=None, use_es256_signer=False, use_es384_signer=False + ): now = _helpers.datetime_to_secs(_helpers.utcnow()) payload = { "aud": "audience@example.com", @@ -113,6 +134,8 @@ def factory(claims=None, key_id=None, use_es256_signer=False): if use_es256_signer: return jwt.encode(es256_signer, payload, key_id=key_id) + elif use_es384_signer: + return jwt.encode(es384_signer, payload, key_id=key_id) else: return jwt.encode(signer, payload, key_id=key_id) @@ -158,6 +181,15 @@ def test_decode_valid_es256(token_factory): assert payload["metadata"]["meta"] == "data" +def test_decode_valid_es384(token_factory): + payload = jwt.decode( + token_factory(use_es384_signer=True), certs=EC384_PUBLIC_CERT_BYTES + ) + assert payload["aud"] == "audience@example.com" + assert payload["user"] == "billy bob" + assert payload["metadata"]["meta"] == "data" + + def test_decode_valid_with_audience(token_factory): payload = jwt.decode( token_factory(), certs=PUBLIC_CERT_BYTES, audience="audience@example.com" diff --git a/tests/test_pluggable.py b/tests/test_pluggable.py index 066920b22..d15ebb88b 100644 --- a/tests/test_pluggable.py +++ b/tests/test_pluggable.py @@ -1239,6 +1239,36 @@ def test_retrieve_subject_token_python_2(self): assert excinfo.match(r"Pluggable auth is only supported for python 3.7+") + @mock.patch.dict(os.environ, {"GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES": "1"}) + def test_retrieve_subject_token_with_quoted_command(self): + command_with_spaces = '"/path/with spaces/to/executable" "arg with spaces"' + credential_source = { + "executable": {"command": command_with_spaces, "timeout_millis": 30000} + } + + with mock.patch( + "subprocess.run", + return_value=subprocess.CompletedProcess( + args=[], + stdout=json.dumps( + self.EXECUTABLE_SUCCESSFUL_OIDC_RESPONSE_ID_TOKEN + ).encode("UTF-8"), + returncode=0, + ), + ) as mock_run: + credentials = self.make_pluggable(credential_source=credential_source) + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == self.EXECUTABLE_OIDC_TOKEN + mock_run.assert_called_once_with( + ["/path/with spaces/to/executable", "arg with spaces"], + timeout=30.0, + stdin=None, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=mock.ANY, + ) + @mock.patch.dict(os.environ, {"GOOGLE_EXTERNAL_ACCOUNT_ALLOW_EXECUTABLES": "1"}) def test_revoke_subject_token_python_2(self): with mock.patch("sys.version_info", (2, 7)): diff --git a/tests/transport/aio/test_sessions.py b/tests/transport/aio/test_sessions.py index c91a7c40a..742f863d0 100644 --- a/tests/transport/aio/test_sessions.py +++ b/tests/transport/aio/test_sessions.py @@ -32,8 +32,13 @@ @pytest.fixture -async def simple_async_task(): - return True +def simple_async_task(): + # Wrap async fixture within a synchronous fixture to suppress pytest.PytestRemovedIn9Warning + # See https://docs.pytest.org/en/stable/deprecations.html#sync-test-depending-on-async-fixture + async def inner_fixture(): + return True + + return inner_fixture() class MockRequest(Request): @@ -151,10 +156,15 @@ class TestAsyncAuthorizedSession(object): credentials = AnonymousCredentials() @pytest.fixture - async def mocked_content(self): - content = [b"Cavefish ", b"have ", b"no ", b"sight."] - for chunk in content: - yield chunk + def mocked_content(self): + # Wrap async fixture within a synchronous fixture to suppress pytest.PytestRemovedIn9Warning + # See https://docs.pytest.org/en/stable/deprecations.html#sync-test-depending-on-async-fixture + async def inner_fixture(): + content = [b"Cavefish ", b"have ", b"no ", b"sight."] + for chunk in content: + yield chunk + + return inner_fixture() @pytest.mark.asyncio async def test_constructor_with_default_auth_request(self): diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py index 01d5e3a40..2a7a524b1 100644 --- a/tests/transport/test__mtls_helper.py +++ b/tests/transport/test__mtls_helper.py @@ -334,9 +334,102 @@ def test_success_with_certificate_config( assert key == pytest.private_key_bytes assert passphrase is None + @mock.patch( + "google.auth.transport._mtls_helper._read_cert_and_key_files", autospec=True + ) + @mock.patch( + "google.auth.transport._mtls_helper._get_cert_config_path", autospec=True + ) + @mock.patch("google.auth.transport._mtls_helper._load_json_file", autospec=True) @mock.patch("google.auth.transport._mtls_helper._check_config_path", autospec=True) - def test_success_without_metadata(self, mock_check_config_path): + def test_success_with_certificate_config_cloud_run_patch( + self, + mock_check_config_path, + mock_load_json_file, + mock_get_cert_config_path, + mock_read_cert_and_key_files, + ): + cert_config_path = "/path/to/config" + mock_check_config_path.return_value = cert_config_path + mock_load_json_file.return_value = { + "cert_configs": { + "workload": { + "cert_path": _mtls_helper._INCORRECT_CLOUD_RUN_CERT_PATH, + "key_path": _mtls_helper._INCORRECT_CLOUD_RUN_KEY_PATH, + } + } + } + mock_get_cert_config_path.return_value = cert_config_path + mock_read_cert_and_key_files.return_value = ( + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + + has_cert, cert, key, passphrase = _mtls_helper.get_client_ssl_credentials() + assert has_cert + assert cert == pytest.public_cert_bytes + assert key == pytest.private_key_bytes + assert passphrase is None + + mock_read_cert_and_key_files.assert_called_once_with( + _mtls_helper._WELL_KNOWN_CLOUD_RUN_CERT_PATH, + _mtls_helper._WELL_KNOWN_CLOUD_RUN_KEY_PATH, + ) + + @mock.patch("os.path.exists", autospec=True) + @mock.patch( + "google.auth.transport._mtls_helper._read_cert_and_key_files", autospec=True + ) + @mock.patch( + "google.auth.transport._mtls_helper._get_cert_config_path", autospec=True + ) + @mock.patch("google.auth.transport._mtls_helper._load_json_file", autospec=True) + @mock.patch("google.auth.transport._mtls_helper._check_config_path", autospec=True) + def test_success_with_certificate_config_cloud_run_patch_skipped_if_cert_exists( + self, + mock_check_config_path, + mock_load_json_file, + mock_get_cert_config_path, + mock_read_cert_and_key_files, + mock_os_path_exists, + ): + cert_config_path = "/path/to/config" + mock_check_config_path.return_value = cert_config_path + mock_os_path_exists.return_value = True + mock_load_json_file.return_value = { + "cert_configs": { + "workload": { + "cert_path": _mtls_helper._INCORRECT_CLOUD_RUN_CERT_PATH, + "key_path": _mtls_helper._INCORRECT_CLOUD_RUN_KEY_PATH, + } + } + } + mock_get_cert_config_path.return_value = cert_config_path + mock_read_cert_and_key_files.return_value = ( + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + + has_cert, cert, key, passphrase = _mtls_helper.get_client_ssl_credentials() + assert has_cert + assert cert == pytest.public_cert_bytes + assert key == pytest.private_key_bytes + assert passphrase is None + + mock_read_cert_and_key_files.assert_called_once_with( + _mtls_helper._INCORRECT_CLOUD_RUN_CERT_PATH, + _mtls_helper._INCORRECT_CLOUD_RUN_KEY_PATH, + ) + + @mock.patch( + "google.auth.transport._mtls_helper._get_workload_cert_and_key", autospec=True + ) + @mock.patch("google.auth.transport._mtls_helper._check_config_path", autospec=True) + def test_success_without_metadata( + self, mock_check_config_path, mock_get_workload_cert_and_key + ): mock_check_config_path.return_value = False + mock_get_workload_cert_and_key.return_value = (None, None) has_cert, cert, key, passphrase = _mtls_helper.get_client_ssl_credentials() assert not has_cert assert cert is None @@ -395,12 +488,17 @@ def test_missing_cert_command( ) @mock.patch("google.auth.transport._mtls_helper._load_json_file", autospec=True) @mock.patch("google.auth.transport._mtls_helper._check_config_path", autospec=True) + @mock.patch( + "google.auth.transport._mtls_helper._get_workload_cert_and_key", autospec=True + ) def test_customize_context_aware_metadata_path( self, + mock_get_workload_cert_and_key, mock_check_config_path, mock_load_json_file, mock_run_cert_provider_command, ): + mock_get_workload_cert_and_key.return_value = (None, None) context_aware_metadata_path = "/path/to/metata/data" mock_check_config_path.return_value = context_aware_metadata_path mock_load_json_file.return_value = {"cert_provider_command": ["command"]} diff --git a/tests_async/transport/test_aiohttp_requests.py b/tests_async/transport/test_aiohttp_requests.py index d00955a7d..e910779a6 100644 --- a/tests_async/transport/test_aiohttp_requests.py +++ b/tests_async/transport/test_aiohttp_requests.py @@ -115,10 +115,11 @@ def make_with_parameter_request(self): http = aiohttp.ClientSession(auto_decompress=False) return aiohttp_requests.Request(http) - def test_unsupported_session(self): + @pytest.mark.asyncio + async def test_unsupported_session(self): http = aiohttp.ClientSession(auto_decompress=True) with pytest.raises(ValueError): - aiohttp_requests.Request(http) + await aiohttp_requests.Request(http) def test_timeout(self): http = mock.create_autospec( @@ -144,11 +145,13 @@ class TestAuthorizedSession(object): TEST_URL = "http://example.com/" method = "GET" - def test_constructor(self): + @pytest.mark.asyncio + async def test_constructor(self): authed_session = aiohttp_requests.AuthorizedSession(mock.sentinel.credentials) assert authed_session.credentials == mock.sentinel.credentials - def test_constructor_with_auth_request(self): + @pytest.mark.asyncio + async def test_constructor_with_auth_request(self): http = mock.create_autospec( aiohttp.ClientSession, instance=True, _auto_decompress=False )