Skip to content

Commit e303c23

Browse files
feat: add helper func to for default encrypted cert (#514)
* feat: helper func to for default encrpted cert
1 parent 0e1ab39 commit e303c23

2 files changed

Lines changed: 70 additions & 0 deletions

File tree

packages/google-auth/google/auth/transport/mtls.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,45 @@ def callback():
5858
return cert_bytes, key_bytes
5959

6060
return callback
61+
62+
63+
def default_client_encrypted_cert_source(cert_path, key_path):
64+
"""Get a callback which returns the default encrpyted client SSL credentials.
65+
66+
Args:
67+
cert_path (str): The cert file path. The default client certificate will
68+
be written to this file when the returned callback is called.
69+
key_path (str): The key file path. The default encrypted client key will
70+
be written to this file when the returned callback is called.
71+
72+
Returns:
73+
Callable[[], [str, str, bytes]]: A callback which generates the default
74+
client certificate, encrpyted private key and passphrase. It writes
75+
the certificate and private key into the cert_path and key_path, and
76+
returns the cert_path, key_path and passphrase bytes.
77+
78+
Raises:
79+
google.auth.exceptions.DefaultClientCertSourceError: If any problem
80+
occurs when loading or saving the client certificate and key.
81+
"""
82+
if not has_default_client_cert_source():
83+
raise exceptions.MutualTLSChannelError(
84+
"Default client encrypted cert source doesn't exist"
85+
)
86+
87+
def callback():
88+
try:
89+
_, cert_bytes, key_bytes, passphrase_bytes = _mtls_helper.get_client_ssl_credentials(
90+
generate_encrypted_key=True
91+
)
92+
with open(cert_path, "wb") as cert_file:
93+
cert_file.write(cert_bytes)
94+
with open(key_path, "wb") as key_file:
95+
key_file.write(key_bytes)
96+
except (exceptions.ClientCertError, OSError) as caught_exc:
97+
new_exc = exceptions.MutualTLSChannelError(caught_exc)
98+
six.raise_from(new_exc, caught_exc)
99+
100+
return cert_path, key_path, passphrase_bytes
101+
102+
return callback

packages/google-auth/tests/transport/test_mtls.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,31 @@ def test_default_client_cert_source(
5353
callback = mtls.default_client_cert_source()
5454
with pytest.raises(exceptions.MutualTLSChannelError):
5555
callback()
56+
57+
58+
@mock.patch(
59+
"google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
60+
)
61+
@mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True)
62+
def test_default_client_encrypted_cert_source(
63+
has_default_client_cert_source, get_client_ssl_credentials
64+
):
65+
# Test default client cert source doesn't exist.
66+
has_default_client_cert_source.return_value = False
67+
with pytest.raises(exceptions.MutualTLSChannelError):
68+
mtls.default_client_encrypted_cert_source("cert_path", "key_path")
69+
70+
# The following tests will assume default client cert source exists.
71+
has_default_client_cert_source.return_value = True
72+
73+
# Test good callback.
74+
get_client_ssl_credentials.return_value = (True, b"cert", b"key", b"passphrase")
75+
callback = mtls.default_client_encrypted_cert_source("cert_path", "key_path")
76+
with mock.patch("{}.open".format(__name__), return_value=mock.MagicMock()):
77+
assert callback() == ("cert_path", "key_path", b"passphrase")
78+
79+
# Test bad callback which throws exception.
80+
get_client_ssl_credentials.side_effect = exceptions.ClientCertError()
81+
callback = mtls.default_client_encrypted_cert_source("cert_path", "key_path")
82+
with pytest.raises(exceptions.MutualTLSChannelError):
83+
callback()

0 commit comments

Comments
 (0)