1919from google .cloud .proto .spanner .v1 .type_pb2 import STRING
2020from google .cloud .proto .spanner .v1 .type_pb2 import Type
2121from google .cloud .spanner .client import Client
22+ from google .cloud .spanner .keyset import KeySet
2223from google .cloud .spanner .pool import BurstyPool
2324from google .cloud .spanner ._fixtures import DDL_STATEMENTS
2425
@@ -167,12 +168,32 @@ def test_update_instance(self):
167168 Config .INSTANCE .update ()
168169
169170
170- class TestDatabaseAdminAPI (unittest .TestCase ):
171+ class _TestData (object ):
172+ TABLE = 'contacts'
173+ COLUMNS = ('contact_id' , 'first_name' , 'last_name' , 'email' )
174+ ROW_DATA = (
175+ (1 , u'Phred' , u'Phlyntstone' , u'phred@example.com' ),
176+ (2 , u'Bharney' , u'Rhubble' , u'bharney@example.com' ),
177+ (3 , u'Wylma' , u'Phlyntstone' , u'wylma@example.com' ),
178+ )
179+ ALL = KeySet (all_ = True )
180+ SQL = 'SELECT * FROM contacts ORDER BY contact_id'
181+
182+ def _check_row_data (self , row_data ):
183+ self .assertEqual (len (row_data ), len (self .ROW_DATA ))
184+ for found , expected in zip (row_data , self .ROW_DATA ):
185+ self .assertEqual (len (found ), len (expected ))
186+ for f_cell , e_cell in zip (found , expected ):
187+ self .assertEqual (f_cell , e_cell )
188+
189+
190+ class TestDatabaseAPI (unittest .TestCase , _TestData ):
171191
172192 @classmethod
173193 def setUpClass (cls ):
174194 pool = BurstyPool ()
175- cls ._db = Config .INSTANCE .database (DATABASE_ID , pool = pool )
195+ cls ._db = Config .INSTANCE .database (
196+ DATABASE_ID , ddl_statements = DDL_STATEMENTS , pool = pool )
176197 cls ._db .create ()
177198
178199 @classmethod
@@ -228,16 +249,43 @@ def test_update_database_ddl(self):
228249
229250 self .assertEqual (len (temp_db .ddl_statements ), len (DDL_STATEMENTS ))
230251
252+ def test_db_batch_insert_then_db_snapshot_read_and_db_read (self ):
253+ retry = RetryInstanceState (_has_all_ddl )
254+ retry (self ._db .reload )()
231255
232- class TestSessionAPI (unittest .TestCase ):
233- TABLE = 'contacts'
234- COLUMNS = ('contact_id' , 'first_name' , 'last_name' , 'email' )
235- ROW_DATA = (
236- (1 , u'Phred' , u'Phlyntstone' , u'phred@example.com' ),
237- (2 , u'Bharney' , u'Rhubble' , u'bharney@example.com' ),
238- (3 , u'Wylma' , u'Phlyntstone' , u'wylma@example.com' ),
239- )
240- SQL = 'SELECT * FROM contacts ORDER BY contact_id'
256+ with self ._db .batch () as batch :
257+ batch .delete (self .TABLE , self .ALL )
258+ batch .insert (self .TABLE , self .COLUMNS , self .ROW_DATA )
259+
260+ with self ._db .snapshot (read_timestamp = batch .committed ) as snapshot :
261+ from_snap = list (snapshot .read (self .TABLE , self .COLUMNS , self .ALL ))
262+
263+ self ._check_row_data (from_snap )
264+
265+ from_db = list (self ._db .read (self .TABLE , self .COLUMNS , self .ALL ))
266+ self ._check_row_data (from_db )
267+
268+ def test_db_run_in_transaction_then_db_execute_sql (self ):
269+ retry = RetryInstanceState (_has_all_ddl )
270+ retry (self ._db .reload )()
271+
272+ with self ._db .batch () as batch :
273+ batch .delete (self .TABLE , self .ALL )
274+
275+ def _unit_of_work (transaction , test ):
276+ rows = list (transaction .read (test .TABLE , test .COLUMNS , self .ALL ))
277+ test .assertEqual (rows , [])
278+
279+ transaction .insert_or_update (
280+ test .TABLE , test .COLUMNS , test .ROW_DATA )
281+
282+ self ._db .run_in_transaction (_unit_of_work , test = self )
283+
284+ rows = list (self ._db .execute_sql (self .SQL ))
285+ self ._check_row_data (rows )
286+
287+
288+ class TestSessionAPI (unittest .TestCase , _TestData ):
241289
242290 @classmethod
243291 def setUpClass (cls ):
@@ -258,13 +306,6 @@ def tearDown(self):
258306 for doomed in self .to_delete :
259307 doomed .delete ()
260308
261- def _check_row_data (self , row_data ):
262- self .assertEqual (len (row_data ), len (self .ROW_DATA ))
263- for found , expected in zip (row_data , self .ROW_DATA ):
264- self .assertEqual (len (found ), len (expected ))
265- for f_cell , e_cell in zip (found , expected ):
266- self .assertEqual (f_cell , e_cell )
267-
268309 def test_session_crud (self ):
269310 retry_true = RetryResult (operator .truth )
270311 retry_false = RetryResult (operator .not_ )
@@ -276,9 +317,6 @@ def test_session_crud(self):
276317 retry_false (session .exists )()
277318
278319 def test_batch_insert_then_read (self ):
279- from google .cloud .spanner import KeySet
280- keyset = KeySet (all_ = True )
281-
282320 retry = RetryInstanceState (_has_all_ddl )
283321 retry (self ._db .reload )()
284322
@@ -287,12 +325,12 @@ def test_batch_insert_then_read(self):
287325 self .to_delete .append (session )
288326
289327 batch = session .batch ()
290- batch .delete (self .TABLE , keyset )
328+ batch .delete (self .TABLE , self . ALL )
291329 batch .insert (self .TABLE , self .COLUMNS , self .ROW_DATA )
292330 batch .commit ()
293331
294332 snapshot = session .snapshot (read_timestamp = batch .committed )
295- rows = list (snapshot .read (self .TABLE , self .COLUMNS , keyset ))
333+ rows = list (snapshot .read (self .TABLE , self .COLUMNS , self . ALL ))
296334 self ._check_row_data (rows )
297335
298336 def test_batch_insert_or_update_then_query (self ):
@@ -313,9 +351,6 @@ def test_batch_insert_or_update_then_query(self):
313351
314352 @RetryErrors (exception = _Rendezvous )
315353 def test_transaction_read_and_insert_then_rollback (self ):
316- from google .cloud .spanner import KeySet
317- keyset = KeySet (all_ = True )
318-
319354 retry = RetryInstanceState (_has_all_ddl )
320355 retry (self ._db .reload )()
321356
@@ -324,29 +359,26 @@ def test_transaction_read_and_insert_then_rollback(self):
324359 self .to_delete .append (session )
325360
326361 with session .batch () as batch :
327- batch .delete (self .TABLE , keyset )
362+ batch .delete (self .TABLE , self . ALL )
328363
329364 transaction = session .transaction ()
330365 transaction .begin ()
331366
332- rows = list (transaction .read (self .TABLE , self .COLUMNS , keyset ))
367+ rows = list (transaction .read (self .TABLE , self .COLUMNS , self . ALL ))
333368 self .assertEqual (rows , [])
334369
335370 transaction .insert (self .TABLE , self .COLUMNS , self .ROW_DATA )
336371
337372 # Inserted rows can't be read until after commit.
338- rows = list (transaction .read (self .TABLE , self .COLUMNS , keyset ))
373+ rows = list (transaction .read (self .TABLE , self .COLUMNS , self . ALL ))
339374 self .assertEqual (rows , [])
340375 transaction .rollback ()
341376
342- rows = list (session .read (self .TABLE , self .COLUMNS , keyset ))
377+ rows = list (session .read (self .TABLE , self .COLUMNS , self . ALL ))
343378 self .assertEqual (rows , [])
344379
345380 @RetryErrors (exception = _Rendezvous )
346381 def test_transaction_read_and_insert_or_update_then_commit (self ):
347- from google .cloud .spanner import KeySet
348- keyset = KeySet (all_ = True )
349-
350382 retry = RetryInstanceState (_has_all_ddl )
351383 retry (self ._db .reload )()
352384
@@ -355,32 +387,28 @@ def test_transaction_read_and_insert_or_update_then_commit(self):
355387 self .to_delete .append (session )
356388
357389 with session .batch () as batch :
358- batch .delete (self .TABLE , keyset )
390+ batch .delete (self .TABLE , self . ALL )
359391
360392 with session .transaction () as transaction :
361- rows = list (transaction .read (self .TABLE , self .COLUMNS , keyset ))
393+ rows = list (transaction .read (self .TABLE , self .COLUMNS , self . ALL ))
362394 self .assertEqual (rows , [])
363395
364396 transaction .insert_or_update (
365397 self .TABLE , self .COLUMNS , self .ROW_DATA )
366398
367399 # Inserted rows can't be read until after commit.
368- rows = list (transaction .read (self .TABLE , self .COLUMNS , keyset ))
400+ rows = list (transaction .read (self .TABLE , self .COLUMNS , self . ALL ))
369401 self .assertEqual (rows , [])
370402
371- rows = list (session .read (self .TABLE , self .COLUMNS , keyset ))
403+ rows = list (session .read (self .TABLE , self .COLUMNS , self . ALL ))
372404 self ._check_row_data (rows )
373405
374406 def _set_up_table (self , row_count ):
375- from google .cloud .spanner import KeySet
376-
377407 def _row_data (max_index ):
378408 for index in range (max_index ):
379409 yield [index , 'First%09d' % (index ,), 'Last09%d' % (index ),
380410 'test-%09d@example.com' % (index ,)]
381411
382- keyset = KeySet (all_ = True )
383-
384412 retry = RetryInstanceState (_has_all_ddl )
385413 retry (self ._db .reload )()
386414
@@ -389,17 +417,17 @@ def _row_data(max_index):
389417 self .to_delete .append (session )
390418
391419 with session .transaction () as transaction :
392- transaction .delete (self .TABLE , keyset )
420+ transaction .delete (self .TABLE , self . ALL )
393421 transaction .insert (self .TABLE , self .COLUMNS , _row_data (row_count ))
394422
395- return session , keyset , transaction .committed
423+ return session , transaction .committed
396424
397425 def test_read_w_manual_consume (self ):
398426 ROW_COUNT = 4000
399- session , keyset , committed = self ._set_up_table (ROW_COUNT )
427+ session , committed = self ._set_up_table (ROW_COUNT )
400428
401429 snapshot = session .snapshot (read_timestamp = committed )
402- streamed = snapshot .read (self .TABLE , self .COLUMNS , keyset )
430+ streamed = snapshot .read (self .TABLE , self .COLUMNS , self . ALL )
403431
404432 retrieved = 0
405433 while True :
@@ -416,7 +444,7 @@ def test_read_w_manual_consume(self):
416444
417445 def test_execute_sql_w_manual_consume (self ):
418446 ROW_COUNT = 4000
419- session , _ , committed = self ._set_up_table (ROW_COUNT )
447+ session , committed = self ._set_up_table (ROW_COUNT )
420448
421449 snapshot = session .snapshot (read_timestamp = committed )
422450 streamed = snapshot .execute_sql (self .SQL )
@@ -437,7 +465,7 @@ def test_execute_sql_w_manual_consume(self):
437465 def test_execute_sql_w_query_param (self ):
438466 SQL = 'SELECT * FROM contacts WHERE first_name = @first_name'
439467 ROW_COUNT = 10
440- session , _ , committed = self ._set_up_table (ROW_COUNT )
468+ session , committed = self ._set_up_table (ROW_COUNT )
441469
442470 snapshot = session .snapshot (read_timestamp = committed )
443471 rows = list (snapshot .execute_sql (
0 commit comments