@@ -183,6 +183,10 @@ def test_close(self):
183183 mock_transaction .rollback = mock_rollback = mock .MagicMock ()
184184 connection .close ()
185185 mock_rollback .assert_called_once_with ()
186+ connection ._transaction = mock .MagicMock ()
187+ connection ._own_pool = False
188+ connection .close ()
189+ self .assertTrue (connection .is_closed )
186190
187191 @mock .patch .object (warnings , "warn" )
188192 def test_commit (self , mock_warn ):
@@ -379,6 +383,25 @@ def test_run_statement_dont_remember_retried_statements(self):
379383
380384 self .assertEqual (len (connection ._statements ), 0 )
381385
386+ def test_run_statement_w_heterogenous_insert_statements (self ):
387+ """Check that Connection executed heterogenous insert statements."""
388+ from google .cloud .spanner_dbapi .checksum import ResultsChecksum
389+ from google .cloud .spanner_dbapi .cursor import Statement
390+
391+ sql = "INSERT INTO T (f1, f2) VALUES (1, 2)"
392+ params = None
393+ param_types = None
394+
395+ connection = self ._make_connection ()
396+
397+ statement = Statement (sql , params , param_types , ResultsChecksum (), True )
398+ with mock .patch (
399+ "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout"
400+ ):
401+ connection .run_statement (statement , retried = True )
402+
403+ self .assertEqual (len (connection ._statements ), 0 )
404+
382405 def test_run_statement_w_homogeneous_insert_statements (self ):
383406 """Check that Connection executed homogeneous insert statements."""
384407 from google .cloud .spanner_dbapi .checksum import ResultsChecksum
@@ -582,3 +605,132 @@ def test_retry_aborted_retry(self):
582605 mock .call (statement , retried = True ),
583606 )
584607 )
608+
609+ def test_retry_transaction_raise_max_internal_retries (self ):
610+ """Check retrying raise an error of max internal retries."""
611+ from google .cloud .spanner_dbapi import connection as conn
612+ from google .cloud .spanner_dbapi .checksum import ResultsChecksum
613+ from google .cloud .spanner_dbapi .cursor import Statement
614+
615+ conn .MAX_INTERNAL_RETRIES = 0
616+ row = ["field1" , "field2" ]
617+ connection = self ._make_connection ()
618+
619+ checksum = ResultsChecksum ()
620+ checksum .consume_result (row )
621+
622+ statement = Statement ("SELECT 1" , [], {}, checksum , False )
623+ connection ._statements .append (statement )
624+
625+ with self .assertRaises (Exception ):
626+ connection .retry_transaction ()
627+
628+ conn .MAX_INTERNAL_RETRIES = 50
629+
630+ def test_retry_aborted_retry_without_delay (self ):
631+ """
632+ Check that in case of a retried transaction failed,
633+ the connection will retry it once again.
634+ """
635+ from google .api_core .exceptions import Aborted
636+ from google .cloud .spanner_dbapi .checksum import ResultsChecksum
637+ from google .cloud .spanner_dbapi .connection import connect
638+ from google .cloud .spanner_dbapi .cursor import Statement
639+
640+ row = ["field1" , "field2" ]
641+
642+ with mock .patch (
643+ "google.cloud.spanner_v1.instance.Instance.exists" , return_value = True ,
644+ ):
645+ with mock .patch (
646+ "google.cloud.spanner_v1.database.Database.exists" , return_value = True ,
647+ ):
648+ connection = connect ("test-instance" , "test-database" )
649+
650+ cursor = connection .cursor ()
651+ cursor ._checksum = ResultsChecksum ()
652+ cursor ._checksum .consume_result (row )
653+
654+ statement = Statement ("SELECT 1" , [], {}, cursor ._checksum , False )
655+ connection ._statements .append (statement )
656+
657+ metadata_mock = mock .Mock ()
658+ metadata_mock .trailing_metadata .return_value = {}
659+
660+ with mock .patch (
661+ "google.cloud.spanner_dbapi.connection.Connection.run_statement" ,
662+ side_effect = (
663+ Aborted ("Aborted" , errors = [metadata_mock ]),
664+ ([row ], ResultsChecksum ()),
665+ ),
666+ ) as retry_mock :
667+ with mock .patch (
668+ "google.cloud.spanner_dbapi.connection._get_retry_delay" ,
669+ return_value = False ,
670+ ):
671+ connection .retry_transaction ()
672+
673+ retry_mock .assert_has_calls (
674+ (
675+ mock .call (statement , retried = True ),
676+ mock .call (statement , retried = True ),
677+ )
678+ )
679+
680+ def test_retry_transaction_w_multiple_statement (self ):
681+ """Check retrying an aborted transaction."""
682+ from google .cloud .spanner_dbapi .checksum import ResultsChecksum
683+ from google .cloud .spanner_dbapi .cursor import Statement
684+
685+ row = ["field1" , "field2" ]
686+ connection = self ._make_connection ()
687+
688+ checksum = ResultsChecksum ()
689+ checksum .consume_result (row )
690+ retried_checkum = ResultsChecksum ()
691+
692+ statement = Statement ("SELECT 1" , [], {}, checksum , False )
693+ statement1 = Statement ("SELECT 2" , [], {}, checksum , False )
694+ connection ._statements .append (statement )
695+ connection ._statements .append (statement1 )
696+
697+ with mock .patch (
698+ "google.cloud.spanner_dbapi.connection.Connection.run_statement" ,
699+ return_value = ([row ], retried_checkum ),
700+ ) as run_mock :
701+ with mock .patch (
702+ "google.cloud.spanner_dbapi.connection._compare_checksums"
703+ ) as compare_mock :
704+ connection .retry_transaction ()
705+
706+ compare_mock .assert_called_with (checksum , retried_checkum )
707+
708+ run_mock .assert_called_with (statement1 , retried = True )
709+
710+ def test_retry_transaction_w_empty_response (self ):
711+ """Check retrying an aborted transaction."""
712+ from google .cloud .spanner_dbapi .checksum import ResultsChecksum
713+ from google .cloud .spanner_dbapi .cursor import Statement
714+
715+ row = []
716+ connection = self ._make_connection ()
717+
718+ checksum = ResultsChecksum ()
719+ checksum .count = 1
720+ retried_checkum = ResultsChecksum ()
721+
722+ statement = Statement ("SELECT 1" , [], {}, checksum , False )
723+ connection ._statements .append (statement )
724+
725+ with mock .patch (
726+ "google.cloud.spanner_dbapi.connection.Connection.run_statement" ,
727+ return_value = (row , retried_checkum ),
728+ ) as run_mock :
729+ with mock .patch (
730+ "google.cloud.spanner_dbapi.connection._compare_checksums"
731+ ) as compare_mock :
732+ connection .retry_transaction ()
733+
734+ compare_mock .assert_called_with (checksum , retried_checkum )
735+
736+ run_mock .assert_called_with (statement , retried = True )
0 commit comments