Skip to content
This repository was archived by the owner on Mar 9, 2026. It is now read-only.

Commit d45bfbf

Browse files
fix: fix mtls issue in handwritten layer
1 parent ec8f5f2 commit d45bfbf

4 files changed

Lines changed: 46 additions & 22 deletions

File tree

google/cloud/pubsub_v1/publisher/client.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,19 @@ def __init__(self, batch_settings=(), publisher_options=(), **kwargs):
130130
target=os.environ.get("PUBSUB_EMULATOR_HOST")
131131
)
132132

133+
# The GAPIC client has mTLS logic to determine the api endpoint and the
134+
# ssl credentials to use. Here we create a GAPIC client to help compute the
135+
# api endpoint and ssl credentials. The api endpoint will be used to set
136+
# `self._target`, and ssl credentials will be passed to
137+
# `grpc_helpers.create_channel` to establish a mTLS channel (if ssl
138+
# credentials is not None).
133139
client_options = kwargs.get("client_options", None)
134-
if (
135-
client_options
136-
and "api_endpoint" in client_options
137-
and isinstance(client_options["api_endpoint"], six.string_types)
138-
):
139-
self._target = client_options["api_endpoint"]
140-
else:
141-
self._target = publisher_client.PublisherClient.SERVICE_ADDRESS
140+
credentials = kwargs.get("credentials", None)
141+
client_for_mtls_info = publisher_client.PublisherClient(
142+
credentials=credentials, client_options=client_options
143+
)
144+
145+
self._target = client_for_mtls_info._transport._host
142146

143147
# Use a custom channel.
144148
# We need this in order to set appropriate default message size and
@@ -149,6 +153,7 @@ def __init__(self, batch_settings=(), publisher_options=(), **kwargs):
149153
channel = grpc_helpers.create_channel(
150154
credentials=kwargs.pop("credentials", None),
151155
target=self.target,
156+
ssl_credentials=client_for_mtls_info._transport._ssl_channel_credentials,
152157
scopes=publisher_client.PublisherClient._DEFAULT_SCOPES,
153158
options={
154159
"grpc.max_send_message_length": -1,

google/cloud/pubsub_v1/subscriber/client.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import os
1818
import pkg_resources
19-
import six
2019

2120
import grpc
2221

@@ -82,16 +81,19 @@ def __init__(self, **kwargs):
8281
target=os.environ.get("PUBSUB_EMULATOR_HOST")
8382
)
8483

85-
# api_endpoint wont be applied if 'transport' is passed in.
84+
# The GAPIC client has mTLS logic to determine the api endpoint and the
85+
# ssl credentials to use. Here we create a GAPIC client to help compute the
86+
# api endpoint and ssl credentials. The api endpoint will be used to set
87+
# `self._target`, and ssl credentials will be passed to
88+
# `grpc_helpers.create_channel` to establish a mTLS channel (if ssl
89+
# credentials is not None).
8690
client_options = kwargs.get("client_options", None)
87-
if (
88-
client_options
89-
and "api_endpoint" in client_options
90-
and isinstance(client_options["api_endpoint"], six.string_types)
91-
):
92-
self._target = client_options["api_endpoint"]
93-
else:
94-
self._target = subscriber_client.SubscriberClient.SERVICE_ADDRESS
91+
credentials = kwargs.get("credentials", None)
92+
client_for_mtls_info = subscriber_client.SubscriberClient(
93+
credentials=credentials, client_options=client_options
94+
)
95+
96+
self._target = client_for_mtls_info._transport._host
9597

9698
# Use a custom channel.
9799
# We need this in order to set appropriate default message size and
@@ -102,6 +104,7 @@ def __init__(self, **kwargs):
102104
channel = grpc_helpers.create_channel(
103105
credentials=kwargs.pop("credentials", None),
104106
target=self.target,
107+
ssl_credentials=client_for_mtls_info._transport._ssl_channel_credentials,
105108
scopes=subscriber_client.SubscriberClient._DEFAULT_SCOPES,
106109
options={
107110
"grpc.max_send_message_length": -1,

tests/unit/pubsub_v1/publisher/test_publisher_client.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import inspect
1919

2020
from google.auth import credentials
21+
import grpc
2122

2223
import mock
2324
import pytest
@@ -81,7 +82,7 @@ def test_init_w_api_endpoint():
8182
assert isinstance(client.api, publisher_client.PublisherClient)
8283
assert (client.api._transport.grpc_channel._channel.target()).decode(
8384
"utf-8"
84-
) == "testendpoint.google.com"
85+
) == "testendpoint.google.com:443"
8586

8687

8788
def test_init_w_unicode_api_endpoint():
@@ -91,7 +92,7 @@ def test_init_w_unicode_api_endpoint():
9192
assert isinstance(client.api, publisher_client.PublisherClient)
9293
assert (client.api._transport.grpc_channel._channel.target()).decode(
9394
"utf-8"
94-
) == "testendpoint.google.com"
95+
) == "testendpoint.google.com:443"
9596

9697

9798
def test_init_w_empty_client_options():
@@ -104,8 +105,13 @@ def test_init_w_empty_client_options():
104105

105106

106107
def test_init_client_options_pass_through():
108+
mock_ssl_creds = grpc.ssl_channel_credentials()
109+
107110
def init(self, *args, **kwargs):
108111
self.kwargs = kwargs
112+
self._transport = mock.Mock()
113+
self._transport._host = "testendpoint.google.com"
114+
self._transport._ssl_channel_credentials = mock_ssl_creds
109115

110116
with mock.patch.object(publisher_client.PublisherClient, "__init__", init):
111117
client = publisher.Client(
@@ -119,6 +125,8 @@ def init(self, *args, **kwargs):
119125
assert client_options.get("quota_project_id") == "42"
120126
assert client_options.get("scopes") == []
121127
assert client_options.get("credentials_file") == "file.json"
128+
assert client.target == "testendpoint.google.com"
129+
assert client.api.transport._ssl_channel_credentials == mock_ssl_creds
122130

123131

124132
def test_init_emulator(monkeypatch):

tests/unit/pubsub_v1/subscriber/test_subscriber_client.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from google.auth import credentials
16+
import grpc
1617
import mock
1718

1819
from google.cloud.pubsub_v1 import subscriber
@@ -42,7 +43,7 @@ def test_init_w_api_endpoint():
4243
assert isinstance(client.api, subscriber_client.SubscriberClient)
4344
assert (client.api._transport.grpc_channel._channel.target()).decode(
4445
"utf-8"
45-
) == "testendpoint.google.com"
46+
) == "testendpoint.google.com:443"
4647

4748

4849
def test_init_w_unicode_api_endpoint():
@@ -52,7 +53,7 @@ def test_init_w_unicode_api_endpoint():
5253
assert isinstance(client.api, subscriber_client.SubscriberClient)
5354
assert (client.api._transport.grpc_channel._channel.target()).decode(
5455
"utf-8"
55-
) == "testendpoint.google.com"
56+
) == "testendpoint.google.com:443"
5657

5758

5859
def test_init_w_empty_client_options():
@@ -65,8 +66,13 @@ def test_init_w_empty_client_options():
6566

6667

6768
def test_init_client_options_pass_through():
69+
mock_ssl_creds = grpc.ssl_channel_credentials()
70+
6871
def init(self, *args, **kwargs):
6972
self.kwargs = kwargs
73+
self._transport = mock.Mock()
74+
self._transport._host = "testendpoint.google.com"
75+
self._transport._ssl_channel_credentials = mock_ssl_creds
7076

7177
with mock.patch.object(subscriber_client.SubscriberClient, "__init__", init):
7278
client = subscriber.Client(
@@ -80,6 +86,8 @@ def init(self, *args, **kwargs):
8086
assert client_options.get("quota_project_id") == "42"
8187
assert client_options.get("scopes") == []
8288
assert client_options.get("credentials_file") == "file.json"
89+
assert client.target == "testendpoint.google.com"
90+
assert client.api.transport._ssl_channel_credentials == mock_ssl_creds
8391

8492

8593
def test_init_emulator(monkeypatch):

0 commit comments

Comments
 (0)