1717import operator
1818import os
1919import struct
20+ import threading
2021import unittest
2122
2223from google .cloud .proto .spanner .v1 .type_pb2 import ARRAY
@@ -358,6 +359,11 @@ class TestSessionAPI(unittest.TestCase, _TestData):
358359 'description' ,
359360 'exactly_hwhen' ,
360361 )
362+ COUNTERS_TABLE = 'counters'
363+ COUNTERS_COLUMNS = (
364+ 'name' ,
365+ 'value' ,
366+ )
361367 SOME_DATE = datetime .date (2011 , 1 , 17 )
362368 SOME_TIME = datetime .datetime (1989 , 1 , 17 , 17 , 59 , 12 , 345612 )
363369 NANO_TIME = TimestampWithNanoseconds (1995 , 8 , 31 , nanosecond = 987654321 )
@@ -482,6 +488,31 @@ def test_transaction_read_and_insert_then_rollback(self):
482488 rows = list (session .read (self .TABLE , self .COLUMNS , self .ALL ))
483489 self .assertEqual (rows , [])
484490
491+ def _transaction_read_then_raise (self , transaction ):
492+ rows = list (transaction .read (self .TABLE , self .COLUMNS , self .ALL ))
493+ self .assertEqual (len (rows ), 0 )
494+ transaction .insert (self .TABLE , self .COLUMNS , self .ROW_DATA )
495+ raise CustomException ()
496+
497+ @RetryErrors (exception = GrpcRendezvous )
498+ def test_transaction_read_and_insert_then_execption (self ):
499+ retry = RetryInstanceState (_has_all_ddl )
500+ retry (self ._db .reload )()
501+
502+ session = self ._db .session ()
503+ session .create ()
504+ self .to_delete .append (session )
505+
506+ with session .batch () as batch :
507+ batch .delete (self .TABLE , self .ALL )
508+
509+ with self .assertRaises (CustomException ):
510+ session .run_in_transaction (self ._transaction_read_then_raise )
511+
512+ # Transaction was rolled back.
513+ rows = list (session .read (self .TABLE , self .COLUMNS , self .ALL ))
514+ self .assertEqual (rows , [])
515+
485516 @RetryErrors (exception = GrpcRendezvous )
486517 def test_transaction_read_and_insert_or_update_then_commit (self ):
487518 retry = RetryInstanceState (_has_all_ddl )
@@ -508,6 +539,87 @@ def test_transaction_read_and_insert_or_update_then_commit(self):
508539 rows = list (session .read (self .TABLE , self .COLUMNS , self .ALL ))
509540 self ._check_row_data (rows )
510541
542+ def _transaction_concurrency_helper (self , unit_of_work , pkey ):
543+ INITIAL_VALUE = 123
544+ NUM_THREADS = 3 # conforms to equivalent Java systest.
545+
546+ retry = RetryInstanceState (_has_all_ddl )
547+ retry (self ._db .reload )()
548+
549+ session = self ._db .session ()
550+ session .create ()
551+ self .to_delete .append (session )
552+
553+ with session .batch () as batch :
554+ batch .insert_or_update (
555+ self .COUNTERS_TABLE ,
556+ self .COUNTERS_COLUMNS ,
557+ [[pkey , INITIAL_VALUE ]])
558+
559+ # We don't want to run the threads' transactions in the current
560+ # session, which would fail.
561+ txn_sessions = []
562+
563+ for _ in range (NUM_THREADS ):
564+ txn_session = self ._db .session ()
565+ txn_sessions .append (txn_session )
566+ txn_session .create ()
567+ self .to_delete .append (txn_session )
568+
569+ threads = [
570+ threading .Thread (
571+ target = txn_session .run_in_transaction ,
572+ args = (unit_of_work , pkey ))
573+ for txn_session in txn_sessions ]
574+
575+ for thread in threads :
576+ thread .start ()
577+
578+ for thread in threads :
579+ thread .join ()
580+
581+ keyset = KeySet (keys = [(pkey ,)])
582+ rows = list (session .read (
583+ self .COUNTERS_TABLE , self .COUNTERS_COLUMNS , keyset ))
584+ self .assertEqual (len (rows ), 1 )
585+ _ , value = rows [0 ]
586+ self .assertEqual (value , INITIAL_VALUE + len (threads ))
587+
588+ def _read_w_concurrent_update (self , transaction , pkey ):
589+ keyset = KeySet (keys = [(pkey ,)])
590+ rows = list (transaction .read (
591+ self .COUNTERS_TABLE , self .COUNTERS_COLUMNS , keyset ))
592+ self .assertEqual (len (rows ), 1 )
593+ pkey , value = rows [0 ]
594+ transaction .update (
595+ self .COUNTERS_TABLE ,
596+ self .COUNTERS_COLUMNS ,
597+ [[pkey , value + 1 ]])
598+
599+ def test_transaction_read_w_concurrent_updates (self ):
600+ PKEY = 'read_w_concurrent_updates'
601+ self ._transaction_concurrency_helper (
602+ self ._read_w_concurrent_update , PKEY )
603+
604+ def _query_w_concurrent_update (self , transaction , pkey ):
605+ SQL = 'SELECT * FROM counters WHERE name = @name'
606+ rows = list (transaction .execute_sql (
607+ SQL ,
608+ params = {'name' : pkey },
609+ param_types = {'name' : Type (code = STRING )},
610+ ))
611+ self .assertEqual (len (rows ), 1 )
612+ pkey , value = rows [0 ]
613+ transaction .update (
614+ self .COUNTERS_TABLE ,
615+ self .COUNTERS_COLUMNS ,
616+ [[pkey , value + 1 ]])
617+
618+ def test_transaction_query_w_concurrent_updates (self ):
619+ PKEY = 'query_w_concurrent_updates'
620+ self ._transaction_concurrency_helper (
621+ self ._query_w_concurrent_update , PKEY )
622+
511623 @staticmethod
512624 def _row_data (max_index ):
513625 for index in range (max_index ):
@@ -910,6 +1022,10 @@ def test_four_meg(self):
9101022 self ._verify_two_columns (FOUR_MEG )
9111023
9121024
1025+ class CustomException (Exception ):
1026+ """Placeholder for any user-defined exception."""
1027+
1028+
9131029class _DatabaseDropper (object ):
9141030 """Helper for cleaning up databases created on-the-fly."""
9151031
0 commit comments