diff --git a/pubsub/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py b/pubsub/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py index b393cbfd5ec6..c4f64a1aa8c5 100644 --- a/pubsub/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py +++ b/pubsub/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py @@ -44,6 +44,7 @@ exceptions.GatewayTimeout, exceptions.Aborted, ) +_TERMINATING_STREAM_ERRORS = (exceptions.Cancelled,) _MAX_LOAD = 1.0 """The load threshold above which to pause the incoming message stream.""" @@ -394,6 +395,7 @@ def open(self, callback, on_callback_error): start_rpc=self._client.api.streaming_pull, initial_request=get_initial_request, should_recover=self._should_recover, + should_terminate=self._should_terminate, throttle_reopen=True, ) self._rpc.add_done_callback(self._on_rpc_done) @@ -582,6 +584,26 @@ def _should_recover(self, exception): _LOGGER.info("Observed non-recoverable stream error %s", exception) return False + def _should_terminate(self, exception): + """Determine if an error on the RPC stream should be terminated. + + If the exception is one of the terminating exceptions, this will signal + to the consumer thread that it should terminate. + + This will cause the stream to exit when it returns :data:`True`. + + Returns: + bool: Indicates if the caller should terminate or attempt recovery. + Will be :data:`True` if the ``exception`` is "acceptable", i.e. + in a list of terminating exceptions. + """ + exception = _maybe_wrap_exception(exception) + if isinstance(exception, _TERMINATING_STREAM_ERRORS): + _LOGGER.info("Observed terminating stream error %s", exception) + return True + _LOGGER.info("Observed non-terminating stream error %s", exception) + return False + def _on_rpc_done(self, future): """Triggered whenever the underlying RPC terminates without recovery. diff --git a/pubsub/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py b/pubsub/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py index 352b09ba83fc..1b72f48dc9f6 100644 --- a/pubsub/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py +++ b/pubsub/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py @@ -433,6 +433,7 @@ def test_open(heartbeater, dispatcher, leaser, background_consumer, resumable_bi start_rpc=manager._client.api.streaming_pull, initial_request=mock.ANY, should_recover=manager._should_recover, + should_terminate=manager._should_terminate, throttle_reopen=True, ) initial_request_arg = resumable_bidi_rpc.call_args.kwargs["initial_request"] @@ -726,6 +727,23 @@ def test__should_recover_false(): assert manager._should_recover(exc) is False +def test__should_terminate_true(): + manager = make_manager() + + details = "Cancelled. Go away, before I taunt you a second time." + exc = exceptions.Cancelled(details) + + assert manager._should_terminate(exc) is True + + +def test__should_terminate_false(): + manager = make_manager() + + exc = TypeError("wahhhhhh") + + assert manager._should_terminate(exc) is False + + @mock.patch("threading.Thread", autospec=True) def test__on_rpc_done(thread): manager = make_manager()