5757 'google-cloud-python-systest' )
5858DATABASE_ID = 'test_database'
5959EXISTING_INSTANCES = []
60+ COUNTERS_TABLE = 'counters'
61+ COUNTERS_COLUMNS = ('name' , 'value' )
6062
6163
6264class Config (object ):
@@ -360,11 +362,6 @@ class TestSessionAPI(unittest.TestCase, _TestData):
360362 'description' ,
361363 'exactly_hwhen' ,
362364 )
363- COUNTERS_TABLE = 'counters'
364- COUNTERS_COLUMNS = (
365- 'name' ,
366- 'value' ,
367- )
368365 SOME_DATE = datetime .date (2011 , 1 , 17 )
369366 SOME_TIME = datetime .datetime (1989 , 1 , 17 , 17 , 59 , 12 , 345612 )
370367 NANO_TIME = TimestampWithNanoseconds (1995 , 8 , 31 , nanosecond = 987654321 )
@@ -553,9 +550,7 @@ def _transaction_concurrency_helper(self, unit_of_work, pkey):
553550
554551 with session .batch () as batch :
555552 batch .insert_or_update (
556- self .COUNTERS_TABLE ,
557- self .COUNTERS_COLUMNS ,
558- [[pkey , INITIAL_VALUE ]])
553+ COUNTERS_TABLE , COUNTERS_COLUMNS , [[pkey , INITIAL_VALUE ]])
559554
560555 # We don't want to run the threads' transactions in the current
561556 # session, which would fail.
@@ -581,21 +576,19 @@ def _transaction_concurrency_helper(self, unit_of_work, pkey):
581576
582577 keyset = KeySet (keys = [(pkey ,)])
583578 rows = list (session .read (
584- self . COUNTERS_TABLE , self . COUNTERS_COLUMNS , keyset ))
579+ COUNTERS_TABLE , COUNTERS_COLUMNS , keyset ))
585580 self .assertEqual (len (rows ), 1 )
586581 _ , value = rows [0 ]
587582 self .assertEqual (value , INITIAL_VALUE + len (threads ))
588583
589584 def _read_w_concurrent_update (self , transaction , pkey ):
590585 keyset = KeySet (keys = [(pkey ,)])
591586 rows = list (transaction .read (
592- self . COUNTERS_TABLE , self . COUNTERS_COLUMNS , keyset ))
587+ COUNTERS_TABLE , COUNTERS_COLUMNS , keyset ))
593588 self .assertEqual (len (rows ), 1 )
594589 pkey , value = rows [0 ]
595590 transaction .update (
596- self .COUNTERS_TABLE ,
597- self .COUNTERS_COLUMNS ,
598- [[pkey , value + 1 ]])
591+ COUNTERS_TABLE , COUNTERS_COLUMNS , [[pkey , value + 1 ]])
599592
600593 def test_transaction_read_w_concurrent_updates (self ):
601594 PKEY = 'read_w_concurrent_updates'
@@ -612,15 +605,43 @@ def _query_w_concurrent_update(self, transaction, pkey):
612605 self .assertEqual (len (rows ), 1 )
613606 pkey , value = rows [0 ]
614607 transaction .update (
615- self .COUNTERS_TABLE ,
616- self .COUNTERS_COLUMNS ,
617- [[pkey , value + 1 ]])
608+ COUNTERS_TABLE , COUNTERS_COLUMNS , [[pkey , value + 1 ]])
618609
619610 def test_transaction_query_w_concurrent_updates (self ):
620611 PKEY = 'query_w_concurrent_updates'
621612 self ._transaction_concurrency_helper (
622613 self ._query_w_concurrent_update , PKEY )
623614
615+ def test_transaction_read_w_abort (self ):
616+
617+ retry = RetryInstanceState (_has_all_ddl )
618+ retry (self ._db .reload )()
619+
620+ session = self ._db .session ()
621+ session .create ()
622+
623+ trigger = _ReadAbortTrigger ()
624+
625+ with session .batch () as batch :
626+ batch .insert_or_update (
627+ COUNTERS_TABLE ,
628+ COUNTERS_COLUMNS ,
629+ [[trigger .KEY1 , 0 ], [trigger .KEY2 , 0 ]])
630+
631+ provoker = threading .Thread (target = trigger .provoke_abort )
632+ handler = threading .Thread (target = trigger .handle_abort )
633+
634+ provoker .start ()
635+ with trigger .provoker_started :
636+ trigger .provoker_started .wait ()
637+
638+ handler .start ()
639+ with trigger .handler_done :
640+ trigger .handler_done .wait ()
641+
642+ provoker .join ()
643+ handler .join ()
644+
624645 @staticmethod
625646 def _row_data (max_index ):
626647 for index in range (max_index ):
@@ -1102,3 +1123,103 @@ def __init__(self, db):
11021123
11031124 def delete (self ):
11041125 self ._db .drop ()
1126+
1127+
1128+ class _ReadAbortTrigger (object ):
1129+ """Helper for tests provoking abort-during-read."""
1130+
1131+ KEY1 = 'key1'
1132+ KEY2 = 'key2'
1133+
1134+ def __init__ (self ):
1135+ self .provoker_started = threading .Condition ()
1136+ self .provoker_done = threading .Condition ()
1137+ self .handler_running = threading .Condition ()
1138+ self .handler_done = threading .Condition ()
1139+
1140+ self ._xxx_provoker = open ('/tmp/xxx_provoker.log' , 'w' )
1141+ self ._xxx_handler = open ('/tmp/xxx_handler.log' , 'w' )
1142+
1143+ def _log_provoker (self , msg ):
1144+ self ._xxx_provoker .writelines ([msg ])
1145+ self ._xxx_provoker .flush ()
1146+
1147+ def _log_handler (self , msg ):
1148+ self ._xxx_handler .writelines ([msg ])
1149+ self ._xxx_handler .flush ()
1150+
1151+ def _provoke_abort_unit_of_work (self , transaction ):
1152+ log = self ._log_provoker
1153+ log ('UoW: initial read starting' )
1154+ keyset = KeySet (keys = [(self .KEY1 ,)])
1155+ rows = list (
1156+ transaction .read (COUNTERS_TABLE , COUNTERS_COLUMNS , keyset ))
1157+ log ('UoW: initial read complete' )
1158+
1159+ assert len (rows ) == 1
1160+ row = rows [0 ]
1161+ value = row [1 ]
1162+
1163+ log ('UoW: notifying' )
1164+ with self .provoker_started :
1165+ self .provoker_started .notify ()
1166+
1167+ log ('UoW: waiting for handler' )
1168+ with self .handler_running :
1169+ self .handler_running .wait ()
1170+
1171+ log ('UoW: updating' )
1172+ transaction .update (
1173+ COUNTERS_TABLE , COUNTERS_COLUMNS , [[self .KEY1 , value + 1 ]])
1174+ log ('UoW: committing' )
1175+
1176+ def provoke_abort (self , database ):
1177+ log = self ._log_provoker
1178+ log ('Thread: starting' )
1179+ database .run_in_transaction (self ._provoke_abort_unit_of_work )
1180+ log ('Thread: notifying' )
1181+ self .provoker_done .notify ()
1182+ log ('Thread: exiting' )
1183+
1184+ def _handle_abort_unit_of_work (self , transaction ):
1185+ log = self ._log_handler
1186+ log ('UoW: initial read starting' )
1187+ keyset_1 = KeySet (keys = [(self .KEY1 ,)])
1188+ rows_1 = list (
1189+ transaction .read (COUNTERS_TABLE , COUNTERS_COLUMNS , keyset_1 ))
1190+ log ('UoW: initial read complete' )
1191+
1192+ assert len (rows_1 ) == 1
1193+ row_1 = rows_1 [0 ]
1194+ value_1 = row_1 [1 ]
1195+
1196+ log ('UoW: notifying' )
1197+ with self .handler_running :
1198+ self .handler_running .notify ()
1199+
1200+ log ('UoW: waiting for provider' )
1201+ with self .provoker_done :
1202+ self .provoker_done .wait ()
1203+
1204+ log ('UoW: second read starting' )
1205+ keyset_2 = KeySet (keys = [(self .KEY2 ,)])
1206+ rows_2 = list (
1207+ transaction .read (COUNTERS_TABLE , COUNTERS_COLUMNS , keyset_2 ))
1208+ log ('UoW: second read complete' )
1209+
1210+ assert len (rows_2 ) == 1
1211+ row_2 = rows_2 [0 ]
1212+ value_2 = row_2 [1 ]
1213+
1214+ log ('UoW: updating' )
1215+ transaction .update (
1216+ COUNTERS_TABLE , COUNTERS_COLUMNS , [[self .KEY2 , value_1 + value_2 ]])
1217+ log ('UoW: committing' )
1218+
1219+ def handle_abort (self , database ):
1220+ log = self ._log_handler
1221+ log ('Thread: starting' )
1222+ database .run_in_transaction (self ._handle_abort_unit_of_work )
1223+ log ('Thread: notifying' )
1224+ self .handler_done .notify ()
1225+ log ('Thread: exiting' )
0 commit comments