@@ -46,7 +46,15 @@ def _get_target_class():
4646 def _make_one (self , * args , ** kw ):
4747 return self ._get_target_class ()(* args , ** kw )
4848
49- def _mock_client (self , rows = None , schema = None , num_dml_affected_rows = None ):
49+ def _mock_client (
50+ self ,
51+ rows = None ,
52+ schema = None ,
53+ num_dml_affected_rows = None ,
54+ default_query_job_config = None ,
55+ dry_run_job = False ,
56+ total_bytes_processed = 0 ,
57+ ):
5058 from google .cloud .bigquery import client
5159
5260 if rows is None :
@@ -59,8 +67,11 @@ def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
5967 total_rows = total_rows ,
6068 schema = schema ,
6169 num_dml_affected_rows = num_dml_affected_rows ,
70+ dry_run = dry_run_job ,
71+ total_bytes_processed = total_bytes_processed ,
6272 )
6373 mock_client .list_rows .return_value = rows
74+ mock_client ._default_query_job_config = default_query_job_config
6475
6576 # Assure that the REST client gets used, not the BQ Storage client.
6677 mock_client ._create_bqstorage_client .return_value = None
@@ -95,27 +106,41 @@ def _mock_bqstorage_client(self, rows=None, stream_count=0, v1beta1=False):
95106 )
96107
97108 mock_client .create_read_session .return_value = mock_read_session
109+
98110 mock_rows_stream = mock .MagicMock ()
99111 mock_rows_stream .rows .return_value = iter (rows )
100112 mock_client .read_rows .return_value = mock_rows_stream
101113
102114 return mock_client
103115
104- def _mock_job (self , total_rows = 0 , schema = None , num_dml_affected_rows = None ):
116+ def _mock_job (
117+ self ,
118+ total_rows = 0 ,
119+ schema = None ,
120+ num_dml_affected_rows = None ,
121+ dry_run = False ,
122+ total_bytes_processed = 0 ,
123+ ):
105124 from google .cloud .bigquery import job
106125
107126 mock_job = mock .create_autospec (job .QueryJob )
108127 mock_job .error_result = None
109128 mock_job .state = "DONE"
110- mock_job .result .return_value = mock_job
111- mock_job ._query_results = self ._mock_results (
112- total_rows = total_rows ,
113- schema = schema ,
114- num_dml_affected_rows = num_dml_affected_rows ,
115- )
116- mock_job .destination .to_bqstorage .return_value = (
117- "projects/P/datasets/DS/tables/T"
118- )
129+ mock_job .dry_run = dry_run
130+
131+ if dry_run :
132+ mock_job .result .side_effect = exceptions .NotFound
133+ mock_job .total_bytes_processed = total_bytes_processed
134+ else :
135+ mock_job .result .return_value = mock_job
136+ mock_job ._query_results = self ._mock_results (
137+ total_rows = total_rows ,
138+ schema = schema ,
139+ num_dml_affected_rows = num_dml_affected_rows ,
140+ )
141+ mock_job .destination .to_bqstorage .return_value = (
142+ "projects/P/datasets/DS/tables/T"
143+ )
119144
120145 if num_dml_affected_rows is None :
121146 mock_job .statement_type = None # API sends back None for SELECT
@@ -445,7 +470,27 @@ def test_execute_custom_job_id(self):
445470 self .assertEqual (args [0 ], "SELECT 1;" )
446471 self .assertEqual (kwargs ["job_id" ], "foo" )
447472
448- def test_execute_custom_job_config (self ):
473+ def test_execute_w_default_config (self ):
474+ from google .cloud .bigquery .dbapi import connect
475+ from google .cloud .bigquery import job
476+
477+ default_config = job .QueryJobConfig (use_legacy_sql = False , flatten_results = True )
478+ client = self ._mock_client (
479+ rows = [], num_dml_affected_rows = 0 , default_query_job_config = default_config
480+ )
481+ connection = connect (client )
482+ cursor = connection .cursor ()
483+
484+ cursor .execute ("SELECT 1;" , job_id = "foo" )
485+
486+ _ , kwargs = client .query .call_args
487+ used_config = kwargs ["job_config" ]
488+ expected_config = job .QueryJobConfig (
489+ use_legacy_sql = False , flatten_results = True , query_parameters = []
490+ )
491+ self .assertEqual (used_config ._properties , expected_config ._properties )
492+
493+ def test_execute_custom_job_config_wo_default_config (self ):
449494 from google .cloud .bigquery .dbapi import connect
450495 from google .cloud .bigquery import job
451496
@@ -459,6 +504,29 @@ def test_execute_custom_job_config(self):
459504 self .assertEqual (kwargs ["job_id" ], "foo" )
460505 self .assertEqual (kwargs ["job_config" ], config )
461506
507+ def test_execute_custom_job_config_w_default_config (self ):
508+ from google .cloud .bigquery .dbapi import connect
509+ from google .cloud .bigquery import job
510+
511+ default_config = job .QueryJobConfig (use_legacy_sql = False , flatten_results = True )
512+ client = self ._mock_client (
513+ rows = [], num_dml_affected_rows = 0 , default_query_job_config = default_config
514+ )
515+ connection = connect (client )
516+ cursor = connection .cursor ()
517+ config = job .QueryJobConfig (use_legacy_sql = True )
518+
519+ cursor .execute ("SELECT 1;" , job_id = "foo" , job_config = config )
520+
521+ _ , kwargs = client .query .call_args
522+ used_config = kwargs ["job_config" ]
523+ expected_config = job .QueryJobConfig (
524+ use_legacy_sql = True , # the config passed to execute() prevails
525+ flatten_results = True , # from the default
526+ query_parameters = [],
527+ )
528+ self .assertEqual (used_config ._properties , expected_config ._properties )
529+
462530 def test_execute_w_dml (self ):
463531 from google .cloud .bigquery .dbapi import connect
464532
@@ -514,6 +582,35 @@ def test_execute_w_query(self):
514582 row = cursor .fetchone ()
515583 self .assertIsNone (row )
516584
585+ def test_execute_w_query_dry_run (self ):
586+ from google .cloud .bigquery .job import QueryJobConfig
587+ from google .cloud .bigquery .schema import SchemaField
588+ from google .cloud .bigquery import dbapi
589+
590+ connection = dbapi .connect (
591+ self ._mock_client (
592+ rows = [("hello" , "world" , 1 ), ("howdy" , "y'all" , 2 )],
593+ schema = [
594+ SchemaField ("a" , "STRING" , mode = "NULLABLE" ),
595+ SchemaField ("b" , "STRING" , mode = "REQUIRED" ),
596+ SchemaField ("c" , "INTEGER" , mode = "NULLABLE" ),
597+ ],
598+ dry_run_job = True ,
599+ total_bytes_processed = 12345 ,
600+ )
601+ )
602+ cursor = connection .cursor ()
603+
604+ cursor .execute (
605+ "SELECT a, b, c FROM hello_world WHERE d > 3;" ,
606+ job_config = QueryJobConfig (dry_run = True ),
607+ )
608+
609+ self .assertEqual (cursor .rowcount , 0 )
610+ self .assertIsNone (cursor .description )
611+ rows = cursor .fetchall ()
612+ self .assertEqual (list (rows ), [])
613+
517614 def test_execute_raises_if_result_raises (self ):
518615 import google .cloud .exceptions
519616
@@ -523,8 +620,10 @@ def test_execute_raises_if_result_raises(self):
523620 from google .cloud .bigquery .dbapi import exceptions
524621
525622 job = mock .create_autospec (job .QueryJob )
623+ job .dry_run = None
526624 job .result .side_effect = google .cloud .exceptions .GoogleCloudError ("" )
527625 client = mock .create_autospec (client .Client )
626+ client ._default_query_job_config = None
528627 client .query .return_value = job
529628 connection = connect (client )
530629 cursor = connection .cursor ()
0 commit comments