diff --git a/google/cloud/pubsub_v1/subscriber/_protocol/dispatcher.py b/google/cloud/pubsub_v1/subscriber/_protocol/dispatcher.py index 7a8950844..382c5c38a 100644 --- a/google/cloud/pubsub_v1/subscriber/_protocol/dispatcher.py +++ b/google/cloud/pubsub_v1/subscriber/_protocol/dispatcher.py @@ -99,9 +99,6 @@ def dispatch_callback(self, items): ValueError: If ``action`` isn't one of the expected actions "ack", "drop", "lease", "modify_ack_deadline" or "nack". """ - if not self._manager.is_active: - return - batched_commands = collections.defaultdict(list) for item in items: diff --git a/google/cloud/pubsub_v1/subscriber/_protocol/heartbeater.py b/google/cloud/pubsub_v1/subscriber/_protocol/heartbeater.py index 9cd84a1e2..fef158965 100644 --- a/google/cloud/pubsub_v1/subscriber/_protocol/heartbeater.py +++ b/google/cloud/pubsub_v1/subscriber/_protocol/heartbeater.py @@ -35,10 +35,11 @@ def __init__(self, manager, period=_DEFAULT_PERIOD): self._period = period def heartbeat(self): - """Periodically send heartbeats.""" - while self._manager.is_active and not self._stop_event.is_set(): - self._manager.heartbeat() - _LOGGER.debug("Sent heartbeat.") + """Periodically send streaming pull heartbeats. + """ + while not self._stop_event.is_set(): + if self._manager.heartbeat(): + _LOGGER.debug("Sent heartbeat.") self._stop_event.wait(timeout=self._period) _LOGGER.info("%s exiting.", _HEARTBEAT_WORKER_NAME) diff --git a/google/cloud/pubsub_v1/subscriber/_protocol/leaser.py b/google/cloud/pubsub_v1/subscriber/_protocol/leaser.py index 5830680da..4a19792fc 100644 --- a/google/cloud/pubsub_v1/subscriber/_protocol/leaser.py +++ b/google/cloud/pubsub_v1/subscriber/_protocol/leaser.py @@ -126,7 +126,7 @@ def maintain_leases(self): ack IDs, then waits for most of that time (but with jitter), and repeats. """ - while self._manager.is_active and not self._stop_event.is_set(): + while not self._stop_event.is_set(): # Determine the appropriate duration for the lease. This is # based off of how long previous messages have taken to ack, with # a sensible default and within the ranges allowed by Pub/Sub. diff --git a/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py b/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py index e8a4a8caf..8a3ff0e87 100644 --- a/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py +++ b/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py @@ -16,6 +16,7 @@ import collections import functools +import itertools import logging import threading import uuid @@ -113,10 +114,6 @@ class StreamingPullManager(object): scheduler will be used. """ - _UNARY_REQUESTS = True - """If set to True, this class will make requests over a separate unary - RPC instead of over the streaming RPC.""" - def __init__( self, client, @@ -292,6 +289,9 @@ def activate_ordering_keys(self, ordering_keys): activate. May be empty. """ with self._pause_resume_lock: + if self._scheduler is None: + return # We are shutting down, don't try to dispatch any more messages. + self._messages_on_hold.activate_ordering_keys( ordering_keys, self._schedule_message_on_hold ) @@ -421,37 +421,36 @@ def send(self, request): If a RetryError occurs, the manager shutdown is triggered, and the error is re-raised. """ - if self._UNARY_REQUESTS: - try: - self._send_unary_request(request) - except exceptions.GoogleAPICallError: - _LOGGER.debug( - "Exception while sending unary RPC. This is typically " - "non-fatal as stream requests are best-effort.", - exc_info=True, - ) - except exceptions.RetryError as exc: - _LOGGER.debug( - "RetryError while sending unary RPC. Waiting on a transient " - "error resolution for too long, will now trigger shutdown.", - exc_info=False, - ) - # The underlying channel has been suffering from a retryable error - # for too long, time to give up and shut the streaming pull down. - self._on_rpc_done(exc) - raise - - else: - self._rpc.send(request) + try: + self._send_unary_request(request) + except exceptions.GoogleAPICallError: + _LOGGER.debug( + "Exception while sending unary RPC. This is typically " + "non-fatal as stream requests are best-effort.", + exc_info=True, + ) + except exceptions.RetryError as exc: + _LOGGER.debug( + "RetryError while sending unary RPC. Waiting on a transient " + "error resolution for too long, will now trigger shutdown.", + exc_info=False, + ) + # The underlying channel has been suffering from a retryable error + # for too long, time to give up and shut the streaming pull down. + self._on_rpc_done(exc) + raise def heartbeat(self): """Sends an empty request over the streaming pull RPC. - This always sends over the stream, regardless of if - ``self._UNARY_REQUESTS`` is set or not. + Returns: + bool: If a heartbeat request has actually been sent. """ if self._rpc is not None and self._rpc.is_active: self._rpc.send(gapic_types.StreamingPullRequest()) + return True + + return False def open(self, callback, on_callback_error): """Begin consuming messages. @@ -513,7 +512,7 @@ def open(self, callback, on_callback_error): # Start the stream heartbeater thread. self._heartbeater.start() - def close(self, reason=None): + def close(self, reason=None, await_msg_callbacks=False): """Stop consuming messages and shutdown all helper threads. This method is idempotent. Additional calls will have no effect. @@ -522,6 +521,15 @@ def close(self, reason=None): reason (Any): The reason to close this. If None, this is considered an "intentional" shutdown. This is passed to the callbacks specified via :meth:`add_close_callback`. + + await_msg_callbacks (bool): + If ``True``, the method will wait until all scheduler threads terminate + and only then proceed with the shutdown with the remaining shutdown + tasks, + + If ``False`` (default), the method will shut down the scheduler in a + non-blocking fashion, i.e. it will not wait for the currently executing + scheduler threads to terminate. """ with self._closing: if self._closed: @@ -535,7 +543,9 @@ def close(self, reason=None): # Shutdown all helper threads _LOGGER.debug("Stopping scheduler.") - self._scheduler.shutdown() + dropped_messages = self._scheduler.shutdown( + await_msg_callbacks=await_msg_callbacks + ) self._scheduler = None # Leaser and dispatcher reference each other through the shared @@ -549,11 +559,23 @@ def close(self, reason=None): # because the consumer gets shut down first. _LOGGER.debug("Stopping leaser.") self._leaser.stop() + + total = len(dropped_messages) + len( + self._messages_on_hold._messages_on_hold + ) + _LOGGER.debug(f"NACK-ing all not-yet-dispatched messages (total: {total}).") + messages_to_nack = itertools.chain( + dropped_messages, self._messages_on_hold._messages_on_hold + ) + for msg in messages_to_nack: + msg.nack() + _LOGGER.debug("Stopping dispatcher.") self._dispatcher.stop() self._dispatcher = None # dispatcher terminated, OK to dispose the leaser reference now self._leaser = None + _LOGGER.debug("Stopping heartbeater.") self._heartbeater.stop() self._heartbeater = None diff --git a/google/cloud/pubsub_v1/subscriber/futures.py b/google/cloud/pubsub_v1/subscriber/futures.py index f9fdd76ab..cefe1aa91 100644 --- a/google/cloud/pubsub_v1/subscriber/futures.py +++ b/google/cloud/pubsub_v1/subscriber/futures.py @@ -43,12 +43,23 @@ def _on_close_callback(self, manager, result): else: self.set_exception(result) - def cancel(self): + def cancel(self, await_msg_callbacks=False): """Stops pulling messages and shutdowns the background thread consuming messages. + + Args: + await_msg_callbacks (bool): + If ``True``, the method will block until the background stream and its + helper threads have has been terminated, as well as all currently + executing message callbacks are done processing. + + If ``False`` (default), the method returns immediately after the + background stream and its helper threads have has been terminated, but + some of the message callback threads might still be running at that + point. """ self._cancelled = True - return self._manager.close() + return self._manager.close(await_msg_callbacks=await_msg_callbacks) def cancelled(self): """ diff --git a/google/cloud/pubsub_v1/subscriber/scheduler.py b/google/cloud/pubsub_v1/subscriber/scheduler.py index ef2ef59cb..2690c1fc6 100644 --- a/google/cloud/pubsub_v1/subscriber/scheduler.py +++ b/google/cloud/pubsub_v1/subscriber/scheduler.py @@ -20,7 +20,6 @@ import abc import concurrent.futures -import sys import six from six.moves import queue @@ -58,19 +57,29 @@ def schedule(self, callback, *args, **kwargs): raise NotImplementedError @abc.abstractmethod - def shutdown(self): + def shutdown(self, await_msg_callbacks=False): """Shuts down the scheduler and immediately end all pending callbacks. + + Args: + await_msg_callbacks (bool): + If ``True``, the method will block until all currently executing + callbacks are done processing. If ``False`` (default), the + method will not wait for the currently running callbacks to complete. + + Returns: + List[pubsub_v1.subscriber.message.Message]: + The messages submitted to the scheduler that were not yet dispatched + to their callbacks. + It is assumed that each message was submitted to the scheduler as the + first positional argument to the provided callback. """ raise NotImplementedError def _make_default_thread_pool_executor(): - # Python 2.7 and 3.6+ have the thread_name_prefix argument, which is useful - # for debugging. - executor_kwargs = {} - if sys.version_info[:2] == (2, 7) or sys.version_info >= (3, 6): - executor_kwargs["thread_name_prefix"] = "ThreadPoolExecutor-ThreadScheduler" - return concurrent.futures.ThreadPoolExecutor(max_workers=10, **executor_kwargs) + return concurrent.futures.ThreadPoolExecutor( + max_workers=10, thread_name_prefix="ThreadPoolExecutor-ThreadScheduler" + ) class ThreadScheduler(Scheduler): @@ -110,15 +119,35 @@ def schedule(self, callback, *args, **kwargs): """ self._executor.submit(callback, *args, **kwargs) - def shutdown(self): - """Shuts down the scheduler and immediately end all pending callbacks. + def shutdown(self, await_msg_callbacks=False): + """Shut down the scheduler and immediately end all pending callbacks. + + Args: + await_msg_callbacks (bool): + If ``True``, the method will block until all currently executing + executor threads are done processing. If ``False`` (default), the + method will not wait for the currently running threads to complete. + + Returns: + List[pubsub_v1.subscriber.message.Message]: + The messages submitted to the scheduler that were not yet dispatched + to their callbacks. + It is assumed that each message was submitted to the scheduler as the + first positional argument to the provided callback. """ - # Drop all pending item from the executor. Without this, the executor - # will block until all pending items are complete, which is - # undesirable. + dropped_messages = [] + + # Drop all pending item from the executor. Without this, the executor will also + # try to process any pending work items before termination, which is undesirable. + # + # TODO: Replace the logic below by passing `cancel_futures=True` to shutdown() + # once we only need to support Python 3.9+. try: while True: - self._executor._work_queue.get(block=False) + work_item = self._executor._work_queue.get(block=False) + dropped_messages.append(work_item.args[0]) except queue.Empty: pass - self._executor.shutdown() + + self._executor.shutdown(wait=await_msg_callbacks) + return dropped_messages diff --git a/tests/system.py b/tests/system.py index bbedd9a11..05a91a420 100644 --- a/tests/system.py +++ b/tests/system.py @@ -14,6 +14,7 @@ from __future__ import absolute_import +import concurrent.futures import datetime import itertools import operator as op @@ -609,6 +610,78 @@ def test_streaming_pull_max_messages( finally: subscription_future.cancel() # trigger clean shutdown + def test_streaming_pull_blocking_shutdown( + self, publisher, topic_path, subscriber, subscription_path, cleanup + ): + # Make sure the topic and subscription get deleted. + cleanup.append((publisher.delete_topic, (), {"topic": topic_path})) + cleanup.append( + (subscriber.delete_subscription, (), {"subscription": subscription_path}) + ) + + # The ACK-s are only persisted if *all* messages published in the same batch + # are ACK-ed. We thus publish each message in its own batch so that the backend + # treats all messages' ACKs independently of each other. + publisher.create_topic(name=topic_path) + subscriber.create_subscription(name=subscription_path, topic=topic_path) + _publish_messages(publisher, topic_path, batch_sizes=[1] * 10) + + # Artificially delay message processing, gracefully shutdown the streaming pull + # in the meantime, then verify that those messages were nevertheless processed. + processed_messages = [] + + def callback(message): + time.sleep(15) + processed_messages.append(message.data) + message.ack() + + # Flow control limits should exceed the number of worker threads, so that some + # of the messages will be blocked on waiting for free scheduler threads. + flow_control = pubsub_v1.types.FlowControl(max_messages=5) + executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) + scheduler = pubsub_v1.subscriber.scheduler.ThreadScheduler(executor=executor) + subscription_future = subscriber.subscribe( + subscription_path, + callback=callback, + flow_control=flow_control, + scheduler=scheduler, + ) + + try: + subscription_future.result(timeout=10) # less than the sleep in callback + except exceptions.TimeoutError: + subscription_future.cancel(await_msg_callbacks=True) + + # The shutdown should have waited for the already executing callbacks to finish. + assert len(processed_messages) == 3 + + # The messages that were not processed should have been NACK-ed and we should + # receive them again quite soon. + all_done = threading.Barrier(7 + 1, timeout=5) # +1 because of the main thread + remaining = [] + + def callback2(message): + remaining.append(message.data) + message.ack() + all_done.wait() + + subscription_future = subscriber.subscribe( + subscription_path, callback=callback2 + ) + + try: + all_done.wait() + except threading.BrokenBarrierError: # PRAGMA: no cover + pytest.fail("The remaining messages have not been re-delivered in time.") + finally: + subscription_future.cancel(await_msg_callbacks=False) + + # There should be 7 messages left that were not yet processed and none of them + # should be a message that should have already been sucessfully processed in the + # first streaming pull. + assert len(remaining) == 7 + assert not (set(processed_messages) & set(remaining)) # no re-delivery + @pytest.mark.skipif( "KOKORO_GFILE_DIR" not in os.environ, @@ -790,8 +863,8 @@ def _publish_messages(publisher, topic_path, batch_sizes): publish_futures = [] msg_counter = itertools.count(start=1) - for batch_size in batch_sizes: - msg_batch = _make_messages(count=batch_size) + for batch_num, batch_size in enumerate(batch_sizes, start=1): + msg_batch = _make_messages(count=batch_size, batch_num=batch_num) for msg in msg_batch: future = publisher.publish(topic_path, msg, seq_num=str(next(msg_counter))) publish_futures.append(future) @@ -802,9 +875,10 @@ def _publish_messages(publisher, topic_path, batch_sizes): future.result(timeout=30) -def _make_messages(count): +def _make_messages(count, batch_num): messages = [ - "message {}/{}".format(i, count).encode("utf-8") for i in range(1, count + 1) + f"message {i}/{count} of batch {batch_num}".encode("utf-8") + for i in range(1, count + 1) ] return messages diff --git a/tests/unit/pubsub_v1/subscriber/test_dispatcher.py b/tests/unit/pubsub_v1/subscriber/test_dispatcher.py index 288e4bd18..47c62bab6 100644 --- a/tests/unit/pubsub_v1/subscriber/test_dispatcher.py +++ b/tests/unit/pubsub_v1/subscriber/test_dispatcher.py @@ -29,14 +29,14 @@ @pytest.mark.parametrize( "item,method_name", [ - (requests.AckRequest(0, 0, 0, ""), "ack"), - (requests.DropRequest(0, 0, ""), "drop"), - (requests.LeaseRequest(0, 0, ""), "lease"), - (requests.ModAckRequest(0, 0), "modify_ack_deadline"), - (requests.NackRequest(0, 0, ""), "nack"), + (requests.AckRequest("0", 0, 0, ""), "ack"), + (requests.DropRequest("0", 0, ""), "drop"), + (requests.LeaseRequest("0", 0, ""), "lease"), + (requests.ModAckRequest("0", 0), "modify_ack_deadline"), + (requests.NackRequest("0", 0, ""), "nack"), ], ) -def test_dispatch_callback(item, method_name): +def test_dispatch_callback_active_manager(item, method_name): manager = mock.create_autospec( streaming_pull_manager.StreamingPullManager, instance=True ) @@ -50,16 +50,29 @@ def test_dispatch_callback(item, method_name): method.assert_called_once_with([item]) -def test_dispatch_callback_inactive(): +@pytest.mark.parametrize( + "item,method_name", + [ + (requests.AckRequest("0", 0, 0, ""), "ack"), + (requests.DropRequest("0", 0, ""), "drop"), + (requests.LeaseRequest("0", 0, ""), "lease"), + (requests.ModAckRequest("0", 0), "modify_ack_deadline"), + (requests.NackRequest("0", 0, ""), "nack"), + ], +) +def test_dispatch_callback_inactive_manager(item, method_name): manager = mock.create_autospec( streaming_pull_manager.StreamingPullManager, instance=True ) manager.is_active = False dispatcher_ = dispatcher.Dispatcher(manager, mock.sentinel.queue) - dispatcher_.dispatch_callback([requests.AckRequest(0, 0, 0, "")]) + items = [item] - manager.send.assert_not_called() + with mock.patch.object(dispatcher_, method_name) as method: + dispatcher_.dispatch_callback(items) + + method.assert_called_once_with([item]) def test_ack(): diff --git a/tests/unit/pubsub_v1/subscriber/test_futures_subscriber.py b/tests/unit/pubsub_v1/subscriber/test_futures_subscriber.py index 909337cc8..62a3ea1da 100644 --- a/tests/unit/pubsub_v1/subscriber/test_futures_subscriber.py +++ b/tests/unit/pubsub_v1/subscriber/test_futures_subscriber.py @@ -69,10 +69,18 @@ def test__on_close_callback_future_already_done(self): result = future.result() assert result == "foo" # on close callback was a no-op - def test_cancel(self): + def test_cancel_default_nonblocking_manager_shutdown(self): future = self.make_future() future.cancel() - future._manager.close.assert_called_once() + future._manager.close.assert_called_once_with(await_msg_callbacks=False) + assert future.cancelled() + + def test_cancel_blocking_manager_shutdown(self): + future = self.make_future() + + future.cancel(await_msg_callbacks=True) + + future._manager.close.assert_called_once_with(await_msg_callbacks=True) assert future.cancelled() diff --git a/tests/unit/pubsub_v1/subscriber/test_heartbeater.py b/tests/unit/pubsub_v1/subscriber/test_heartbeater.py index 8f5049691..1a52af231 100644 --- a/tests/unit/pubsub_v1/subscriber/test_heartbeater.py +++ b/tests/unit/pubsub_v1/subscriber/test_heartbeater.py @@ -22,22 +22,44 @@ import pytest -def test_heartbeat_inactive(caplog): - caplog.set_level(logging.INFO) +def test_heartbeat_inactive_manager_active_rpc(caplog): + caplog.set_level(logging.DEBUG) + + manager = mock.create_autospec( + streaming_pull_manager.StreamingPullManager, instance=True + ) + manager.is_active = False + manager.heartbeat.return_value = True # because of active rpc + + heartbeater_ = heartbeater.Heartbeater(manager) + make_sleep_mark_event_as_done(heartbeater_) + + heartbeater_.heartbeat() + + assert "Sent heartbeat" in caplog.text + assert "exiting" in caplog.text + + +def test_heartbeat_inactive_manager_inactive_rpc(caplog): + caplog.set_level(logging.DEBUG) + manager = mock.create_autospec( streaming_pull_manager.StreamingPullManager, instance=True ) manager.is_active = False + manager.heartbeat.return_value = False # because of inactive rpc heartbeater_ = heartbeater.Heartbeater(manager) + make_sleep_mark_event_as_done(heartbeater_) heartbeater_.heartbeat() + assert "Sent heartbeat" not in caplog.text assert "exiting" in caplog.text def test_heartbeat_stopped(caplog): - caplog.set_level(logging.INFO) + caplog.set_level(logging.DEBUG) manager = mock.create_autospec( streaming_pull_manager.StreamingPullManager, instance=True ) @@ -47,17 +69,18 @@ def test_heartbeat_stopped(caplog): heartbeater_.heartbeat() + assert "Sent heartbeat" not in caplog.text assert "exiting" in caplog.text -def make_sleep_mark_manager_as_inactive(heartbeater): - # Make sleep mark the manager as inactive so that heartbeat() +def make_sleep_mark_event_as_done(heartbeater): + # Make sleep actually trigger the done event so that heartbeat() # exits at the end of the first run. - def trigger_inactive(timeout): + def trigger_done(timeout): assert timeout - heartbeater._manager.is_active = False + heartbeater._stop_event.set() - heartbeater._stop_event.wait = trigger_inactive + heartbeater._stop_event.wait = trigger_done def test_heartbeat_once(): @@ -65,7 +88,7 @@ def test_heartbeat_once(): streaming_pull_manager.StreamingPullManager, instance=True ) heartbeater_ = heartbeater.Heartbeater(manager) - make_sleep_mark_manager_as_inactive(heartbeater_) + make_sleep_mark_event_as_done(heartbeater_) heartbeater_.heartbeat() diff --git a/tests/unit/pubsub_v1/subscriber/test_leaser.py b/tests/unit/pubsub_v1/subscriber/test_leaser.py index 17409cb3f..2ecc0b9f3 100644 --- a/tests/unit/pubsub_v1/subscriber/test_leaser.py +++ b/tests/unit/pubsub_v1/subscriber/test_leaser.py @@ -88,15 +88,21 @@ def create_manager(flow_control=types.FlowControl()): return manager -def test_maintain_leases_inactive(caplog): +def test_maintain_leases_inactive_manager(caplog): caplog.set_level(logging.INFO) manager = create_manager() manager.is_active = False leaser_ = leaser.Leaser(manager) + make_sleep_mark_event_as_done(leaser_) + leaser_.add( + [requests.LeaseRequest(ack_id="my_ack_ID", byte_size=42, ordering_key="")] + ) leaser_.maintain_leases() + # Leases should still be maintained even if the manager is inactive. + manager.dispatcher.modify_ack_deadline.assert_called() assert "exiting" in caplog.text @@ -112,20 +118,20 @@ def test_maintain_leases_stopped(caplog): assert "exiting" in caplog.text -def make_sleep_mark_manager_as_inactive(leaser): - # Make sleep mark the manager as inactive so that maintain_leases +def make_sleep_mark_event_as_done(leaser): + # Make sleep actually trigger the done event so that heartbeat() # exits at the end of the first run. - def trigger_inactive(timeout): + def trigger_done(timeout): assert 0 < timeout < 10 - leaser._manager.is_active = False + leaser._stop_event.set() - leaser._stop_event.wait = trigger_inactive + leaser._stop_event.wait = trigger_done def test_maintain_leases_ack_ids(): manager = create_manager() leaser_ = leaser.Leaser(manager) - make_sleep_mark_manager_as_inactive(leaser_) + make_sleep_mark_event_as_done(leaser_) leaser_.add( [requests.LeaseRequest(ack_id="my ack id", byte_size=50, ordering_key="")] ) @@ -140,7 +146,7 @@ def test_maintain_leases_ack_ids(): def test_maintain_leases_no_ack_ids(): manager = create_manager() leaser_ = leaser.Leaser(manager) - make_sleep_mark_manager_as_inactive(leaser_) + make_sleep_mark_event_as_done(leaser_) leaser_.maintain_leases() @@ -151,7 +157,7 @@ def test_maintain_leases_no_ack_ids(): def test_maintain_leases_outdated_items(time): manager = create_manager() leaser_ = leaser.Leaser(manager) - make_sleep_mark_manager_as_inactive(leaser_) + make_sleep_mark_event_as_done(leaser_) # Add and start expiry timer at the beginning of the timeline. time.return_value = 0 diff --git a/tests/unit/pubsub_v1/subscriber/test_scheduler.py b/tests/unit/pubsub_v1/subscriber/test_scheduler.py index 774d0d63e..ede7c6b2d 100644 --- a/tests/unit/pubsub_v1/subscriber/test_scheduler.py +++ b/tests/unit/pubsub_v1/subscriber/test_scheduler.py @@ -14,6 +14,7 @@ import concurrent.futures import threading +import time import mock from six.moves import queue @@ -38,19 +39,89 @@ def test_constructor_options(): assert scheduler_._executor == mock.sentinel.executor -def test_schedule(): +def test_schedule_executes_submitted_items(): called_with = [] - called = threading.Event() + callback_done_twice = threading.Barrier(3) # 3 == 2x callback + 1x main thread def callback(*args, **kwargs): - called_with.append((args, kwargs)) - called.set() + called_with.append((args, kwargs)) # appends are thread-safe + callback_done_twice.wait() scheduler_ = scheduler.ThreadScheduler() scheduler_.schedule(callback, "arg1", kwarg1="meep") + scheduler_.schedule(callback, "arg2", kwarg2="boop") - called.wait() - scheduler_.shutdown() + callback_done_twice.wait(timeout=3.0) + result = scheduler_.shutdown() - assert called_with == [(("arg1",), {"kwarg1": "meep"})] + assert result == [] # no scheduled items dropped + + expected_calls = [(("arg1",), {"kwarg1": "meep"}), (("arg2",), {"kwarg2": "boop"})] + assert sorted(called_with) == expected_calls + + +def test_shutdown_nonblocking_by_default(): + called_with = [] + at_least_one_called = threading.Event() + at_least_one_completed = threading.Event() + + def callback(message): + called_with.append(message) # appends are thread-safe + at_least_one_called.set() + time.sleep(1.0) + at_least_one_completed.set() + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + scheduler_ = scheduler.ThreadScheduler(executor=executor) + + scheduler_.schedule(callback, "message_1") + scheduler_.schedule(callback, "message_2") + + at_least_one_called.wait() + dropped = scheduler_.shutdown() + + assert len(called_with) == 1 + assert called_with[0] in {"message_1", "message_2"} + + assert len(dropped) == 1 + assert dropped[0] in {"message_1", "message_2"} + assert dropped[0] != called_with[0] # the dropped message was not the processed one + + err_msg = ( + "Shutdown should not have waited " + "for the already running callbacks to complete." + ) + assert not at_least_one_completed.is_set(), err_msg + + +def test_shutdown_blocking_awaits_running_callbacks(): + called_with = [] + at_least_one_called = threading.Event() + at_least_one_completed = threading.Event() + + def callback(message): + called_with.append(message) # appends are thread-safe + at_least_one_called.set() + time.sleep(1.0) + at_least_one_completed.set() + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + scheduler_ = scheduler.ThreadScheduler(executor=executor) + + scheduler_.schedule(callback, "message_1") + scheduler_.schedule(callback, "message_2") + + at_least_one_called.wait() + dropped = scheduler_.shutdown(await_msg_callbacks=True) + + assert len(called_with) == 1 + assert called_with[0] in {"message_1", "message_2"} + + # The work items that have not been started yet should still be dropped. + assert len(dropped) == 1 + assert dropped[0] in {"message_1", "message_2"} + assert dropped[0] != called_with[0] # the dropped message was not the processed one + + err_msg = "Shutdown did not wait for the already running callbacks to complete." + assert at_least_one_completed.is_set(), err_msg diff --git a/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py b/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py index 242c0804a..a6454f853 100644 --- a/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py +++ b/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import logging import threading import time @@ -372,7 +373,6 @@ def test__maybe_release_messages_negative_on_hold_bytes_warning(caplog): def test_send_unary(): manager = make_manager() - manager._UNARY_REQUESTS = True manager.send( gapic_types.StreamingPullRequest( @@ -405,7 +405,6 @@ def test_send_unary(): def test_send_unary_empty(): manager = make_manager() - manager._UNARY_REQUESTS = True manager.send(gapic_types.StreamingPullRequest()) @@ -417,7 +416,6 @@ def test_send_unary_api_call_error(caplog): caplog.set_level(logging.DEBUG) manager = make_manager() - manager._UNARY_REQUESTS = True error = exceptions.GoogleAPICallError("The front fell off") manager._client.acknowledge.side_effect = error @@ -431,7 +429,6 @@ def test_send_unary_retry_error(caplog): caplog.set_level(logging.DEBUG) manager, _, _, _, _, _ = make_running_manager() - manager._UNARY_REQUESTS = True error = exceptions.RetryError( "Too long a transient error", cause=Exception("Out of time!") @@ -445,24 +442,15 @@ def test_send_unary_retry_error(caplog): assert "signaled streaming pull manager shutdown" in caplog.text -def test_send_streaming(): - manager = make_manager() - manager._UNARY_REQUESTS = False - manager._rpc = mock.create_autospec(bidi.BidiRpc, instance=True) - - manager.send(mock.sentinel.request) - - manager._rpc.send.assert_called_once_with(mock.sentinel.request) - - def test_heartbeat(): manager = make_manager() manager._rpc = mock.create_autospec(bidi.BidiRpc, instance=True) manager._rpc.is_active = True - manager.heartbeat() + result = manager.heartbeat() manager._rpc.send.assert_called_once_with(gapic_types.StreamingPullRequest()) + assert result def test_heartbeat_inactive(): @@ -472,7 +460,8 @@ def test_heartbeat_inactive(): manager.heartbeat() - manager._rpc.send.assert_not_called() + result = manager._rpc.send.assert_not_called() + assert not result @mock.patch("google.api_core.bidi.ResumableBidiRpc", autospec=True) @@ -632,14 +621,14 @@ def _do_work(self): while not self._stop: try: self._manager.leaser.add([mock.Mock()]) - except Exception as exc: + except Exception as exc: # pragma: NO COVER self._error_callback(exc) time.sleep(0.1) # also try to interact with the leaser after the stop flag has been set try: self._manager.leaser.remove([mock.Mock()]) - except Exception as exc: + except Exception as exc: # pragma: NO COVER self._error_callback(exc) @@ -666,6 +655,27 @@ def test_close_callbacks(): callback.assert_called_once_with(manager, "meep") +def test_close_nacks_internally_queued_messages(): + nacked_messages = [] + + def fake_nack(self): + nacked_messages.append(self.data) + + MockMsg = functools.partial(mock.create_autospec, message.Message, instance=True) + messages = [MockMsg(data=b"msg1"), MockMsg(data=b"msg2"), MockMsg(data=b"msg3")] + for msg in messages: + msg.nack = stdlib_types.MethodType(fake_nack, msg) + + manager, _, _, _, _, _ = make_running_manager() + dropped_by_scheduler = messages[:2] + manager._scheduler.shutdown.return_value = dropped_by_scheduler + manager._messages_on_hold._messages_on_hold.append(messages[2]) + + manager.close() + + assert sorted(nacked_messages) == [b"msg1", b"msg2", b"msg3"] + + def test__get_initial_request(): manager = make_manager() manager._leaser = mock.create_autospec(leaser.Leaser, instance=True) @@ -979,3 +989,15 @@ def test_activate_ordering_keys(): manager._messages_on_hold.activate_ordering_keys.assert_called_once_with( ["key1", "key2"], mock.ANY ) + + +def test_activate_ordering_keys_stopped_scheduler(): + manager = make_manager() + manager._messages_on_hold = mock.create_autospec( + messages_on_hold.MessagesOnHold, instance=True + ) + manager._scheduler = None + + manager.activate_ordering_keys(["key1", "key2"]) + + manager._messages_on_hold.activate_ordering_keys.assert_not_called()