diff --git a/google/cloud/bigtable/client.py b/google/cloud/bigtable/client.py index be536f295..7249c0b35 100644 --- a/google/cloud/bigtable/client.py +++ b/google/cloud/bigtable/client.py @@ -67,6 +67,13 @@ READ_ONLY_SCOPE = "https://www.googleapis.com/auth/bigtable.data.readonly" """Scope for reading table data.""" +_GRPC_CHANNEL_OPTIONS = ( + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ("grpc.keepalive_time_ms", 30000), + ("grpc.keepalive_timeout_ms", 10000), +) + def _create_gapic_client(client_class, client_options=None, transport=None): def inner(self): @@ -195,11 +202,15 @@ def _get_scopes(self): return scopes def _emulator_channel(self, transport, options): - """ - Creates a channel using self._credentials in a similar way to grpc.secure_channel but - using grpc.local_channel_credentials() rather than grpc.ssh_channel_credentials() - to allow easy connection to a local emulator. - :return: grpc.Channel or grpc.aio.Channel + """Create a channel using self._credentials + + Works in a similar way to ``grpc.secure_channel`` but using + ``grpc.local_channel_credentials`` rather than + ``grpc.ssh_channel_credentials`` to allow easy connection to a + local emulator. + + Returns: + grpc.Channel or grpc.aio.Channel """ # TODO: Implement a special credentials type for emulator and use # "transport.create_channel" to create gRPC channels once google-auth @@ -219,8 +230,8 @@ def _emulator_channel(self, transport, options): ) def _local_composite_credentials(self): - """ - Creates the credentials for the local emulator channel + """Create credentials for the local emulator channel. + :return: grpc.ChannelCredentials """ credentials = google.auth.credentials.with_scopes_if_required( @@ -245,27 +256,24 @@ def _local_composite_credentials(self): ) def _create_gapic_client_channel(self, client_class, grpc_transport): - options = { - "grpc.max_send_message_length": -1, - "grpc.max_receive_message_length": -1, - "grpc.keepalive_time_ms": 30000, - "grpc.keepalive_timeout_ms": 10000, - }.items() - if self._client_options and self._client_options.api_endpoint: + if self._emulator_host is not None: + api_endpoint = self._emulator_host + elif self._client_options and self._client_options.api_endpoint: api_endpoint = self._client_options.api_endpoint else: api_endpoint = client_class.DEFAULT_ENDPOINT - channel = None if self._emulator_host is not None: - api_endpoint = self._emulator_host - channel = self._emulator_channel(grpc_transport, options) + channel = self._emulator_channel( + transport=grpc_transport, options=_GRPC_CHANNEL_OPTIONS, + ) else: channel = grpc_transport.create_channel( - host=api_endpoint, credentials=self._credentials, options=options, + host=api_endpoint, + credentials=self._credentials, + options=_GRPC_CHANNEL_OPTIONS, ) - transport = grpc_transport(channel=channel, host=api_endpoint) - return transport + return grpc_transport(channel=channel, host=api_endpoint) @property def project_path(self): diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index f6b8eb5bc..5c557763a 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -170,6 +170,7 @@ def test_constructor_both_admin_and_read_only(self): def test_constructor_with_emulator_host(self): from google.cloud.environment_vars import BIGTABLE_EMULATOR + from google.cloud.bigtable.client import _GRPC_CHANNEL_OPTIONS credentials = _make_credentials() emulator_host = "localhost:8081" @@ -183,13 +184,9 @@ def test_constructor_with_emulator_host(self): client.table_data_client self.assertEqual(client._emulator_host, emulator_host) - options = { - "grpc.max_send_message_length": -1, - "grpc.max_receive_message_length": -1, - "grpc.keepalive_time_ms": 30000, - "grpc.keepalive_timeout_ms": 10000, - }.items() - factory.assert_called_once_with(emulator_host, credentials, options=options) + factory.assert_called_once_with( + emulator_host, credentials, options=_GRPC_CHANNEL_OPTIONS, + ) def test__get_scopes_default(self): from google.cloud.bigtable.client import DATA_SCOPE @@ -215,6 +212,140 @@ def test__get_scopes_read_only(self): ) self.assertEqual(client._get_scopes(), (READ_ONLY_SCOPE,)) + def test__emulator_channel_sync(self): + emulator_host = "localhost:8081" + transport_name = "GrpcTransportTesting" + transport = mock.Mock(spec=["__name__"], __name__=transport_name) + options = mock.Mock(spec=[]) + client = self._make_one( + project=self.PROJECT, credentials=_make_credentials(), read_only=True + ) + client._emulator_host = emulator_host + lcc = client._local_composite_credentials = mock.Mock(spec=[]) + + with mock.patch("grpc.secure_channel") as patched: + channel = client._emulator_channel(transport, options) + + assert channel is patched.return_value + patched.assert_called_once_with( + emulator_host, lcc.return_value, options=options, + ) + + def test__emulator_channel_async(self): + emulator_host = "localhost:8081" + transport_name = "GrpcAsyncIOTransportTesting" + transport = mock.Mock(spec=["__name__"], __name__=transport_name) + options = mock.Mock(spec=[]) + client = self._make_one( + project=self.PROJECT, credentials=_make_credentials(), read_only=True + ) + client._emulator_host = emulator_host + lcc = client._local_composite_credentials = mock.Mock(spec=[]) + + with mock.patch("grpc.aio.secure_channel") as patched: + channel = client._emulator_channel(transport, options) + + assert channel is patched.return_value + patched.assert_called_once_with( + emulator_host, lcc.return_value, options=options, + ) + + def test__local_composite_credentials(self): + client = self._make_one( + project=self.PROJECT, credentials=_make_credentials(), read_only=True + ) + + wsir_patch = mock.patch("google.auth.credentials.with_scopes_if_required") + request_patch = mock.patch("google.auth.transport.requests.Request") + amp_patch = mock.patch("google.auth.transport.grpc.AuthMetadataPlugin") + grpc_patches = mock.patch.multiple( + "grpc", + metadata_call_credentials=mock.DEFAULT, + local_channel_credentials=mock.DEFAULT, + composite_channel_credentials=mock.DEFAULT, + ) + with wsir_patch as wsir_patched: + with request_patch as request_patched: + with amp_patch as amp_patched: + with grpc_patches as grpc_patched: + credentials = client._local_composite_credentials() + + grpc_mcc = grpc_patched["metadata_call_credentials"] + grpc_lcc = grpc_patched["local_channel_credentials"] + grpc_ccc = grpc_patched["composite_channel_credentials"] + + self.assertIs(credentials, grpc_ccc.return_value) + + wsir_patched.assert_called_once_with(client._credentials, None) + request_patched.assert_called_once_with() + amp_patched.assert_called_once_with( + wsir_patched.return_value, request_patched.return_value, + ) + grpc_mcc.assert_called_once_with(amp_patched.return_value) + grpc_lcc.assert_called_once_with() + grpc_ccc.assert_called_once_with(grpc_lcc.return_value, grpc_mcc.return_value) + + def _create_gapic_client_channel_helper( + self, endpoint=None, emulator_host=None, + ): + from google.cloud.bigtable.client import _GRPC_CHANNEL_OPTIONS + + client_class = mock.Mock(spec=["DEFAULT_ENDPOINT"]) + credentials = _make_credentials() + client = self._make_one(project=self.PROJECT, credentials=credentials) + + if endpoint is not None: + client._client_options = mock.Mock( + spec=["api_endpoint"], api_endpoint=endpoint, + ) + expected_host = endpoint + else: + expected_host = client_class.DEFAULT_ENDPOINT + + if emulator_host is not None: + client._emulator_host = emulator_host + client._emulator_channel = mock.Mock(spec=[]) + expected_host = emulator_host + + grpc_transport = mock.Mock(spec=["create_channel"]) + + transport = client._create_gapic_client_channel(client_class, grpc_transport) + + self.assertIs(transport, grpc_transport.return_value) + + if emulator_host is not None: + client._emulator_channel.assert_called_once_with( + transport=grpc_transport, options=_GRPC_CHANNEL_OPTIONS, + ) + grpc_transport.assert_called_once_with( + channel=client._emulator_channel.return_value, host=expected_host, + ) + else: + grpc_transport.create_channel.assert_called_once_with( + host=expected_host, + credentials=client._credentials, + options=_GRPC_CHANNEL_OPTIONS, + ) + grpc_transport.assert_called_once_with( + channel=grpc_transport.create_channel.return_value, host=expected_host, + ) + + def test__create_gapic_client_channel_w_defaults(self): + self._create_gapic_client_channel_helper() + + def test__create_gapic_client_channel_w_endpoint(self): + endpoint = "api.example.com" + self._create_gapic_client_channel_helper(endpoint=endpoint) + + def test__create_gapic_client_channel_w_emulator_host(self): + host = "api.example.com:1234" + self._create_gapic_client_channel_helper(emulator_host=host) + + def test__create_gapic_client_channel_w_endpoint_w_emulator_host(self): + endpoint = "api.example.com" + host = "other.example.com:1234" + self._create_gapic_client_channel_helper(endpoint=endpoint, emulator_host=host) + def test_project_path_property(self): credentials = _make_credentials() project = "PROJECT"