1414
1515from __future__ import absolute_import
1616
17+ import concurrent .futures
1718import datetime
1819import itertools
1920import operator as op
@@ -609,6 +610,78 @@ def test_streaming_pull_max_messages(
609610 finally :
610611 subscription_future .cancel () # trigger clean shutdown
611612
613+ def test_streaming_pull_blocking_shutdown (
614+ self , publisher , topic_path , subscriber , subscription_path , cleanup
615+ ):
616+ # Make sure the topic and subscription get deleted.
617+ cleanup .append ((publisher .delete_topic , (), {"topic" : topic_path }))
618+ cleanup .append (
619+ (subscriber .delete_subscription , (), {"subscription" : subscription_path })
620+ )
621+
622+ # The ACK-s are only persisted if *all* messages published in the same batch
623+ # are ACK-ed. We thus publish each message in its own batch so that the backend
624+ # treats all messages' ACKs independently of each other.
625+ publisher .create_topic (name = topic_path )
626+ subscriber .create_subscription (name = subscription_path , topic = topic_path )
627+ _publish_messages (publisher , topic_path , batch_sizes = [1 ] * 10 )
628+
629+ # Artificially delay message processing, gracefully shutdown the streaming pull
630+ # in the meantime, then verify that those messages were nevertheless processed.
631+ processed_messages = []
632+
633+ def callback (message ):
634+ time .sleep (15 )
635+ processed_messages .append (message .data )
636+ message .ack ()
637+
638+ # Flow control limits should exceed the number of worker threads, so that some
639+ # of the messages will be blocked on waiting for free scheduler threads.
640+ flow_control = pubsub_v1 .types .FlowControl (max_messages = 5 )
641+ executor = concurrent .futures .ThreadPoolExecutor (max_workers = 3 )
642+ scheduler = pubsub_v1 .subscriber .scheduler .ThreadScheduler (executor = executor )
643+ subscription_future = subscriber .subscribe (
644+ subscription_path ,
645+ callback = callback ,
646+ flow_control = flow_control ,
647+ scheduler = scheduler ,
648+ )
649+
650+ try :
651+ subscription_future .result (timeout = 10 ) # less than the sleep in callback
652+ except exceptions .TimeoutError :
653+ subscription_future .cancel (await_msg_callbacks = True )
654+
655+ # The shutdown should have waited for the already executing callbacks to finish.
656+ assert len (processed_messages ) == 3
657+
658+ # The messages that were not processed should have been NACK-ed and we should
659+ # receive them again quite soon.
660+ all_done = threading .Barrier (7 + 1 , timeout = 5 ) # +1 because of the main thread
661+ remaining = []
662+
663+ def callback2 (message ):
664+ remaining .append (message .data )
665+ message .ack ()
666+ all_done .wait ()
667+
668+ subscription_future = subscriber .subscribe (
669+ subscription_path , callback = callback2
670+ )
671+
672+ try :
673+ all_done .wait ()
674+ except threading .BrokenBarrierError : # PRAGMA: no cover
675+ pytest .fail ("The remaining messages have not been re-delivered in time." )
676+ finally :
677+ subscription_future .cancel (await_msg_callbacks = False )
678+
679+ # There should be 7 messages left that were not yet processed and none of them
680+ # should be a message that should have already been sucessfully processed in the
681+ # first streaming pull.
682+ assert len (remaining ) == 7
683+ assert not (set (processed_messages ) & set (remaining )) # no re-delivery
684+
612685
613686@pytest .mark .skipif (
614687 "KOKORO_GFILE_DIR" not in os .environ ,
@@ -790,8 +863,8 @@ def _publish_messages(publisher, topic_path, batch_sizes):
790863 publish_futures = []
791864 msg_counter = itertools .count (start = 1 )
792865
793- for batch_size in batch_sizes :
794- msg_batch = _make_messages (count = batch_size )
866+ for batch_num , batch_size in enumerate ( batch_sizes , start = 1 ) :
867+ msg_batch = _make_messages (count = batch_size , batch_num = batch_num )
795868 for msg in msg_batch :
796869 future = publisher .publish (topic_path , msg , seq_num = str (next (msg_counter )))
797870 publish_futures .append (future )
@@ -802,9 +875,10 @@ def _publish_messages(publisher, topic_path, batch_sizes):
802875 future .result (timeout = 30 )
803876
804877
805- def _make_messages (count ):
878+ def _make_messages (count , batch_num ):
806879 messages = [
807- "message {}/{}" .format (i , count ).encode ("utf-8" ) for i in range (1 , count + 1 )
880+ f"message { i } /{ count } of batch { batch_num } " .encode ("utf-8" )
881+ for i in range (1 , count + 1 )
808882 ]
809883 return messages
810884
0 commit comments