From 777c374fbe2fe5343c41c580781ead8cd653a267 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Wed, 30 May 2018 13:29:49 -0400 Subject: [PATCH 01/12] Remove check that txn has mutations before committing. With DML statements, there may be no explicit mutations tracked on the client side. --- .../google/cloud/spanner_v1/transaction.py | 3 --- spanner/tests/unit/test_transaction.py | 19 ++++++++++--------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/spanner/google/cloud/spanner_v1/transaction.py b/spanner/google/cloud/spanner_v1/transaction.py index 9f2f6d99895e..b1097ac1d721 100644 --- a/spanner/google/cloud/spanner_v1/transaction.py +++ b/spanner/google/cloud/spanner_v1/transaction.py @@ -114,9 +114,6 @@ def commit(self): """ self._check_state() - if not self._mutations: - raise ValueError("No mutations to commit") - database = self._session._database api = database.spanner_api metadata = _metadata_with_prefix(database.name) diff --git a/spanner/tests/unit/test_transaction.py b/spanner/tests/unit/test_transaction.py index 29c1e765888e..22af4310d923 100644 --- a/spanner/tests/unit/test_transaction.py +++ b/spanner/tests/unit/test_transaction.py @@ -238,13 +238,6 @@ def test_commit_already_rolled_back(self): with self.assertRaises(ValueError): transaction.commit() - def test_commit_no_mutations(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - with self.assertRaises(ValueError): - transaction.commit() - def test_commit_w_other_error(self): database = _Database() database.spanner_api = self._make_spanner_api() @@ -259,7 +252,7 @@ def test_commit_w_other_error(self): self.assertIsNone(transaction.committed) - def test_commit_ok(self): + def _commit_helper(self, mutate=True): import datetime from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse from google.cloud.spanner_v1.keyset import KeySet @@ -277,7 +270,9 @@ def test_commit_ok(self): session = _Session(database) transaction = self._make_one(session) transaction._transaction_id = self.TRANSACTION_ID - transaction.delete(TABLE_NAME, keyset) + + if mutate: + transaction.delete(TABLE_NAME, keyset) transaction.commit() @@ -291,6 +286,12 @@ def test_commit_ok(self): self.assertEqual( metadata, [('google-cloud-resource-prefix', database.name)]) + def test_commit_no_mutations(self): + self._commit_helper(mutate=False) + + def test_commit_w_mutations(self): + self._commit_helper(mutate=True) + def test_context_mgr_success(self): import datetime from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse From 974b598f05b0dc36e8e46afdf4df7666c14a2291 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Wed, 30 May 2018 14:16:57 -0400 Subject: [PATCH 02/12] Add (failing) systest for 'execute_sql' w/ DML. --- spanner/tests/system/test_system.py | 40 +++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index 2d85a99531b6..c506cbf5d6db 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -627,6 +627,46 @@ def test_transaction_read_and_insert_or_update_then_commit(self): rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) self._check_rows_data(rows) + def _generate_insert_statements(self): + insert_template = ( + 'INSERT INTO {table} ({column_list}) ' + 'VALUES ({row_data})' + ) + for row in self.ROW_DATA: + yield insert_template.format( + table=self.TABLE, + column_list=', '.join(self.COLUMNS), + row_data='{}, "{}", "{}", "{}"'.format(*row) + ) + + @RetryErrors(exception=exceptions.ServerError) + @RetryErrors(exception=exceptions.Conflict) + def test_transaction_execute_dml_read_commit(self): + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + with session.batch() as batch: + batch.delete(self.TABLE, self.ALL) + + with session.transaction() as transaction: + rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self.assertEqual(rows, []) + + for insert_statement in self._generate_insert_statements(): + transaction.execute_sql(insert_statement) + + # Rows inserted via DML *can* be read before commit. + during_rows = list( + transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(during_rows) + + rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(rows) + def _transaction_concurrency_helper(self, unit_of_work, pkey): INITIAL_VALUE = 123 NUM_THREADS = 3 # conforms to equivalent Java systest. From 359cb833c077d6dd9d56a8ca38c9af9fdd391800 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Fri, 1 Jun 2018 16:45:09 -0400 Subject: [PATCH 03/12] Add 'Transaction.execute_update' method. Switch the still-failing system test to use it. --- .../google/cloud/spanner_v1/transaction.py | 64 ++++++++++++++- spanner/tests/system/test_system.py | 5 +- spanner/tests/unit/test_transaction.py | 79 +++++++++++++++++++ 3 files changed, 144 insertions(+), 4 deletions(-) diff --git a/spanner/google/cloud/spanner_v1/transaction.py b/spanner/google/cloud/spanner_v1/transaction.py index b1097ac1d721..33883786a479 100644 --- a/spanner/google/cloud/spanner_v1/transaction.py +++ b/spanner/google/cloud/spanner_v1/transaction.py @@ -14,11 +14,13 @@ """Spanner read-write transaction support.""" -from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionSelector -from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionOptions +from google.protobuf.struct_pb2 import Struct from google.cloud._helpers import _pb_timestamp_to_datetime +from google.cloud.spanner_v1._helpers import _make_value_pb from google.cloud.spanner_v1._helpers import _metadata_with_prefix +from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionSelector +from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionOptions from google.cloud.spanner_v1.snapshot import _SnapshotBase from google.cloud.spanner_v1.batch import _BatchBase @@ -125,6 +127,64 @@ def commit(self): del self._session._transaction return self.committed + def execute_update(self, dml, params=None, param_types=None, + query_mode=None, partition=None): + """Perform an ``ExecuteSql`` API request with DML. + + :type dml: str + :param dml: SQL DML statement + + :type params: dict, {str -> column value} + :param params: values for parameter replacement. Keys must match + the names used in ``dml``. + + :type param_types: dict[str -> Union[dict, .types.Type]] + :param param_types: + (Optional) maps explicit types for one or more param values; + required if parameters are passed. + + :type query_mode: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryMode` + :param query_mode: Mode governing return of results / query plan. See + https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 + + :type partition: bytes + :param partition: (Optional) one of the partition tokens returned + from :meth:`partition_query`. + + :rtype: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.ResultSetStats` + :returns: + stats object, including count of rows affected by the DML + statement. + """ + if params is not None: + if param_types is None: + raise ValueError( + "Specify 'param_types' when passing 'params'.") + params_pb = Struct(fields={ + key: _make_value_pb(value) for key, value in params.items()}) + else: + params_pb = None + + database = self._session._database + metadata = _metadata_with_prefix(database.name) + transaction = self._make_txn_selector() + api = database.spanner_api + + response = api.execute_sql( + self._session.name, + dml, + transaction=transaction, + params=params_pb, + param_types=param_types, + query_mode=query_mode, + partition_token=partition, + metadata=metadata, + ) + + return response.stats + def __enter__(self): """Begin ``with`` block.""" self.begin() diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index c506cbf5d6db..7821896b06f0 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -641,7 +641,7 @@ def _generate_insert_statements(self): @RetryErrors(exception=exceptions.ServerError) @RetryErrors(exception=exceptions.Conflict) - def test_transaction_execute_dml_read_commit(self): + def test_transaction_execute_update_read_commit(self): retry = RetryInstanceState(_has_all_ddl) retry(self._db.reload)() @@ -657,7 +657,8 @@ def test_transaction_execute_dml_read_commit(self): self.assertEqual(rows, []) for insert_statement in self._generate_insert_statements(): - transaction.execute_sql(insert_statement) + result = transaction.execute_update(insert_statement) + print("DML: {}, stats: {}".format(insert_statement, result)) # Rows inserted via DML *can* be read before commit. during_rows = list( diff --git a/spanner/tests/unit/test_transaction.py b/spanner/tests/unit/test_transaction.py index 22af4310d923..8a717ddf28d8 100644 --- a/spanner/tests/unit/test_transaction.py +++ b/spanner/tests/unit/test_transaction.py @@ -24,6 +24,16 @@ ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], ['bharney@example.com', 'Bharney', 'Rhubble', 31], ] +DML_QUERY = """\ +INSERT INTO citizens(first_name, last_name, age) +VALUES ("Phred", "Phlyntstone", 32) +""" +DML_QUERY_WITH_PARAM = """ +INSERT INTO citizens(first_name, last_name, age) +VALUES ("Phred", "Phlyntstone", @age) +""" +PARAMS = {'age': 30} +PARAM_TYPES = {'age': 'INT64'} class TestTransaction(unittest.TestCase): @@ -292,6 +302,75 @@ def test_commit_no_mutations(self): def test_commit_w_mutations(self): self._commit_helper(mutate=True) + def test_execute_update_other_error(self): + database = _Database() + database.spanner_api = self._make_spanner_api() + database.spanner_api.execute_sql.side_effect = RuntimeError() + session = _Session(database) + transaction = self._make_one(session) + transaction._transaction_id = self.TRANSACTION_ID + + with self.assertRaises(RuntimeError): + transaction.execute_update(DML_QUERY) + + def test_execute_update_w_params_wo_param_types(self): + database = _Database() + database.spanner_api = self._make_spanner_api() + session = _Session(database) + session = _Session() + transaction = self._make_one(session) + transaction._transaction_id = self.TRANSACTION_ID + + with self.assertRaises(ValueError): + transaction.execute_update(DML_QUERY_WITH_PARAM, PARAMS) + + def _execute_update_helper(self, partition=None): + from google.protobuf.struct_pb2 import Struct + from google.cloud.spanner_v1.proto.result_set_pb2 import ( + ResultSet, ResultSetStats) + from google.cloud.spanner_v1.proto.transaction_pb2 import ( + TransactionSelector) + from google.cloud.spanner_v1._helpers import _make_value_pb + + MODE = 2 # PROFILE + stats_pb = ResultSetStats( + query_stats=Struct(fields={ + 'rows_affected': _make_value_pb(1), + })) + database = _Database() + api = database.spanner_api = self._make_spanner_api() + api.execute_sql.return_value = ResultSet(stats=stats_pb) + session = _Session(database) + transaction = self._make_one(session) + transaction._transaction_id = self.TRANSACTION_ID + + result = transaction.execute_update( + DML_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, + query_mode=MODE, partition=partition) + + self.assertEqual(result, stats_pb) + + expected_transaction = TransactionSelector(id=self.TRANSACTION_ID) + expected_params = Struct(fields={ + key: _make_value_pb(value) for (key, value) in PARAMS.items()}) + + api.execute_sql.assert_called_once_with( + self.SESSION_NAME, + DML_QUERY_WITH_PARAM, + transaction=expected_transaction, + params=expected_params, + param_types=PARAM_TYPES, + query_mode=MODE, + partition_token=partition, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + + def test_execute_update_wo_partition(self): + self._execute_update_helper() + + def test_execute_update_w_partition(self): + self._execute_update_helper(partition=b'FACEDACE') + def test_context_mgr_success(self): import datetime from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse From 68967176b622ec8f2ae27f8e41b8942ce26fdd03 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Fri, 1 Jun 2018 16:54:12 -0400 Subject: [PATCH 04/12] Get new systest to 'pass' by forcing only a single insert. --- spanner/tests/system/test_system.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index 7821896b06f0..d5dced63148e 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -642,6 +642,8 @@ def _generate_insert_statements(self): @RetryErrors(exception=exceptions.ServerError) @RetryErrors(exception=exceptions.Conflict) def test_transaction_execute_update_read_commit(self): + insert_statements = list(self._generate_insert_statements()) + retry = RetryInstanceState(_has_all_ddl) retry(self._db.reload)() @@ -656,17 +658,17 @@ def test_transaction_execute_update_read_commit(self): rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL)) self.assertEqual(rows, []) - for insert_statement in self._generate_insert_statements(): + for insert_statement in insert_statements[:1]: result = transaction.execute_update(insert_statement) print("DML: {}, stats: {}".format(insert_statement, result)) # Rows inserted via DML *can* be read before commit. during_rows = list( transaction.read(self.TABLE, self.COLUMNS, self.ALL)) - self._check_rows_data(during_rows) + self._check_rows_data(during_rows, self.ROW_DATA[:1]) rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) - self._check_rows_data(rows) + self._check_rows_data(rows, self.ROW_DATA[:1]) def _transaction_concurrency_helper(self, unit_of_work, pkey): INITIAL_VALUE = 123 From 250564ba73eedb16075bfefe6db3aadc09a44b31 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Tue, 5 Jun 2018 08:07:08 -0400 Subject: [PATCH 05/12] Add test for 'Transaction.execute_sql' w/ DML. Assert that stats have appropriate row count. --- spanner/tests/system/test_system.py | 34 ++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index d5dced63148e..445f2a1a1410 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -639,6 +639,38 @@ def _generate_insert_statements(self): row_data='{}, "{}", "{}", "{}"'.format(*row) ) + @RetryErrors(exception=exceptions.ServerError) + @RetryErrors(exception=exceptions.Conflict) + def test_transaction_execute_sql_w_dml_read_commit(self): + insert_statements = list(self._generate_insert_statements()) + + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + with session.batch() as batch: + batch.delete(self.TABLE, self.ALL) + + with session.transaction() as transaction: + rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self.assertEqual(rows, []) + + for insert_statement in insert_statements[:1]: + result = transaction.execute_sql(insert_statement) + list(result) # iterate to get stats + self.assertEqual(result.stats.row_count_exact, 1) + + # Rows inserted via DML *can* be read before commit. + during_rows = list( + transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(during_rows, self.ROW_DATA[:1]) + + rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(rows, self.ROW_DATA[:1]) + @RetryErrors(exception=exceptions.ServerError) @RetryErrors(exception=exceptions.Conflict) def test_transaction_execute_update_read_commit(self): @@ -660,7 +692,7 @@ def test_transaction_execute_update_read_commit(self): for insert_statement in insert_statements[:1]: result = transaction.execute_update(insert_statement) - print("DML: {}, stats: {}".format(insert_statement, result)) + self.assertEqual(result.row_count_exact, 1) # Rows inserted via DML *can* be read before commit. during_rows = list( From 7cb8fc1509f7f87815a486aca1ea1ed148bf2001 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Tue, 5 Jun 2018 08:24:53 -0400 Subject: [PATCH 06/12] Pass 'seqno' when invoking 'ExecuteSql'/'ExecuteStreamingSql' APIs. Required for executing multiple DML statements w/in context of one transaction. Update system tests for DML to exercise this (they previously passed only one statement). --- spanner/google/cloud/spanner_v1/snapshot.py | 13 ++++++++++--- spanner/google/cloud/spanner_v1/transaction.py | 3 +++ spanner/tests/system/test_system.py | 14 ++++++-------- spanner/tests/unit/test_snapshot.py | 15 ++++++++++----- spanner/tests/unit/test_transaction.py | 13 +++++++++---- 5 files changed, 38 insertions(+), 20 deletions(-) diff --git a/spanner/google/cloud/spanner_v1/snapshot.py b/spanner/google/cloud/spanner_v1/snapshot.py index 827da34ee7c4..00d45410f499 100644 --- a/spanner/google/cloud/spanner_v1/snapshot.py +++ b/spanner/google/cloud/spanner_v1/snapshot.py @@ -71,6 +71,7 @@ class _SnapshotBase(_SessionWrapper): _multi_use = False _transaction_id = None _read_request_count = 0 + _execute_sql_count = 0 def _make_txn_selector(self): # pylint: disable=redundant-returns-doc """Helper for :meth:`read` / :meth:`execute_sql`. @@ -195,14 +196,20 @@ def execute_sql(self, sql, params=None, param_types=None, restart = functools.partial( api.execute_streaming_sql, - self._session.name, sql, - transaction=transaction, params=params_pb, param_types=param_types, - query_mode=query_mode, partition_token=partition, + self._session.name, + sql, + transaction=transaction, + params=params_pb, + param_types=param_types, + query_mode=query_mode, + partition_token=partition, + seqno=self._execute_sql_count, metadata=metadata) iterator = _restart_on_unavailable(restart) self._read_request_count += 1 + self._execute_sql_count += 1 if self._multi_use: return StreamedResultSet(iterator, source=self) diff --git a/spanner/google/cloud/spanner_v1/transaction.py b/spanner/google/cloud/spanner_v1/transaction.py index 33883786a479..beb06d5fa866 100644 --- a/spanner/google/cloud/spanner_v1/transaction.py +++ b/spanner/google/cloud/spanner_v1/transaction.py @@ -37,6 +37,7 @@ class Transaction(_SnapshotBase, _BatchBase): """Timestamp at which the transaction was successfully committed.""" _rolled_back = False _multi_use = True + _execute_sql_count = 0 def __init__(self, session): if session._transaction is not None: @@ -180,9 +181,11 @@ def execute_update(self, dml, params=None, param_types=None, param_types=param_types, query_mode=query_mode, partition_token=partition, + seqno=self._execute_sql_count, metadata=metadata, ) + self._execute_sql_count += 1 return response.stats def __enter__(self): diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index 445f2a1a1410..0c325e34ecb7 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -642,8 +642,6 @@ def _generate_insert_statements(self): @RetryErrors(exception=exceptions.ServerError) @RetryErrors(exception=exceptions.Conflict) def test_transaction_execute_sql_w_dml_read_commit(self): - insert_statements = list(self._generate_insert_statements()) - retry = RetryInstanceState(_has_all_ddl) retry(self._db.reload)() @@ -658,7 +656,7 @@ def test_transaction_execute_sql_w_dml_read_commit(self): rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL)) self.assertEqual(rows, []) - for insert_statement in insert_statements[:1]: + for insert_statement in self._generate_insert_statements(): result = transaction.execute_sql(insert_statement) list(result) # iterate to get stats self.assertEqual(result.stats.row_count_exact, 1) @@ -666,10 +664,10 @@ def test_transaction_execute_sql_w_dml_read_commit(self): # Rows inserted via DML *can* be read before commit. during_rows = list( transaction.read(self.TABLE, self.COLUMNS, self.ALL)) - self._check_rows_data(during_rows, self.ROW_DATA[:1]) + self._check_rows_data(during_rows) rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) - self._check_rows_data(rows, self.ROW_DATA[:1]) + self._check_rows_data(rows) @RetryErrors(exception=exceptions.ServerError) @RetryErrors(exception=exceptions.Conflict) @@ -690,17 +688,17 @@ def test_transaction_execute_update_read_commit(self): rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL)) self.assertEqual(rows, []) - for insert_statement in insert_statements[:1]: + for insert_statement in self._generate_insert_statements(): result = transaction.execute_update(insert_statement) self.assertEqual(result.row_count_exact, 1) # Rows inserted via DML *can* be read before commit. during_rows = list( transaction.read(self.TABLE, self.COLUMNS, self.ALL)) - self._check_rows_data(during_rows, self.ROW_DATA[:1]) + self._check_rows_data(during_rows) rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) - self._check_rows_data(rows, self.ROW_DATA[:1]) + self._check_rows_data(rows) def _transaction_concurrency_helper(self, unit_of_work, pkey): INITIAL_VALUE = 123 diff --git a/spanner/tests/unit/test_snapshot.py b/spanner/tests/unit/test_snapshot.py index 2b5961b75f74..85cf5d2febc5 100644 --- a/spanner/tests/unit/test_snapshot.py +++ b/spanner/tests/unit/test_snapshot.py @@ -176,6 +176,7 @@ def test_ctor(self): session = _Session() base = self._make_one(session) self.assertIs(base._session, session) + self.assertEqual(base._execute_sql_count, 0) def test__make_txn_selector_virtual(self): session = _Session() @@ -328,7 +329,7 @@ def test_execute_sql_w_params_wo_param_types(self): derived.execute_sql(SQL_QUERY_WITH_PARAM, PARAMS) def _execute_sql_helper( - self, multi_use, first=True, count=0, partition=None): + self, multi_use, first=True, count=0, partition=None, sql_count=0): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1.proto.result_set_pb2 import ( PartialResultSet, ResultSetMetadata, ResultSetStats) @@ -369,6 +370,7 @@ def _execute_sql_helper( derived = self._makeDerived(session) derived._multi_use = multi_use derived._read_request_count = count + derived._execute_sql_count = sql_count if not first: derived._transaction_id = TXN_ID @@ -388,7 +390,7 @@ def _execute_sql_helper( self.assertEqual(result_set.stats, stats_pb) (r_session, sql, transaction, params, param_types, - resume_token, query_mode, partition_token, + resume_token, query_mode, partition_token, seqno, metadata) = api._executed_streaming_sql_with self.assertEqual(r_session, self.SESSION_NAME) @@ -408,9 +410,12 @@ def _execute_sql_helper( self.assertEqual(query_mode, MODE) self.assertEqual(resume_token, b'') self.assertEqual(partition_token, partition) + self.assertEqual(seqno, sql_count) self.assertEqual( metadata, [('google-cloud-resource-prefix', database.name)]) + self.assertEqual(derived._execute_sql_count, sql_count + 1) + def test_execute_sql_wo_multi_use(self): self._execute_sql_helper(multi_use=False) @@ -419,7 +424,7 @@ def test_execute_sql_wo_multi_use_w_read_request_count_gt_0(self): self._execute_sql_helper(multi_use=False, count=1) def test_execute_sql_w_multi_use_wo_first(self): - self._execute_sql_helper(multi_use=True, first=False) + self._execute_sql_helper(multi_use=True, first=False, sql_count=1) def test_execute_sql_w_multi_use_wo_first_w_count_gt_0(self): self._execute_sql_helper(multi_use=True, first=False, count=1) @@ -977,10 +982,10 @@ def streaming_read(self, session, table, columns, key_set, def execute_streaming_sql(self, session, sql, transaction=None, params=None, param_types=None, resume_token=b'', query_mode=None, - partition_token=None, metadata=None): + partition_token=None, seqno=0, metadata=None): self._executed_streaming_sql_with = ( session, sql, transaction, params, param_types, resume_token, - query_mode, partition_token, metadata) + query_mode, partition_token, seqno, metadata) return self._execute_streaming_sql_response # pylint: disable=too-many-arguments diff --git a/spanner/tests/unit/test_transaction.py b/spanner/tests/unit/test_transaction.py index 8a717ddf28d8..8dc1b49bf7e1 100644 --- a/spanner/tests/unit/test_transaction.py +++ b/spanner/tests/unit/test_transaction.py @@ -78,6 +78,7 @@ def test_ctor_defaults(self): self.assertIsNone(transaction.committed) self.assertFalse(transaction._rolled_back) self.assertTrue(transaction._multi_use) + self.assertEqual(transaction._execute_sql_count, 0) def test__check_state_not_begun(self): session = _Session() @@ -324,7 +325,7 @@ def test_execute_update_w_params_wo_param_types(self): with self.assertRaises(ValueError): transaction.execute_update(DML_QUERY_WITH_PARAM, PARAMS) - def _execute_update_helper(self, partition=None): + def _execute_update_helper(self, partition=None, count=0): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1.proto.result_set_pb2 import ( ResultSet, ResultSetStats) @@ -343,6 +344,7 @@ def _execute_update_helper(self, partition=None): session = _Session(database) transaction = self._make_one(session) transaction._transaction_id = self.TRANSACTION_ID + transaction._execute_sql_count = count result = transaction.execute_update( DML_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, @@ -362,13 +364,16 @@ def _execute_update_helper(self, partition=None): param_types=PARAM_TYPES, query_mode=MODE, partition_token=partition, + seqno=count, metadata=[('google-cloud-resource-prefix', database.name)], ) - def test_execute_update_wo_partition(self): - self._execute_update_helper() + self.assertEqual(transaction._execute_sql_count, count + 1) - def test_execute_update_w_partition(self): + def test_execute_update_w_count_wo_partition(self): + self._execute_update_helper(count=1) + + def test_execute_update_wo_count_w_partition(self): self._execute_update_helper(partition=b'FACEDACE') def test_context_mgr_success(self): From 90d9d43e43874ce872311287f202674944d491c6 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Tue, 5 Jun 2018 09:33:19 -0400 Subject: [PATCH 07/12] Kill off custom _FauxSpannerAPI mock. --- spanner/tests/system/test_system.py | 2 - spanner/tests/unit/test_snapshot.py | 261 +++++++++++----------------- 2 files changed, 106 insertions(+), 157 deletions(-) diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index 0c325e34ecb7..65ee553806ff 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -672,8 +672,6 @@ def test_transaction_execute_sql_w_dml_read_commit(self): @RetryErrors(exception=exceptions.ServerError) @RetryErrors(exception=exceptions.Conflict) def test_transaction_execute_update_read_commit(self): - insert_statements = list(self._generate_insert_statements()) - retry = RetryInstanceState(_has_all_ddl) retry(self._db.reload)() diff --git a/spanner/tests/unit/test_snapshot.py b/spanner/tests/unit/test_snapshot.py index 85cf5d2febc5..21cb6cbe35df 100644 --- a/spanner/tests/unit/test_snapshot.py +++ b/spanner/tests/unit/test_snapshot.py @@ -31,6 +31,8 @@ PARAMS_WITH_BYTES = {'bytes': b'FACEDACE'} RESUME_TOKEN = b'DEADBEEF' TXN_ID = b'DEAFBEAD' +SECONDS = 3 +MICROS = 123456 class Test_restart_on_unavailable(unittest.TestCase): @@ -202,7 +204,7 @@ def _read_helper(self, multi_use, first=True, count=0, partition=None): from google.cloud.spanner_v1.proto.result_set_pb2 import ( PartialResultSet, ResultSetMetadata, ResultSetStats) from google.cloud.spanner_v1.proto.transaction_pb2 import ( - TransactionSelector) + TransactionSelector, TransactionOptions) from google.cloud.spanner_v1.proto.type_pb2 import Type, StructType from google.cloud.spanner_v1.proto.type_pb2 import STRING, INT64 from google.cloud.spanner_v1.keyset import KeySet @@ -229,13 +231,13 @@ def _read_helper(self, multi_use, first=True, count=0, partition=None): PartialResultSet(values=VALUE_PBS[0], metadata=metadata_pb), PartialResultSet(values=VALUE_PBS[1], stats=stats_pb), ] - KEYS = ['bharney@example.com', 'phred@example.com'] + KEYS = [['bharney@example.com'], ['phred@example.com']] keyset = KeySet(keys=KEYS) INDEX = 'email-address-index' LIMIT = 20 database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _streaming_read_response=_MockIterator(*result_sets)) + api = database.spanner_api = self._make_spanner_api() + api.streaming_read.return_value = _MockIterator(*result_sets) session = _Session(database) derived = self._makeDerived(session) derived._multi_use = multi_use @@ -263,31 +265,33 @@ def _read_helper(self, multi_use, first=True, count=0, partition=None): self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) - (r_session, table, columns, key_set, transaction, index, limit, - resume_token, r_partition, metadata) = api._streaming_read_with + txn_options = TransactionOptions( + read_only=TransactionOptions.ReadOnly(strong=True)) - self.assertEqual(r_session, self.SESSION_NAME) - self.assertEqual(table, TABLE_NAME) - self.assertEqual(columns, COLUMNS) - self.assertEqual(key_set, keyset._to_pb()) - self.assertIsInstance(transaction, TransactionSelector) if multi_use: if first: - self.assertTrue(transaction.begin.read_only.strong) + expected_transaction = TransactionSelector(begin=txn_options) else: - self.assertEqual(transaction.id, TXN_ID) + expected_transaction = TransactionSelector(id=TXN_ID) else: - self.assertTrue(transaction.single_use.read_only.strong) - self.assertEqual(index, INDEX) + expected_transaction = TransactionSelector(single_use=txn_options) + if partition is not None: - self.assertEqual(limit, 0) - self.assertEqual(r_partition, partition) + expected_limit = 0 else: - self.assertEqual(limit, LIMIT) - self.assertIsNone(r_partition) - self.assertEqual(resume_token, b'') - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + expected_limit = LIMIT + + api.streaming_read.assert_called_once_with( + self.SESSION_NAME, + TABLE_NAME, + COLUMNS, + keyset._to_pb(), + transaction=expected_transaction, + index=INDEX, + limit=expected_limit, + partition_token=partition, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_read_wo_multi_use(self): self._read_helper(multi_use=False) @@ -334,7 +338,7 @@ def _execute_sql_helper( from google.cloud.spanner_v1.proto.result_set_pb2 import ( PartialResultSet, ResultSetMetadata, ResultSetStats) from google.cloud.spanner_v1.proto.transaction_pb2 import ( - TransactionSelector) + TransactionSelector, TransactionOptions) from google.cloud.spanner_v1.proto.type_pb2 import Type, StructType from google.cloud.spanner_v1.proto.type_pb2 import STRING, INT64 from google.cloud.spanner_v1._helpers import _make_value_pb @@ -364,8 +368,8 @@ def _execute_sql_helper( ] iterator = _MockIterator(*result_sets) database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _execute_streaming_sql_response=iterator) + api = database.spanner_api = self._make_spanner_api() + api.execute_streaming_sql.return_value = iterator session = _Session(database) derived = self._makeDerived(session) derived._multi_use = multi_use @@ -389,30 +393,31 @@ def _execute_sql_helper( self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) - (r_session, sql, transaction, params, param_types, - resume_token, query_mode, partition_token, seqno, - metadata) = api._executed_streaming_sql_with + txn_options = TransactionOptions( + read_only=TransactionOptions.ReadOnly(strong=True)) - self.assertEqual(r_session, self.SESSION_NAME) - self.assertEqual(sql, SQL_QUERY_WITH_PARAM) - self.assertIsInstance(transaction, TransactionSelector) if multi_use: if first: - self.assertTrue(transaction.begin.read_only.strong) + expected_transaction = TransactionSelector(begin=txn_options) else: - self.assertEqual(transaction.id, TXN_ID) + expected_transaction = TransactionSelector(id=TXN_ID) else: - self.assertTrue(transaction.single_use.read_only.strong) + expected_transaction = TransactionSelector(single_use=txn_options) + expected_params = Struct(fields={ key: _make_value_pb(value) for (key, value) in PARAMS.items()}) - self.assertEqual(params, expected_params) - self.assertEqual(param_types, PARAM_TYPES) - self.assertEqual(query_mode, MODE) - self.assertEqual(resume_token, b'') - self.assertEqual(partition_token, partition) - self.assertEqual(seqno, sql_count) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + + api.execute_streaming_sql.assert_called_once_with( + self.SESSION_NAME, + SQL_QUERY_WITH_PARAM, + transaction=expected_transaction, + params=expected_params, + param_types=PARAM_TYPES, + query_mode=MODE, + partition_token=partition, + seqno=sql_count, + metadata=[('google-cloud-resource-prefix', database.name)], + ) self.assertEqual(derived._execute_sql_count, sql_count + 1) @@ -459,8 +464,8 @@ def _partition_read_helper( transaction=Transaction(id=new_txn_id), ) database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _partition_read_response=response) + api = database.spanner_api = self._make_spanner_api() + api.partition_read.return_value = response session = _Session(database) derived = self._makeDerived(session) derived._multi_use = multi_use @@ -476,23 +481,21 @@ def _partition_read_helper( self.assertEqual(tokens, [token_1, token_2]) - (r_session, table, key_set, transaction, r_index, columns, - partition_options, metadata) = api._partition_read_with - - self.assertEqual(r_session, self.SESSION_NAME) - self.assertEqual(table, TABLE_NAME) - self.assertEqual(key_set, keyset._to_pb()) - self.assertIsInstance(transaction, TransactionSelector) - self.assertEqual(transaction.id, TXN_ID) - self.assertFalse(transaction.HasField('begin')) - self.assertEqual(r_index, index) - self.assertEqual(columns, COLUMNS) - self.assertEqual( - partition_options, - PartitionOptions( - partition_size_bytes=size, max_partitions=max_partitions)) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + expected_txn_selector = TransactionSelector(id=TXN_ID) + + expected_partition_options = PartitionOptions( + partition_size_bytes=size, max_partitions=max_partitions) + + api.partition_read.assert_called_once_with( + session=self.SESSION_NAME, + table=TABLE_NAME, + columns=COLUMNS, + key_set=keyset._to_pb(), + transaction=expected_txn_selector, + index=index, + partition_options=expected_partition_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_partition_read_single_use_raises(self): with self.assertRaises(ValueError): @@ -549,8 +552,8 @@ def _partition_query_helper( transaction=Transaction(id=new_txn_id), ) database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _partition_query_response=response) + api = database.spanner_api = self._make_spanner_api() + api.partition_query.return_value = response session = _Session(database) derived = self._makeDerived(session) derived._multi_use = multi_use @@ -565,24 +568,23 @@ def _partition_query_helper( self.assertEqual(tokens, [token_1, token_2]) - (r_session, sql, transaction, params, param_types, - partition_options, metadata) = api._partition_query_with - - self.assertEqual(r_session, self.SESSION_NAME) - self.assertEqual(sql, SQL_QUERY_WITH_PARAM) - self.assertIsInstance(transaction, TransactionSelector) - self.assertEqual(transaction.id, TXN_ID) - self.assertFalse(transaction.HasField('begin')) expected_params = Struct(fields={ key: _make_value_pb(value) for (key, value) in PARAMS.items()}) - self.assertEqual(params, expected_params) - self.assertEqual(param_types, PARAM_TYPES) - self.assertEqual( - partition_options, - PartitionOptions( - partition_size_bytes=size, max_partitions=max_partitions)) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + + expected_txn_selector = TransactionSelector(id=TXN_ID) + + expected_partition_options = PartitionOptions( + partition_size_bytes=size, max_partitions=max_partitions) + + api.partition_query.assert_called_once_with( + session=self.SESSION_NAME, + sql=SQL_QUERY_WITH_PARAM, + transaction=expected_txn_selector, + params=expected_params, + param_types=PARAM_TYPES, + partition_options=expected_partition_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_partition_query_other_error(self): database = _Database() @@ -899,14 +901,15 @@ def test_begin_w_other_error(self): snapshot.begin() def test_begin_ok_exact_staleness(self): + from google.protobuf.duration_pb2 import Duration from google.cloud.spanner_v1.proto.transaction_pb2 import ( - Transaction as TransactionPB) + Transaction as TransactionPB, TransactionOptions) transaction_pb = TransactionPB(id=TXN_ID) database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _begin_transaction_response=transaction_pb) - duration = self._makeDuration(seconds=3, microseconds=123456) + api = database.spanner_api = self._make_spanner_api() + api.begin_transaction.return_value = transaction_pb + duration = self._makeDuration(seconds=SECONDS, microseconds=MICROS) session = _Session(database) snapshot = self._make_one( session, exact_staleness=duration, multi_use=True) @@ -916,22 +919,25 @@ def test_begin_ok_exact_staleness(self): self.assertEqual(txn_id, TXN_ID) self.assertEqual(snapshot._transaction_id, TXN_ID) - session_id, txn_options, metadata = api._begun - self.assertEqual(session_id, session.name) - read_only = txn_options.read_only - self.assertEqual(read_only.exact_staleness.seconds, 3) - self.assertEqual(read_only.exact_staleness.nanos, 123456000) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + expected_duration = Duration( + seconds=SECONDS, nanos=MICROS * 1000) + expected_txn_options = TransactionOptions( + read_only=TransactionOptions.ReadOnly( + exact_staleness=expected_duration)) + + api.begin_transaction.assert_called_once_with( + session.name, + expected_txn_options, + metadata=[('google-cloud-resource-prefix', database.name)]) def test_begin_ok_exact_strong(self): from google.cloud.spanner_v1.proto.transaction_pb2 import ( - Transaction as TransactionPB) + Transaction as TransactionPB, TransactionOptions) transaction_pb = TransactionPB(id=TXN_ID) database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _begin_transaction_response=transaction_pb) + api = database.spanner_api = self._make_spanner_api() + api.begin_transaction.return_value = transaction_pb session = _Session(database) snapshot = self._make_one(session, multi_use=True) @@ -940,11 +946,13 @@ def test_begin_ok_exact_strong(self): self.assertEqual(txn_id, TXN_ID) self.assertEqual(snapshot._transaction_id, TXN_ID) - session_id, txn_options, metadata = api._begun - self.assertEqual(session_id, session.name) - self.assertTrue(txn_options.read_only.strong) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + expected_txn_options = TransactionOptions( + read_only=TransactionOptions.ReadOnly(strong=True)) + + api.begin_transaction.assert_called_once_with( + session.name, + expected_txn_options, + metadata=[('google-cloud-resource-prefix', database.name)]) class _Session(object): @@ -958,63 +966,6 @@ class _Database(object): name = 'testing' -class _FauxSpannerAPI(object): - - _read_with = _begin = None - - def __init__(self, **kwargs): - self.__dict__.update(**kwargs) - - def begin_transaction(self, session, options_, metadata=None): - self._begun = (session, options_, metadata) - return self._begin_transaction_response - - # pylint: disable=too-many-arguments - def streaming_read(self, session, table, columns, key_set, - transaction=None, index='', limit=0, - resume_token=b'', partition_token=None, metadata=None): - self._streaming_read_with = ( - session, table, columns, key_set, transaction, index, - limit, resume_token, partition_token, metadata) - return self._streaming_read_response - # pylint: enable=too-many-arguments - - def execute_streaming_sql(self, session, sql, transaction=None, - params=None, param_types=None, - resume_token=b'', query_mode=None, - partition_token=None, seqno=0, metadata=None): - self._executed_streaming_sql_with = ( - session, sql, transaction, params, param_types, resume_token, - query_mode, partition_token, seqno, metadata) - return self._execute_streaming_sql_response - - # pylint: disable=too-many-arguments - def partition_read(self, session, table, key_set, - transaction=None, - index='', - columns=None, - partition_options=None, - metadata=None): - self._partition_read_with = ( - session, table, key_set, transaction, index, columns, - partition_options, metadata) - return self._partition_read_response - # pylint: enable=too-many-arguments - - # pylint: disable=too-many-arguments - def partition_query(self, session, sql, - transaction=None, - params=None, - param_types=None, - partition_options=None, - metadata=None): - self._partition_query_with = ( - session, sql, transaction, params, param_types, - partition_options, metadata) - return self._partition_query_response - # pylint: enable=too-many-arguments - - class _MockIterator(object): def __init__(self, *values, **kw): From d8f6c81cabbb60636eefa3633d27453e99338e57 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Tue, 5 Jun 2018 12:25:52 -0400 Subject: [PATCH 08/12] Required integration test cases for non-partitioned DML: - Rollback transaction after performing DML. - Mix DML and batch-style mutations in a single commit. --- spanner/tests/system/test_system.py | 56 ++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index 65ee553806ff..858887e78679 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -641,7 +641,42 @@ def _generate_insert_statements(self): @RetryErrors(exception=exceptions.ServerError) @RetryErrors(exception=exceptions.Conflict) - def test_transaction_execute_sql_w_dml_read_commit(self): + def test_transaction_execute_sql_w_dml_read_rollback(self): + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + with session.batch() as batch: + batch.delete(self.TABLE, self.ALL) + + transaction = session.transaction() + transaction.begin() + + rows = list( + transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self.assertEqual(rows, []) + + for insert_statement in self._generate_insert_statements(): + result = transaction.execute_sql(insert_statement) + list(result) # iterate to get stats + self.assertEqual(result.stats.row_count_exact, 1) + + # Rows inserted via DML *can* be read before commit. + during_rows = list( + transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(during_rows) + + transaction.rollback() + + rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(rows, []) + + @RetryErrors(exception=exceptions.ServerError) + @RetryErrors(exception=exceptions.Conflict) + def test_transaction_execute_update_read_commit(self): retry = RetryInstanceState(_has_all_ddl) retry(self._db.reload)() @@ -657,9 +692,8 @@ def test_transaction_execute_sql_w_dml_read_commit(self): self.assertEqual(rows, []) for insert_statement in self._generate_insert_statements(): - result = transaction.execute_sql(insert_statement) - list(result) # iterate to get stats - self.assertEqual(result.stats.row_count_exact, 1) + result = transaction.execute_update(insert_statement) + self.assertEqual(result.row_count_exact, 1) # Rows inserted via DML *can* be read before commit. during_rows = list( @@ -671,7 +705,7 @@ def test_transaction_execute_sql_w_dml_read_commit(self): @RetryErrors(exception=exceptions.ServerError) @RetryErrors(exception=exceptions.Conflict) - def test_transaction_execute_update_read_commit(self): + def test_transaction_execute_update_then_insert_commit(self): retry = RetryInstanceState(_has_all_ddl) retry(self._db.reload)() @@ -682,18 +716,16 @@ def test_transaction_execute_update_read_commit(self): with session.batch() as batch: batch.delete(self.TABLE, self.ALL) + insert_statement = list(self._generate_insert_statements())[0] + with session.transaction() as transaction: rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL)) self.assertEqual(rows, []) - for insert_statement in self._generate_insert_statements(): - result = transaction.execute_update(insert_statement) - self.assertEqual(result.row_count_exact, 1) + result = transaction.execute_update(insert_statement) + self.assertEqual(result.row_count_exact, 1) - # Rows inserted via DML *can* be read before commit. - during_rows = list( - transaction.read(self.TABLE, self.COLUMNS, self.ALL)) - self._check_rows_data(during_rows) + transaction.insert(self.TABLE, self.COLUMNS, self.ROW_DATA[1:]) rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) self._check_rows_data(rows) From 34d5a45796b06439d0d96de6e65d31b89d3dfcaf Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Wed, 6 Jun 2018 12:05:13 -0400 Subject: [PATCH 09/12] Return just row count from 'Transaction.execute_update'. Also, drop the 'partition' argument to it: not appropriate to the usecase. --- .../google/cloud/spanner_v1/transaction.py | 9 ++------ spanner/tests/system/test_system.py | 8 +++---- spanner/tests/unit/test_transaction.py | 23 ++++++++----------- 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/spanner/google/cloud/spanner_v1/transaction.py b/spanner/google/cloud/spanner_v1/transaction.py index beb06d5fa866..18b87f5cc383 100644 --- a/spanner/google/cloud/spanner_v1/transaction.py +++ b/spanner/google/cloud/spanner_v1/transaction.py @@ -129,7 +129,7 @@ def commit(self): return self.committed def execute_update(self, dml, params=None, param_types=None, - query_mode=None, partition=None): + query_mode=None): """Perform an ``ExecuteSql`` API request with DML. :type dml: str @@ -149,10 +149,6 @@ def execute_update(self, dml, params=None, param_types=None, :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 - :type partition: bytes - :param partition: (Optional) one of the partition tokens returned - from :meth:`partition_query`. - :rtype: :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.ResultSetStats` :returns: @@ -180,13 +176,12 @@ def execute_update(self, dml, params=None, param_types=None, params=params_pb, param_types=param_types, query_mode=query_mode, - partition_token=partition, seqno=self._execute_sql_count, metadata=metadata, ) self._execute_sql_count += 1 - return response.stats + return response.stats.row_count_exact def __enter__(self): """Begin ``with`` block.""" diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index 858887e78679..b2a99cf45c31 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -692,8 +692,8 @@ def test_transaction_execute_update_read_commit(self): self.assertEqual(rows, []) for insert_statement in self._generate_insert_statements(): - result = transaction.execute_update(insert_statement) - self.assertEqual(result.row_count_exact, 1) + row_count = transaction.execute_update(insert_statement) + self.assertEqual(row_count, 1) # Rows inserted via DML *can* be read before commit. during_rows = list( @@ -722,8 +722,8 @@ def test_transaction_execute_update_then_insert_commit(self): rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL)) self.assertEqual(rows, []) - result = transaction.execute_update(insert_statement) - self.assertEqual(result.row_count_exact, 1) + row_count = transaction.execute_update(insert_statement) + self.assertEqual(row_count, 1) transaction.insert(self.TABLE, self.COLUMNS, self.ROW_DATA[1:]) diff --git a/spanner/tests/unit/test_transaction.py b/spanner/tests/unit/test_transaction.py index 8dc1b49bf7e1..99c401cc7e10 100644 --- a/spanner/tests/unit/test_transaction.py +++ b/spanner/tests/unit/test_transaction.py @@ -325,7 +325,7 @@ def test_execute_update_w_params_wo_param_types(self): with self.assertRaises(ValueError): transaction.execute_update(DML_QUERY_WITH_PARAM, PARAMS) - def _execute_update_helper(self, partition=None, count=0): + def _execute_update_helper(self, count=0): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1.proto.result_set_pb2 import ( ResultSet, ResultSetStats) @@ -334,10 +334,7 @@ def _execute_update_helper(self, partition=None, count=0): from google.cloud.spanner_v1._helpers import _make_value_pb MODE = 2 # PROFILE - stats_pb = ResultSetStats( - query_stats=Struct(fields={ - 'rows_affected': _make_value_pb(1), - })) + stats_pb = ResultSetStats(row_count_exact=1) database = _Database() api = database.spanner_api = self._make_spanner_api() api.execute_sql.return_value = ResultSet(stats=stats_pb) @@ -346,11 +343,10 @@ def _execute_update_helper(self, partition=None, count=0): transaction._transaction_id = self.TRANSACTION_ID transaction._execute_sql_count = count - result = transaction.execute_update( - DML_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, - query_mode=MODE, partition=partition) + row_count = transaction.execute_update( + DML_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, query_mode=MODE) - self.assertEqual(result, stats_pb) + self.assertEqual(row_count, 1) expected_transaction = TransactionSelector(id=self.TRANSACTION_ID) expected_params = Struct(fields={ @@ -363,18 +359,17 @@ def _execute_update_helper(self, partition=None, count=0): params=expected_params, param_types=PARAM_TYPES, query_mode=MODE, - partition_token=partition, seqno=count, metadata=[('google-cloud-resource-prefix', database.name)], ) self.assertEqual(transaction._execute_sql_count, count + 1) - def test_execute_update_w_count_wo_partition(self): - self._execute_update_helper(count=1) + def test_execute_update_new_transaction(self): + self._execute_update_helper() - def test_execute_update_wo_count_w_partition(self): - self._execute_update_helper(partition=b'FACEDACE') + def test_execute_update_w_count(self): + self._execute_update_helper(count=1) def test_context_mgr_success(self): import datetime From a37f74a4013ab3082898913b67bf0092ec54ed77 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Fri, 15 Jun 2018 13:39:21 -0400 Subject: [PATCH 10/12] Add partitioned DML support (#459) * Add 'Datatbase.execute_partitioned_dml' method. * Add system test which exercises PDML. both for UPDATE (with parameter) and DELETE. --- spanner/google/cloud/spanner_v1/database.py | 73 ++- .../google/cloud/spanner_v1/transaction.py | 7 +- spanner/tests/system/test_system.py | 56 +++ spanner/tests/unit/test_database.py | 438 ++++++++++-------- 4 files changed, 380 insertions(+), 194 deletions(-) diff --git a/spanner/google/cloud/spanner_v1/database.py b/spanner/google/cloud/spanner_v1/database.py index d3494eb63902..93960947de20 100644 --- a/spanner/google/cloud/spanner_v1/database.py +++ b/spanner/google/cloud/spanner_v1/database.py @@ -14,17 +14,20 @@ """User friendly container for Cloud Spanner Database.""" +import copy +import functools import re import threading -import copy from google.api_core.gapic_v1 import client_info import google.auth.credentials +from google.protobuf.struct_pb2 import Struct from google.cloud.exceptions import NotFound import six # pylint: disable=ungrouped-imports from google.cloud.spanner_v1 import __version__ +from google.cloud.spanner_v1._helpers import _make_value_pb from google.cloud.spanner_v1._helpers import _metadata_with_prefix from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient @@ -32,7 +35,11 @@ from google.cloud.spanner_v1.pool import BurstyPool from google.cloud.spanner_v1.pool import SessionCheckout from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.snapshot import _restart_on_unavailable from google.cloud.spanner_v1.snapshot import Snapshot +from google.cloud.spanner_v1.streamed import StreamedResultSet +from google.cloud.spanner_v1.proto.transaction_pb2 import ( + TransactionSelector, TransactionOptions) # pylint: enable=ungrouped-imports @@ -272,6 +279,70 @@ def drop(self): metadata = _metadata_with_prefix(self.name) api.drop_database(self.name, metadata=metadata) + def execute_partitioned_dml( + self, dml, params=None, param_types=None, query_mode=None): + """Execute a partitionable DML statement. + + :type dml: str + :param dml: SQL DML statement + + :type params: dict, {str -> column value} + :param params: values for parameter replacement. Keys must match + the names used in ``dml``. + + :type param_types: dict[str -> Union[dict, .types.Type]] + :param param_types: + (Optional) maps explicit types for one or more param values; + required if parameters are passed. + + :type query_mode: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryMode` + :param query_mode: Mode governing return of results / query plan. See + https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 + + :rtype: int + :returns: Count of rows affected by the DML statement. + """ + if params is not None: + if param_types is None: + raise ValueError( + "Specify 'param_types' when passing 'params'.") + params_pb = Struct(fields={ + key: _make_value_pb(value) for key, value in params.items()}) + else: + params_pb = None + + api = self.spanner_api + + txn_options = TransactionOptions( + partitioned_dml=TransactionOptions.PartitionedDml()) + + metadata = _metadata_with_prefix(self.name) + + with SessionCheckout(self._pool) as session: + + txn = api.begin_transaction( + session.name, txn_options, metadata=metadata) + + txn_selector = TransactionSelector(id=txn.id) + + restart = functools.partial( + api.execute_streaming_sql, + session.name, + dml, + transaction=txn_selector, + params=params_pb, + param_types=param_types, + query_mode=query_mode, + metadata=metadata) + + iterator = _restart_on_unavailable(restart) + + result_set = StreamedResultSet(iterator) + list(result_set) # consume all partials + + return result_set.stats.row_count_lower_bound + def session(self, labels=None): """Factory to create a session for this database. diff --git a/spanner/google/cloud/spanner_v1/transaction.py b/spanner/google/cloud/spanner_v1/transaction.py index 18b87f5cc383..cc2f06cee54d 100644 --- a/spanner/google/cloud/spanner_v1/transaction.py +++ b/spanner/google/cloud/spanner_v1/transaction.py @@ -149,11 +149,8 @@ def execute_update(self, dml, params=None, param_types=None, :param query_mode: Mode governing return of results / query plan. See https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 - :rtype: - :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.ResultSetStats` - :returns: - stats object, including count of rows affected by the DML - statement. + :rtype: int + :returns: Count of rows affected by the DML statement. """ if params is not None: if param_types is None: diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index b2a99cf45c31..228cd7849fa0 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -730,6 +730,62 @@ def test_transaction_execute_update_then_insert_commit(self): rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) self._check_rows_data(rows) + def test_execute_partitioned_dml(self): + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + delete_statement = 'DELETE FROM {} WHERE true'.format(self.TABLE) + + def _setup_table(txn): + txn.execute_update(delete_statement) + for insert_statement in self._generate_insert_statements(): + txn.execute_update(insert_statement) + + committed = self._db.run_in_transaction(_setup_table) + + with self._db.snapshot(read_timestamp=committed) as snapshot: + before_pdml = list(snapshot.read( + self.TABLE, self.COLUMNS, self.ALL)) + + self._check_rows_data(before_pdml) + + nonesuch = 'nonesuch@example.com' + target = 'phred@example.com' + update_statement = ( + 'UPDATE {table} SET {table}.email = @email ' + 'WHERE {table}.email = @target').format( + table=self.TABLE) + + row_count = self._db.execute_partitioned_dml( + update_statement, + params={ + 'email': nonesuch, + 'target': target, + }, + param_types={ + 'email': Type(code=STRING), + 'target': Type(code=STRING), + }, + ) + self.assertEqual(row_count, 1) + + row = self.ROW_DATA[0] + updated = [row[:3] + (nonesuch,)] + list(self.ROW_DATA[1:]) + + with self._db.snapshot(read_timestamp=committed) as snapshot: + after_update = list(snapshot.read( + self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(after_update, updated) + + row_count = self._db.execute_partitioned_dml(delete_statement) + self.assertEqual(row_count, len(self.ROW_DATA)) + + with self._db.snapshot(read_timestamp=committed) as snapshot: + after_delete = list(snapshot.read( + self.TABLE, self.COLUMNS, self.ALL)) + + self._check_rows_data(after_delete, []) + def _transaction_concurrency_helper(self, unit_of_work, pkey): INITIAL_VALUE = 123 NUM_THREADS = 3 # conforms to equivalent Java systest. diff --git a/spanner/tests/unit/test_database.py b/spanner/tests/unit/test_database.py index 34b30deb2022..c17251647511 100644 --- a/spanner/tests/unit/test_database.py +++ b/spanner/tests/unit/test_database.py @@ -18,6 +18,19 @@ import mock +DML_WO_PARAM = """ +DELETE FROM citizens +""" + +DML_W_PARAM = """ +INSERT INTO citizens(first_name, last_name, age) +VALUES ("Phred", "Phlyntstone", @age) +""" +PARAMS = {'age': 30} +PARAM_TYPES = {'age': 'INT64'} +MODE = 2 # PROFILE + + def _make_credentials(): # pragma: NO COVER import google.auth.credentials @@ -39,7 +52,7 @@ class _BaseTest(unittest.TestCase): DATABASE_NAME = INSTANCE_NAME + '/databases/' + DATABASE_ID SESSION_ID = 'session_id' SESSION_NAME = DATABASE_NAME + '/sessions/' + SESSION_ID - TRANSACTION_ID = 'transaction_id' + TRANSACTION_ID = b'transaction_id' def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) @@ -65,6 +78,20 @@ def _get_target_class(self): return Database + @staticmethod + def _make_database_admin_api(): + from google.cloud.spanner_v1.client import DatabaseAdminClient + + return mock.create_autospec(DatabaseAdminClient, instance=True) + + @staticmethod + def _make_spanner_api(): + import google.cloud.spanner_v1.gapic.spanner_client + + return mock.create_autospec( + google.cloud.spanner_v1.gapic.spanner_client.SpannerClient, + instance=True) + def test_ctor_defaults(self): from google.cloud.spanner_v1.pool import BurstyPool @@ -296,10 +323,12 @@ def test___ne__(self): def test_create_grpc_error(self): from google.api_core.exceptions import GoogleAPICallError + from google.api_core.exceptions import Unknown client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.side_effect = Unknown('testing') + instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -307,22 +336,20 @@ def test_create_grpc_error(self): with self.assertRaises(GoogleAPICallError): database.create() - (parent, create_statement, extra_statements, - metadata) = api._created_database - self.assertEqual(parent, self.INSTANCE_NAME) - self.assertEqual(create_statement, - 'CREATE DATABASE %s' % self.DATABASE_ID) - self.assertEqual(extra_statements, []) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.create_database.assert_called_once_with( + parent=self.INSTANCE_NAME, + create_statement='CREATE DATABASE {}'.format(self.DATABASE_ID), + extra_statements=[], + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_create_already_exists(self): from google.cloud.exceptions import Conflict DATABASE_ID_HYPHEN = 'database-id' client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _create_database_conflict=True) + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.side_effect = Conflict('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(DATABASE_ID_HYPHEN, instance, pool=pool) @@ -330,45 +357,40 @@ def test_create_already_exists(self): with self.assertRaises(Conflict): database.create() - (parent, create_statement, extra_statements, - metadata) = api._created_database - self.assertEqual(parent, self.INSTANCE_NAME) - self.assertEqual(create_statement, - 'CREATE DATABASE `%s`' % DATABASE_ID_HYPHEN) - self.assertEqual(extra_statements, []) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.create_database.assert_called_once_with( + parent=self.INSTANCE_NAME, + create_statement='CREATE DATABASE `{}`'.format(DATABASE_ID_HYPHEN), + extra_statements=[], + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_create_instance_not_found(self): from google.cloud.exceptions import NotFound - DATABASE_ID_HYPHEN = 'database-id' client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() - database = self._make_one(DATABASE_ID_HYPHEN, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance, pool=pool) with self.assertRaises(NotFound): database.create() - (parent, create_statement, extra_statements, - metadata) = api._created_database - self.assertEqual(parent, self.INSTANCE_NAME) - self.assertEqual(create_statement, - 'CREATE DATABASE `%s`' % DATABASE_ID_HYPHEN) - self.assertEqual(extra_statements, []) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.create_database.assert_called_once_with( + parent=self.INSTANCE_NAME, + create_statement='CREATE DATABASE {}'.format(self.DATABASE_ID), + extra_statements=[], + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_create_success(self): from tests._fixtures import DDL_STATEMENTS - op_future = _FauxOperationFuture() + op_future = object() client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _create_database_response=op_future) + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one( @@ -379,21 +401,19 @@ def test_create_success(self): self.assertIs(future, op_future) - (parent, create_statement, extra_statements, - metadata) = api._created_database - self.assertEqual(parent, self.INSTANCE_NAME) - self.assertEqual(create_statement, - 'CREATE DATABASE %s' % self.DATABASE_ID) - self.assertEqual(extra_statements, DDL_STATEMENTS) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.create_database.assert_called_once_with( + parent=self.INSTANCE_NAME, + create_statement='CREATE DATABASE {}'.format(self.DATABASE_ID), + extra_statements=DDL_STATEMENTS, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_exists_grpc_error(self): from google.api_core.exceptions import Unknown client = _Client() - client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.side_effect = Unknown('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -401,20 +421,27 @@ def test_exists_grpc_error(self): with self.assertRaises(Unknown): database.exists() + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_exists_not_found(self): + from google.cloud.exceptions import NotFound + client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) self.assertFalse(database.exists()) - name, metadata = api._got_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_exists_success(self): from google.cloud.spanner_admin_database_v1.proto import ( @@ -424,25 +451,25 @@ def test_exists_success(self): client = _Client() ddl_pb = admin_v1_pb2.GetDatabaseDdlResponse( statements=DDL_STATEMENTS) - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _get_database_ddl_response=ddl_pb) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.return_value = ddl_pb instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) self.assertTrue(database.exists()) - name, metadata = api._got_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_reload_grpc_error(self): from google.api_core.exceptions import Unknown client = _Client() - client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.side_effect = Unknown('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -450,12 +477,17 @@ def test_reload_grpc_error(self): with self.assertRaises(Unknown): database.reload() + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_reload_not_found(self): from google.cloud.exceptions import NotFound client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -463,10 +495,10 @@ def test_reload_not_found(self): with self.assertRaises(NotFound): database.reload() - name, metadata = api._got_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_reload_success(self): from google.cloud.spanner_admin_database_v1.proto import ( @@ -476,8 +508,8 @@ def test_reload_success(self): client = _Client() ddl_pb = admin_v1_pb2.GetDatabaseDdlResponse( statements=DDL_STATEMENTS) - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _get_database_ddl_response=ddl_pb) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.return_value = ddl_pb instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -486,18 +518,18 @@ def test_reload_success(self): self.assertEqual(database._ddl_statements, tuple(DDL_STATEMENTS)) - name, metadata = api._got_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_update_ddl_grpc_error(self): from google.api_core.exceptions import Unknown from tests._fixtures import DDL_STATEMENTS client = _Client() - client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.update_database_ddl.side_effect = Unknown('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -505,13 +537,20 @@ def test_update_ddl_grpc_error(self): with self.assertRaises(Unknown): database.update_ddl(DDL_STATEMENTS) + api.update_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + DDL_STATEMENTS, + '', + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_update_ddl_not_found(self): from google.cloud.exceptions import NotFound from tests._fixtures import DDL_STATEMENTS client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.update_database_ddl.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -519,20 +558,20 @@ def test_update_ddl_not_found(self): with self.assertRaises(NotFound): database.update_ddl(DDL_STATEMENTS) - name, statements, op_id, metadata = api._updated_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual(statements, DDL_STATEMENTS) - self.assertEqual(op_id, '') - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.update_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + DDL_STATEMENTS, + '', + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_update_ddl(self): from tests._fixtures import DDL_STATEMENTS - op_future = _FauxOperationFuture() + op_future = object() client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _update_database_ddl_response=op_future) + api = client.database_admin_api = self._make_database_admin_api() + api.update_database_ddl.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -541,19 +580,19 @@ def test_update_ddl(self): self.assertIs(future, op_future) - name, statements, op_id, metadata = api._updated_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual(statements, DDL_STATEMENTS) - self.assertEqual(op_id, '') - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.update_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + DDL_STATEMENTS, + '', + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_drop_grpc_error(self): from google.api_core.exceptions import Unknown client = _Client() - client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.drop_database.side_effect = Unknown('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -561,12 +600,17 @@ def test_drop_grpc_error(self): with self.assertRaises(Unknown): database.drop() + api.drop_database.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_drop_not_found(self): from google.cloud.exceptions import NotFound client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.drop_database.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -574,27 +618,99 @@ def test_drop_not_found(self): with self.assertRaises(NotFound): database.drop() - name, metadata = api._dropped_database - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.drop_database.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_drop_success(self): from google.protobuf.empty_pb2 import Empty client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _drop_database_response=Empty()) + api = client.database_admin_api = self._make_database_admin_api() + api.drop_database.return_value = Empty() instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) database.drop() - name, metadata = api._dropped_database - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.drop_database.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + + def _execute_partitioned_dml_helper( + self, dml, params=None, param_types=None): + from google.protobuf.struct_pb2 import Struct + from google.cloud.spanner_v1.proto.result_set_pb2 import ( + PartialResultSet, ResultSetStats) + from google.cloud.spanner_v1.proto.transaction_pb2 import ( + Transaction as TransactionPB, + TransactionSelector, TransactionOptions) + from google.cloud.spanner_v1._helpers import _make_value_pb + + transaction_pb = TransactionPB(id=self.TRANSACTION_ID) + + stats_pb = ResultSetStats(row_count_lower_bound=2) + result_sets = [ + PartialResultSet(stats=stats_pb), + ] + iterator = _MockIterator(*result_sets) + + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + api = database._spanner_api = self._make_spanner_api() + api.begin_transaction.return_value = transaction_pb + api.execute_streaming_sql.return_value = iterator + + row_count = database.execute_partitioned_dml( + dml, params, param_types, query_mode=MODE) + + self.assertEqual(row_count, 2) + + txn_options = TransactionOptions( + partitioned_dml=TransactionOptions.PartitionedDml()) + + api.begin_transaction.assert_called_once_with( + session.name, + txn_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + + if params: + expected_params = Struct(fields={ + key: _make_value_pb(value) for (key, value) in params.items()}) + else: + expected_params = None + + expected_transaction = TransactionSelector(id=self.TRANSACTION_ID) + + api.execute_streaming_sql.assert_called_once_with( + self.SESSION_NAME, + dml, + transaction=expected_transaction, + params=expected_params, + param_types=param_types, + query_mode=MODE, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + + def test_execute_partitioned_dml_wo_params(self): + self._execute_partitioned_dml_helper(dml=DML_WO_PARAM) + + def test_execute_partitioned_dml_w_params_wo_param_types(self): + with self.assertRaises(ValueError): + self._execute_partitioned_dml_helper( + dml=DML_W_PARAM, params=PARAMS) + + def test_execute_partitioned_dml_w_params_and_param_types(self): + self._execute_partitioned_dml_helper( + dml=DML_W_PARAM, params=PARAMS, param_types=PARAM_TYPES) def test_session_factory_defaults(self): from google.cloud.spanner_v1.session import Session @@ -787,6 +903,12 @@ def _get_target_class(self): return BatchCheckout + @staticmethod + def _make_spanner_client(): + from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient + + return mock.create_autospec(SpannerClient) + def test_ctor(self): database = _Database(self.DATABASE_NAME) checkout = self._make_one(database) @@ -805,8 +927,8 @@ def test_context_mgr_success(self): now_pb = _datetime_to_pb_timestamp(now) response = CommitResponse(commit_timestamp=now_pb) database = _Database(self.DATABASE_NAME) - api = database.spanner_api = _FauxSpannerClient() - api._commit_response = response + api = database.spanner_api = self._make_spanner_client() + api.commit.return_value = response pool = database._pool = _Pool() session = _Session(database) pool.put(session) @@ -819,14 +941,15 @@ def test_context_mgr_success(self): self.assertIs(pool._session, session) self.assertEqual(batch.committed, now) - (session_name, mutations, single_use_txn, - metadata) = api._committed - self.assertIs(session_name, self.SESSION_NAME) - self.assertEqual(mutations, []) - self.assertIsInstance(single_use_txn, TransactionOptions) - self.assertTrue(single_use_txn.HasField('read_write')) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + + expected_txn_options = TransactionOptions(read_write={}) + + api.commit.assert_called_once_with( + self.SESSION_NAME, + [], + single_use_transaction=expected_txn_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_context_mgr_failure(self): from google.cloud.spanner_v1.batch import Batch @@ -1433,80 +1556,19 @@ def run_in_transaction(self, func, *args, **kw): return self._committed -class _SessionPB(object): - name = TestDatabase.SESSION_NAME - - -class _FauxOperationFuture(object): - pass - - -class _FauxSpannerClient(object): - - _committed = None - - def __init__(self, **kwargs): - self.__dict__.update(**kwargs) - - def commit(self, session, mutations, - transaction_id='', single_use_transaction=None, metadata=None): - assert transaction_id == '' - self._committed = ( - session, mutations, single_use_transaction, metadata) - return self._commit_response - - -class _FauxDatabaseAdminAPI(object): - - _create_database_conflict = False - _database_not_found = False - _rpc_error = False - - def __init__(self, **kwargs): - self.__dict__.update(**kwargs) - - def create_database(self, parent, create_statement, extra_statements=None, - metadata=None): - from google.api_core.exceptions import AlreadyExists, NotFound, Unknown - - self._created_database = ( - parent, create_statement, extra_statements, metadata) - if self._rpc_error: - raise Unknown('error') - if self._create_database_conflict: - raise AlreadyExists('conflict') - if self._database_not_found: - raise NotFound('not found') - return self._create_database_response - - def get_database_ddl(self, database, metadata=None): - from google.api_core.exceptions import NotFound, Unknown - - self._got_database_ddl = database, metadata - if self._rpc_error: - raise Unknown('error') - if self._database_not_found: - raise NotFound('not found') - return self._get_database_ddl_response +class _MockIterator(object): - def drop_database(self, database, metadata=None): - from google.api_core.exceptions import NotFound, Unknown + def __init__(self, *values, **kw): + self._iter_values = iter(values) + self._fail_after = kw.pop('fail_after', False) - self._dropped_database = database, metadata - if self._rpc_error: - raise Unknown('error') - if self._database_not_found: - raise NotFound('not found') - return self._drop_database_response + def __iter__(self): + return self - def update_database_ddl(self, database, statements, operation_id, - metadata=None): - from google.api_core.exceptions import NotFound, Unknown + def __next__(self): + try: + return next(self._iter_values) + except StopIteration: + raise - self._updated_database_ddl = ( - database, statements, operation_id, metadata) - if self._rpc_error: - raise Unknown('error') - if self._database_not_found: - raise NotFound('not found') - return self._update_database_ddl_response + next = __next__ From 7660f40c84fb513b41afb59b7f9b8b5fb5459344 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Wed, 5 Sep 2018 11:04:32 -0400 Subject: [PATCH 11/12] Review comment: s/SQL DML/DML/. --- spanner/google/cloud/spanner_v1/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spanner/google/cloud/spanner_v1/database.py b/spanner/google/cloud/spanner_v1/database.py index 93960947de20..d3411cd0ba84 100644 --- a/spanner/google/cloud/spanner_v1/database.py +++ b/spanner/google/cloud/spanner_v1/database.py @@ -284,7 +284,7 @@ def execute_partitioned_dml( """Execute a partitionable DML statement. :type dml: str - :param dml: SQL DML statement + :param dml: DML statement :type params: dict, {str -> column value} :param params: values for parameter replacement. Keys must match From d8a714dab965984928f57f58270c9d4b5ee7cfd4 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Wed, 5 Sep 2018 11:07:11 -0400 Subject: [PATCH 12/12] Review comment: PDML does not support query mode. --- spanner/google/cloud/spanner_v1/database.py | 8 +------- spanner/tests/unit/test_database.py | 3 +-- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/spanner/google/cloud/spanner_v1/database.py b/spanner/google/cloud/spanner_v1/database.py index d3411cd0ba84..6fb367d3ab87 100644 --- a/spanner/google/cloud/spanner_v1/database.py +++ b/spanner/google/cloud/spanner_v1/database.py @@ -280,7 +280,7 @@ def drop(self): api.drop_database(self.name, metadata=metadata) def execute_partitioned_dml( - self, dml, params=None, param_types=None, query_mode=None): + self, dml, params=None, param_types=None): """Execute a partitionable DML statement. :type dml: str @@ -295,11 +295,6 @@ def execute_partitioned_dml( (Optional) maps explicit types for one or more param values; required if parameters are passed. - :type query_mode: - :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryMode` - :param query_mode: Mode governing return of results / query plan. See - https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 - :rtype: int :returns: Count of rows affected by the DML statement. """ @@ -333,7 +328,6 @@ def execute_partitioned_dml( transaction=txn_selector, params=params_pb, param_types=param_types, - query_mode=query_mode, metadata=metadata) iterator = _restart_on_unavailable(restart) diff --git a/spanner/tests/unit/test_database.py b/spanner/tests/unit/test_database.py index c17251647511..afc358ffc509 100644 --- a/spanner/tests/unit/test_database.py +++ b/spanner/tests/unit/test_database.py @@ -669,7 +669,7 @@ def _execute_partitioned_dml_helper( api.execute_streaming_sql.return_value = iterator row_count = database.execute_partitioned_dml( - dml, params, param_types, query_mode=MODE) + dml, params, param_types) self.assertEqual(row_count, 2) @@ -696,7 +696,6 @@ def _execute_partitioned_dml_helper( transaction=expected_transaction, params=expected_params, param_types=param_types, - query_mode=MODE, metadata=[('google-cloud-resource-prefix', database.name)], )