diff --git a/spanner/google/cloud/spanner_v1/database.py b/spanner/google/cloud/spanner_v1/database.py index d3494eb63902..6fb367d3ab87 100644 --- a/spanner/google/cloud/spanner_v1/database.py +++ b/spanner/google/cloud/spanner_v1/database.py @@ -14,17 +14,20 @@ """User friendly container for Cloud Spanner Database.""" +import copy +import functools import re import threading -import copy from google.api_core.gapic_v1 import client_info import google.auth.credentials +from google.protobuf.struct_pb2 import Struct from google.cloud.exceptions import NotFound import six # pylint: disable=ungrouped-imports from google.cloud.spanner_v1 import __version__ +from google.cloud.spanner_v1._helpers import _make_value_pb from google.cloud.spanner_v1._helpers import _metadata_with_prefix from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient @@ -32,7 +35,11 @@ from google.cloud.spanner_v1.pool import BurstyPool from google.cloud.spanner_v1.pool import SessionCheckout from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.snapshot import _restart_on_unavailable from google.cloud.spanner_v1.snapshot import Snapshot +from google.cloud.spanner_v1.streamed import StreamedResultSet +from google.cloud.spanner_v1.proto.transaction_pb2 import ( + TransactionSelector, TransactionOptions) # pylint: enable=ungrouped-imports @@ -272,6 +279,64 @@ def drop(self): metadata = _metadata_with_prefix(self.name) api.drop_database(self.name, metadata=metadata) + def execute_partitioned_dml( + self, dml, params=None, param_types=None): + """Execute a partitionable DML statement. + + :type dml: str + :param dml: DML statement + + :type params: dict, {str -> column value} + :param params: values for parameter replacement. Keys must match + the names used in ``dml``. + + :type param_types: dict[str -> Union[dict, .types.Type]] + :param param_types: + (Optional) maps explicit types for one or more param values; + required if parameters are passed. + + :rtype: int + :returns: Count of rows affected by the DML statement. + """ + if params is not None: + if param_types is None: + raise ValueError( + "Specify 'param_types' when passing 'params'.") + params_pb = Struct(fields={ + key: _make_value_pb(value) for key, value in params.items()}) + else: + params_pb = None + + api = self.spanner_api + + txn_options = TransactionOptions( + partitioned_dml=TransactionOptions.PartitionedDml()) + + metadata = _metadata_with_prefix(self.name) + + with SessionCheckout(self._pool) as session: + + txn = api.begin_transaction( + session.name, txn_options, metadata=metadata) + + txn_selector = TransactionSelector(id=txn.id) + + restart = functools.partial( + api.execute_streaming_sql, + session.name, + dml, + transaction=txn_selector, + params=params_pb, + param_types=param_types, + metadata=metadata) + + iterator = _restart_on_unavailable(restart) + + result_set = StreamedResultSet(iterator) + list(result_set) # consume all partials + + return result_set.stats.row_count_lower_bound + def session(self, labels=None): """Factory to create a session for this database. diff --git a/spanner/google/cloud/spanner_v1/snapshot.py b/spanner/google/cloud/spanner_v1/snapshot.py index 827da34ee7c4..00d45410f499 100644 --- a/spanner/google/cloud/spanner_v1/snapshot.py +++ b/spanner/google/cloud/spanner_v1/snapshot.py @@ -71,6 +71,7 @@ class _SnapshotBase(_SessionWrapper): _multi_use = False _transaction_id = None _read_request_count = 0 + _execute_sql_count = 0 def _make_txn_selector(self): # pylint: disable=redundant-returns-doc """Helper for :meth:`read` / :meth:`execute_sql`. @@ -195,14 +196,20 @@ def execute_sql(self, sql, params=None, param_types=None, restart = functools.partial( api.execute_streaming_sql, - self._session.name, sql, - transaction=transaction, params=params_pb, param_types=param_types, - query_mode=query_mode, partition_token=partition, + self._session.name, + sql, + transaction=transaction, + params=params_pb, + param_types=param_types, + query_mode=query_mode, + partition_token=partition, + seqno=self._execute_sql_count, metadata=metadata) iterator = _restart_on_unavailable(restart) self._read_request_count += 1 + self._execute_sql_count += 1 if self._multi_use: return StreamedResultSet(iterator, source=self) diff --git a/spanner/google/cloud/spanner_v1/transaction.py b/spanner/google/cloud/spanner_v1/transaction.py index 9f2f6d99895e..cc2f06cee54d 100644 --- a/spanner/google/cloud/spanner_v1/transaction.py +++ b/spanner/google/cloud/spanner_v1/transaction.py @@ -14,11 +14,13 @@ """Spanner read-write transaction support.""" -from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionSelector -from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionOptions +from google.protobuf.struct_pb2 import Struct from google.cloud._helpers import _pb_timestamp_to_datetime +from google.cloud.spanner_v1._helpers import _make_value_pb from google.cloud.spanner_v1._helpers import _metadata_with_prefix +from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionSelector +from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionOptions from google.cloud.spanner_v1.snapshot import _SnapshotBase from google.cloud.spanner_v1.batch import _BatchBase @@ -35,6 +37,7 @@ class Transaction(_SnapshotBase, _BatchBase): """Timestamp at which the transaction was successfully committed.""" _rolled_back = False _multi_use = True + _execute_sql_count = 0 def __init__(self, session): if session._transaction is not None: @@ -114,9 +117,6 @@ def commit(self): """ self._check_state() - if not self._mutations: - raise ValueError("No mutations to commit") - database = self._session._database api = database.spanner_api metadata = _metadata_with_prefix(database.name) @@ -128,6 +128,58 @@ def commit(self): del self._session._transaction return self.committed + def execute_update(self, dml, params=None, param_types=None, + query_mode=None): + """Perform an ``ExecuteSql`` API request with DML. + + :type dml: str + :param dml: SQL DML statement + + :type params: dict, {str -> column value} + :param params: values for parameter replacement. Keys must match + the names used in ``dml``. + + :type param_types: dict[str -> Union[dict, .types.Type]] + :param param_types: + (Optional) maps explicit types for one or more param values; + required if parameters are passed. + + :type query_mode: + :class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryMode` + :param query_mode: Mode governing return of results / query plan. See + https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1 + + :rtype: int + :returns: Count of rows affected by the DML statement. + """ + if params is not None: + if param_types is None: + raise ValueError( + "Specify 'param_types' when passing 'params'.") + params_pb = Struct(fields={ + key: _make_value_pb(value) for key, value in params.items()}) + else: + params_pb = None + + database = self._session._database + metadata = _metadata_with_prefix(database.name) + transaction = self._make_txn_selector() + api = database.spanner_api + + response = api.execute_sql( + self._session.name, + dml, + transaction=transaction, + params=params_pb, + param_types=param_types, + query_mode=query_mode, + seqno=self._execute_sql_count, + metadata=metadata, + ) + + self._execute_sql_count += 1 + return response.stats.row_count_exact + def __enter__(self): """Begin ``with`` block.""" self.begin() diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index 2d85a99531b6..228cd7849fa0 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -627,6 +627,165 @@ def test_transaction_read_and_insert_or_update_then_commit(self): rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) self._check_rows_data(rows) + def _generate_insert_statements(self): + insert_template = ( + 'INSERT INTO {table} ({column_list}) ' + 'VALUES ({row_data})' + ) + for row in self.ROW_DATA: + yield insert_template.format( + table=self.TABLE, + column_list=', '.join(self.COLUMNS), + row_data='{}, "{}", "{}", "{}"'.format(*row) + ) + + @RetryErrors(exception=exceptions.ServerError) + @RetryErrors(exception=exceptions.Conflict) + def test_transaction_execute_sql_w_dml_read_rollback(self): + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + with session.batch() as batch: + batch.delete(self.TABLE, self.ALL) + + transaction = session.transaction() + transaction.begin() + + rows = list( + transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self.assertEqual(rows, []) + + for insert_statement in self._generate_insert_statements(): + result = transaction.execute_sql(insert_statement) + list(result) # iterate to get stats + self.assertEqual(result.stats.row_count_exact, 1) + + # Rows inserted via DML *can* be read before commit. + during_rows = list( + transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(during_rows) + + transaction.rollback() + + rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(rows, []) + + @RetryErrors(exception=exceptions.ServerError) + @RetryErrors(exception=exceptions.Conflict) + def test_transaction_execute_update_read_commit(self): + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + with session.batch() as batch: + batch.delete(self.TABLE, self.ALL) + + with session.transaction() as transaction: + rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self.assertEqual(rows, []) + + for insert_statement in self._generate_insert_statements(): + row_count = transaction.execute_update(insert_statement) + self.assertEqual(row_count, 1) + + # Rows inserted via DML *can* be read before commit. + during_rows = list( + transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(during_rows) + + rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(rows) + + @RetryErrors(exception=exceptions.ServerError) + @RetryErrors(exception=exceptions.Conflict) + def test_transaction_execute_update_then_insert_commit(self): + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + with session.batch() as batch: + batch.delete(self.TABLE, self.ALL) + + insert_statement = list(self._generate_insert_statements())[0] + + with session.transaction() as transaction: + rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self.assertEqual(rows, []) + + row_count = transaction.execute_update(insert_statement) + self.assertEqual(row_count, 1) + + transaction.insert(self.TABLE, self.COLUMNS, self.ROW_DATA[1:]) + + rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(rows) + + def test_execute_partitioned_dml(self): + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + delete_statement = 'DELETE FROM {} WHERE true'.format(self.TABLE) + + def _setup_table(txn): + txn.execute_update(delete_statement) + for insert_statement in self._generate_insert_statements(): + txn.execute_update(insert_statement) + + committed = self._db.run_in_transaction(_setup_table) + + with self._db.snapshot(read_timestamp=committed) as snapshot: + before_pdml = list(snapshot.read( + self.TABLE, self.COLUMNS, self.ALL)) + + self._check_rows_data(before_pdml) + + nonesuch = 'nonesuch@example.com' + target = 'phred@example.com' + update_statement = ( + 'UPDATE {table} SET {table}.email = @email ' + 'WHERE {table}.email = @target').format( + table=self.TABLE) + + row_count = self._db.execute_partitioned_dml( + update_statement, + params={ + 'email': nonesuch, + 'target': target, + }, + param_types={ + 'email': Type(code=STRING), + 'target': Type(code=STRING), + }, + ) + self.assertEqual(row_count, 1) + + row = self.ROW_DATA[0] + updated = [row[:3] + (nonesuch,)] + list(self.ROW_DATA[1:]) + + with self._db.snapshot(read_timestamp=committed) as snapshot: + after_update = list(snapshot.read( + self.TABLE, self.COLUMNS, self.ALL)) + self._check_rows_data(after_update, updated) + + row_count = self._db.execute_partitioned_dml(delete_statement) + self.assertEqual(row_count, len(self.ROW_DATA)) + + with self._db.snapshot(read_timestamp=committed) as snapshot: + after_delete = list(snapshot.read( + self.TABLE, self.COLUMNS, self.ALL)) + + self._check_rows_data(after_delete, []) + def _transaction_concurrency_helper(self, unit_of_work, pkey): INITIAL_VALUE = 123 NUM_THREADS = 3 # conforms to equivalent Java systest. diff --git a/spanner/tests/unit/test_database.py b/spanner/tests/unit/test_database.py index 34b30deb2022..afc358ffc509 100644 --- a/spanner/tests/unit/test_database.py +++ b/spanner/tests/unit/test_database.py @@ -18,6 +18,19 @@ import mock +DML_WO_PARAM = """ +DELETE FROM citizens +""" + +DML_W_PARAM = """ +INSERT INTO citizens(first_name, last_name, age) +VALUES ("Phred", "Phlyntstone", @age) +""" +PARAMS = {'age': 30} +PARAM_TYPES = {'age': 'INT64'} +MODE = 2 # PROFILE + + def _make_credentials(): # pragma: NO COVER import google.auth.credentials @@ -39,7 +52,7 @@ class _BaseTest(unittest.TestCase): DATABASE_NAME = INSTANCE_NAME + '/databases/' + DATABASE_ID SESSION_ID = 'session_id' SESSION_NAME = DATABASE_NAME + '/sessions/' + SESSION_ID - TRANSACTION_ID = 'transaction_id' + TRANSACTION_ID = b'transaction_id' def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) @@ -65,6 +78,20 @@ def _get_target_class(self): return Database + @staticmethod + def _make_database_admin_api(): + from google.cloud.spanner_v1.client import DatabaseAdminClient + + return mock.create_autospec(DatabaseAdminClient, instance=True) + + @staticmethod + def _make_spanner_api(): + import google.cloud.spanner_v1.gapic.spanner_client + + return mock.create_autospec( + google.cloud.spanner_v1.gapic.spanner_client.SpannerClient, + instance=True) + def test_ctor_defaults(self): from google.cloud.spanner_v1.pool import BurstyPool @@ -296,10 +323,12 @@ def test___ne__(self): def test_create_grpc_error(self): from google.api_core.exceptions import GoogleAPICallError + from google.api_core.exceptions import Unknown client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.side_effect = Unknown('testing') + instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -307,22 +336,20 @@ def test_create_grpc_error(self): with self.assertRaises(GoogleAPICallError): database.create() - (parent, create_statement, extra_statements, - metadata) = api._created_database - self.assertEqual(parent, self.INSTANCE_NAME) - self.assertEqual(create_statement, - 'CREATE DATABASE %s' % self.DATABASE_ID) - self.assertEqual(extra_statements, []) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.create_database.assert_called_once_with( + parent=self.INSTANCE_NAME, + create_statement='CREATE DATABASE {}'.format(self.DATABASE_ID), + extra_statements=[], + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_create_already_exists(self): from google.cloud.exceptions import Conflict DATABASE_ID_HYPHEN = 'database-id' client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _create_database_conflict=True) + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.side_effect = Conflict('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(DATABASE_ID_HYPHEN, instance, pool=pool) @@ -330,45 +357,40 @@ def test_create_already_exists(self): with self.assertRaises(Conflict): database.create() - (parent, create_statement, extra_statements, - metadata) = api._created_database - self.assertEqual(parent, self.INSTANCE_NAME) - self.assertEqual(create_statement, - 'CREATE DATABASE `%s`' % DATABASE_ID_HYPHEN) - self.assertEqual(extra_statements, []) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.create_database.assert_called_once_with( + parent=self.INSTANCE_NAME, + create_statement='CREATE DATABASE `{}`'.format(DATABASE_ID_HYPHEN), + extra_statements=[], + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_create_instance_not_found(self): from google.cloud.exceptions import NotFound - DATABASE_ID_HYPHEN = 'database-id' client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() - database = self._make_one(DATABASE_ID_HYPHEN, instance, pool=pool) + database = self._make_one(self.DATABASE_ID, instance, pool=pool) with self.assertRaises(NotFound): database.create() - (parent, create_statement, extra_statements, - metadata) = api._created_database - self.assertEqual(parent, self.INSTANCE_NAME) - self.assertEqual(create_statement, - 'CREATE DATABASE `%s`' % DATABASE_ID_HYPHEN) - self.assertEqual(extra_statements, []) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.create_database.assert_called_once_with( + parent=self.INSTANCE_NAME, + create_statement='CREATE DATABASE {}'.format(self.DATABASE_ID), + extra_statements=[], + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_create_success(self): from tests._fixtures import DDL_STATEMENTS - op_future = _FauxOperationFuture() + op_future = object() client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _create_database_response=op_future) + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one( @@ -379,21 +401,19 @@ def test_create_success(self): self.assertIs(future, op_future) - (parent, create_statement, extra_statements, - metadata) = api._created_database - self.assertEqual(parent, self.INSTANCE_NAME) - self.assertEqual(create_statement, - 'CREATE DATABASE %s' % self.DATABASE_ID) - self.assertEqual(extra_statements, DDL_STATEMENTS) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.create_database.assert_called_once_with( + parent=self.INSTANCE_NAME, + create_statement='CREATE DATABASE {}'.format(self.DATABASE_ID), + extra_statements=DDL_STATEMENTS, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_exists_grpc_error(self): from google.api_core.exceptions import Unknown client = _Client() - client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.side_effect = Unknown('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -401,20 +421,27 @@ def test_exists_grpc_error(self): with self.assertRaises(Unknown): database.exists() + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_exists_not_found(self): + from google.cloud.exceptions import NotFound + client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) self.assertFalse(database.exists()) - name, metadata = api._got_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_exists_success(self): from google.cloud.spanner_admin_database_v1.proto import ( @@ -424,25 +451,25 @@ def test_exists_success(self): client = _Client() ddl_pb = admin_v1_pb2.GetDatabaseDdlResponse( statements=DDL_STATEMENTS) - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _get_database_ddl_response=ddl_pb) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.return_value = ddl_pb instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) self.assertTrue(database.exists()) - name, metadata = api._got_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_reload_grpc_error(self): from google.api_core.exceptions import Unknown client = _Client() - client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.side_effect = Unknown('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -450,12 +477,17 @@ def test_reload_grpc_error(self): with self.assertRaises(Unknown): database.reload() + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_reload_not_found(self): from google.cloud.exceptions import NotFound client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -463,10 +495,10 @@ def test_reload_not_found(self): with self.assertRaises(NotFound): database.reload() - name, metadata = api._got_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_reload_success(self): from google.cloud.spanner_admin_database_v1.proto import ( @@ -476,8 +508,8 @@ def test_reload_success(self): client = _Client() ddl_pb = admin_v1_pb2.GetDatabaseDdlResponse( statements=DDL_STATEMENTS) - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _get_database_ddl_response=ddl_pb) + api = client.database_admin_api = self._make_database_admin_api() + api.get_database_ddl.return_value = ddl_pb instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -486,18 +518,18 @@ def test_reload_success(self): self.assertEqual(database._ddl_statements, tuple(DDL_STATEMENTS)) - name, metadata = api._got_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.get_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_update_ddl_grpc_error(self): from google.api_core.exceptions import Unknown from tests._fixtures import DDL_STATEMENTS client = _Client() - client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.update_database_ddl.side_effect = Unknown('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -505,13 +537,20 @@ def test_update_ddl_grpc_error(self): with self.assertRaises(Unknown): database.update_ddl(DDL_STATEMENTS) + api.update_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + DDL_STATEMENTS, + '', + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_update_ddl_not_found(self): from google.cloud.exceptions import NotFound from tests._fixtures import DDL_STATEMENTS client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.update_database_ddl.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -519,20 +558,20 @@ def test_update_ddl_not_found(self): with self.assertRaises(NotFound): database.update_ddl(DDL_STATEMENTS) - name, statements, op_id, metadata = api._updated_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual(statements, DDL_STATEMENTS) - self.assertEqual(op_id, '') - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.update_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + DDL_STATEMENTS, + '', + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_update_ddl(self): from tests._fixtures import DDL_STATEMENTS - op_future = _FauxOperationFuture() + op_future = object() client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _update_database_ddl_response=op_future) + api = client.database_admin_api = self._make_database_admin_api() + api.update_database_ddl.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -541,19 +580,19 @@ def test_update_ddl(self): self.assertIs(future, op_future) - name, statements, op_id, metadata = api._updated_database_ddl - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual(statements, DDL_STATEMENTS) - self.assertEqual(op_id, '') - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.update_database_ddl.assert_called_once_with( + self.DATABASE_NAME, + DDL_STATEMENTS, + '', + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_drop_grpc_error(self): from google.api_core.exceptions import Unknown client = _Client() - client.database_admin_api = _FauxDatabaseAdminAPI( - _rpc_error=True) + api = client.database_admin_api = self._make_database_admin_api() + api.drop_database.side_effect = Unknown('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -561,12 +600,17 @@ def test_drop_grpc_error(self): with self.assertRaises(Unknown): database.drop() + api.drop_database.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + def test_drop_not_found(self): from google.cloud.exceptions import NotFound client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _database_not_found=True) + api = client.database_admin_api = self._make_database_admin_api() + api.drop_database.side_effect = NotFound('testing') instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) @@ -574,27 +618,98 @@ def test_drop_not_found(self): with self.assertRaises(NotFound): database.drop() - name, metadata = api._dropped_database - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.drop_database.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_drop_success(self): from google.protobuf.empty_pb2 import Empty client = _Client() - api = client.database_admin_api = _FauxDatabaseAdminAPI( - _drop_database_response=Empty()) + api = client.database_admin_api = self._make_database_admin_api() + api.drop_database.return_value = Empty() instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) database.drop() - name, metadata = api._dropped_database - self.assertEqual(name, self.DATABASE_NAME) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + api.drop_database.assert_called_once_with( + self.DATABASE_NAME, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + + def _execute_partitioned_dml_helper( + self, dml, params=None, param_types=None): + from google.protobuf.struct_pb2 import Struct + from google.cloud.spanner_v1.proto.result_set_pb2 import ( + PartialResultSet, ResultSetStats) + from google.cloud.spanner_v1.proto.transaction_pb2 import ( + Transaction as TransactionPB, + TransactionSelector, TransactionOptions) + from google.cloud.spanner_v1._helpers import _make_value_pb + + transaction_pb = TransactionPB(id=self.TRANSACTION_ID) + + stats_pb = ResultSetStats(row_count_lower_bound=2) + result_sets = [ + PartialResultSet(stats=stats_pb), + ] + iterator = _MockIterator(*result_sets) + + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + session = _Session() + pool.put(session) + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + api = database._spanner_api = self._make_spanner_api() + api.begin_transaction.return_value = transaction_pb + api.execute_streaming_sql.return_value = iterator + + row_count = database.execute_partitioned_dml( + dml, params, param_types) + + self.assertEqual(row_count, 2) + + txn_options = TransactionOptions( + partitioned_dml=TransactionOptions.PartitionedDml()) + + api.begin_transaction.assert_called_once_with( + session.name, + txn_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + + if params: + expected_params = Struct(fields={ + key: _make_value_pb(value) for (key, value) in params.items()}) + else: + expected_params = None + + expected_transaction = TransactionSelector(id=self.TRANSACTION_ID) + + api.execute_streaming_sql.assert_called_once_with( + self.SESSION_NAME, + dml, + transaction=expected_transaction, + params=expected_params, + param_types=param_types, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + + def test_execute_partitioned_dml_wo_params(self): + self._execute_partitioned_dml_helper(dml=DML_WO_PARAM) + + def test_execute_partitioned_dml_w_params_wo_param_types(self): + with self.assertRaises(ValueError): + self._execute_partitioned_dml_helper( + dml=DML_W_PARAM, params=PARAMS) + + def test_execute_partitioned_dml_w_params_and_param_types(self): + self._execute_partitioned_dml_helper( + dml=DML_W_PARAM, params=PARAMS, param_types=PARAM_TYPES) def test_session_factory_defaults(self): from google.cloud.spanner_v1.session import Session @@ -787,6 +902,12 @@ def _get_target_class(self): return BatchCheckout + @staticmethod + def _make_spanner_client(): + from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient + + return mock.create_autospec(SpannerClient) + def test_ctor(self): database = _Database(self.DATABASE_NAME) checkout = self._make_one(database) @@ -805,8 +926,8 @@ def test_context_mgr_success(self): now_pb = _datetime_to_pb_timestamp(now) response = CommitResponse(commit_timestamp=now_pb) database = _Database(self.DATABASE_NAME) - api = database.spanner_api = _FauxSpannerClient() - api._commit_response = response + api = database.spanner_api = self._make_spanner_client() + api.commit.return_value = response pool = database._pool = _Pool() session = _Session(database) pool.put(session) @@ -819,14 +940,15 @@ def test_context_mgr_success(self): self.assertIs(pool._session, session) self.assertEqual(batch.committed, now) - (session_name, mutations, single_use_txn, - metadata) = api._committed - self.assertIs(session_name, self.SESSION_NAME) - self.assertEqual(mutations, []) - self.assertIsInstance(single_use_txn, TransactionOptions) - self.assertTrue(single_use_txn.HasField('read_write')) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + + expected_txn_options = TransactionOptions(read_write={}) + + api.commit.assert_called_once_with( + self.SESSION_NAME, + [], + single_use_transaction=expected_txn_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_context_mgr_failure(self): from google.cloud.spanner_v1.batch import Batch @@ -1433,80 +1555,19 @@ def run_in_transaction(self, func, *args, **kw): return self._committed -class _SessionPB(object): - name = TestDatabase.SESSION_NAME - - -class _FauxOperationFuture(object): - pass - - -class _FauxSpannerClient(object): - - _committed = None - - def __init__(self, **kwargs): - self.__dict__.update(**kwargs) - - def commit(self, session, mutations, - transaction_id='', single_use_transaction=None, metadata=None): - assert transaction_id == '' - self._committed = ( - session, mutations, single_use_transaction, metadata) - return self._commit_response - - -class _FauxDatabaseAdminAPI(object): - - _create_database_conflict = False - _database_not_found = False - _rpc_error = False - - def __init__(self, **kwargs): - self.__dict__.update(**kwargs) - - def create_database(self, parent, create_statement, extra_statements=None, - metadata=None): - from google.api_core.exceptions import AlreadyExists, NotFound, Unknown - - self._created_database = ( - parent, create_statement, extra_statements, metadata) - if self._rpc_error: - raise Unknown('error') - if self._create_database_conflict: - raise AlreadyExists('conflict') - if self._database_not_found: - raise NotFound('not found') - return self._create_database_response - - def get_database_ddl(self, database, metadata=None): - from google.api_core.exceptions import NotFound, Unknown - - self._got_database_ddl = database, metadata - if self._rpc_error: - raise Unknown('error') - if self._database_not_found: - raise NotFound('not found') - return self._get_database_ddl_response +class _MockIterator(object): - def drop_database(self, database, metadata=None): - from google.api_core.exceptions import NotFound, Unknown + def __init__(self, *values, **kw): + self._iter_values = iter(values) + self._fail_after = kw.pop('fail_after', False) - self._dropped_database = database, metadata - if self._rpc_error: - raise Unknown('error') - if self._database_not_found: - raise NotFound('not found') - return self._drop_database_response + def __iter__(self): + return self - def update_database_ddl(self, database, statements, operation_id, - metadata=None): - from google.api_core.exceptions import NotFound, Unknown + def __next__(self): + try: + return next(self._iter_values) + except StopIteration: + raise - self._updated_database_ddl = ( - database, statements, operation_id, metadata) - if self._rpc_error: - raise Unknown('error') - if self._database_not_found: - raise NotFound('not found') - return self._update_database_ddl_response + next = __next__ diff --git a/spanner/tests/unit/test_snapshot.py b/spanner/tests/unit/test_snapshot.py index 2b5961b75f74..21cb6cbe35df 100644 --- a/spanner/tests/unit/test_snapshot.py +++ b/spanner/tests/unit/test_snapshot.py @@ -31,6 +31,8 @@ PARAMS_WITH_BYTES = {'bytes': b'FACEDACE'} RESUME_TOKEN = b'DEADBEEF' TXN_ID = b'DEAFBEAD' +SECONDS = 3 +MICROS = 123456 class Test_restart_on_unavailable(unittest.TestCase): @@ -176,6 +178,7 @@ def test_ctor(self): session = _Session() base = self._make_one(session) self.assertIs(base._session, session) + self.assertEqual(base._execute_sql_count, 0) def test__make_txn_selector_virtual(self): session = _Session() @@ -201,7 +204,7 @@ def _read_helper(self, multi_use, first=True, count=0, partition=None): from google.cloud.spanner_v1.proto.result_set_pb2 import ( PartialResultSet, ResultSetMetadata, ResultSetStats) from google.cloud.spanner_v1.proto.transaction_pb2 import ( - TransactionSelector) + TransactionSelector, TransactionOptions) from google.cloud.spanner_v1.proto.type_pb2 import Type, StructType from google.cloud.spanner_v1.proto.type_pb2 import STRING, INT64 from google.cloud.spanner_v1.keyset import KeySet @@ -228,13 +231,13 @@ def _read_helper(self, multi_use, first=True, count=0, partition=None): PartialResultSet(values=VALUE_PBS[0], metadata=metadata_pb), PartialResultSet(values=VALUE_PBS[1], stats=stats_pb), ] - KEYS = ['bharney@example.com', 'phred@example.com'] + KEYS = [['bharney@example.com'], ['phred@example.com']] keyset = KeySet(keys=KEYS) INDEX = 'email-address-index' LIMIT = 20 database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _streaming_read_response=_MockIterator(*result_sets)) + api = database.spanner_api = self._make_spanner_api() + api.streaming_read.return_value = _MockIterator(*result_sets) session = _Session(database) derived = self._makeDerived(session) derived._multi_use = multi_use @@ -262,31 +265,33 @@ def _read_helper(self, multi_use, first=True, count=0, partition=None): self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) - (r_session, table, columns, key_set, transaction, index, limit, - resume_token, r_partition, metadata) = api._streaming_read_with + txn_options = TransactionOptions( + read_only=TransactionOptions.ReadOnly(strong=True)) - self.assertEqual(r_session, self.SESSION_NAME) - self.assertEqual(table, TABLE_NAME) - self.assertEqual(columns, COLUMNS) - self.assertEqual(key_set, keyset._to_pb()) - self.assertIsInstance(transaction, TransactionSelector) if multi_use: if first: - self.assertTrue(transaction.begin.read_only.strong) + expected_transaction = TransactionSelector(begin=txn_options) else: - self.assertEqual(transaction.id, TXN_ID) + expected_transaction = TransactionSelector(id=TXN_ID) else: - self.assertTrue(transaction.single_use.read_only.strong) - self.assertEqual(index, INDEX) + expected_transaction = TransactionSelector(single_use=txn_options) + if partition is not None: - self.assertEqual(limit, 0) - self.assertEqual(r_partition, partition) + expected_limit = 0 else: - self.assertEqual(limit, LIMIT) - self.assertIsNone(r_partition) - self.assertEqual(resume_token, b'') - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + expected_limit = LIMIT + + api.streaming_read.assert_called_once_with( + self.SESSION_NAME, + TABLE_NAME, + COLUMNS, + keyset._to_pb(), + transaction=expected_transaction, + index=INDEX, + limit=expected_limit, + partition_token=partition, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_read_wo_multi_use(self): self._read_helper(multi_use=False) @@ -328,12 +333,12 @@ def test_execute_sql_w_params_wo_param_types(self): derived.execute_sql(SQL_QUERY_WITH_PARAM, PARAMS) def _execute_sql_helper( - self, multi_use, first=True, count=0, partition=None): + self, multi_use, first=True, count=0, partition=None, sql_count=0): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1.proto.result_set_pb2 import ( PartialResultSet, ResultSetMetadata, ResultSetStats) from google.cloud.spanner_v1.proto.transaction_pb2 import ( - TransactionSelector) + TransactionSelector, TransactionOptions) from google.cloud.spanner_v1.proto.type_pb2 import Type, StructType from google.cloud.spanner_v1.proto.type_pb2 import STRING, INT64 from google.cloud.spanner_v1._helpers import _make_value_pb @@ -363,12 +368,13 @@ def _execute_sql_helper( ] iterator = _MockIterator(*result_sets) database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _execute_streaming_sql_response=iterator) + api = database.spanner_api = self._make_spanner_api() + api.execute_streaming_sql.return_value = iterator session = _Session(database) derived = self._makeDerived(session) derived._multi_use = multi_use derived._read_request_count = count + derived._execute_sql_count = sql_count if not first: derived._transaction_id = TXN_ID @@ -387,29 +393,33 @@ def _execute_sql_helper( self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) - (r_session, sql, transaction, params, param_types, - resume_token, query_mode, partition_token, - metadata) = api._executed_streaming_sql_with + txn_options = TransactionOptions( + read_only=TransactionOptions.ReadOnly(strong=True)) - self.assertEqual(r_session, self.SESSION_NAME) - self.assertEqual(sql, SQL_QUERY_WITH_PARAM) - self.assertIsInstance(transaction, TransactionSelector) if multi_use: if first: - self.assertTrue(transaction.begin.read_only.strong) + expected_transaction = TransactionSelector(begin=txn_options) else: - self.assertEqual(transaction.id, TXN_ID) + expected_transaction = TransactionSelector(id=TXN_ID) else: - self.assertTrue(transaction.single_use.read_only.strong) + expected_transaction = TransactionSelector(single_use=txn_options) + expected_params = Struct(fields={ key: _make_value_pb(value) for (key, value) in PARAMS.items()}) - self.assertEqual(params, expected_params) - self.assertEqual(param_types, PARAM_TYPES) - self.assertEqual(query_mode, MODE) - self.assertEqual(resume_token, b'') - self.assertEqual(partition_token, partition) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + + api.execute_streaming_sql.assert_called_once_with( + self.SESSION_NAME, + SQL_QUERY_WITH_PARAM, + transaction=expected_transaction, + params=expected_params, + param_types=PARAM_TYPES, + query_mode=MODE, + partition_token=partition, + seqno=sql_count, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + + self.assertEqual(derived._execute_sql_count, sql_count + 1) def test_execute_sql_wo_multi_use(self): self._execute_sql_helper(multi_use=False) @@ -419,7 +429,7 @@ def test_execute_sql_wo_multi_use_w_read_request_count_gt_0(self): self._execute_sql_helper(multi_use=False, count=1) def test_execute_sql_w_multi_use_wo_first(self): - self._execute_sql_helper(multi_use=True, first=False) + self._execute_sql_helper(multi_use=True, first=False, sql_count=1) def test_execute_sql_w_multi_use_wo_first_w_count_gt_0(self): self._execute_sql_helper(multi_use=True, first=False, count=1) @@ -454,8 +464,8 @@ def _partition_read_helper( transaction=Transaction(id=new_txn_id), ) database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _partition_read_response=response) + api = database.spanner_api = self._make_spanner_api() + api.partition_read.return_value = response session = _Session(database) derived = self._makeDerived(session) derived._multi_use = multi_use @@ -471,23 +481,21 @@ def _partition_read_helper( self.assertEqual(tokens, [token_1, token_2]) - (r_session, table, key_set, transaction, r_index, columns, - partition_options, metadata) = api._partition_read_with - - self.assertEqual(r_session, self.SESSION_NAME) - self.assertEqual(table, TABLE_NAME) - self.assertEqual(key_set, keyset._to_pb()) - self.assertIsInstance(transaction, TransactionSelector) - self.assertEqual(transaction.id, TXN_ID) - self.assertFalse(transaction.HasField('begin')) - self.assertEqual(r_index, index) - self.assertEqual(columns, COLUMNS) - self.assertEqual( - partition_options, - PartitionOptions( - partition_size_bytes=size, max_partitions=max_partitions)) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + expected_txn_selector = TransactionSelector(id=TXN_ID) + + expected_partition_options = PartitionOptions( + partition_size_bytes=size, max_partitions=max_partitions) + + api.partition_read.assert_called_once_with( + session=self.SESSION_NAME, + table=TABLE_NAME, + columns=COLUMNS, + key_set=keyset._to_pb(), + transaction=expected_txn_selector, + index=index, + partition_options=expected_partition_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_partition_read_single_use_raises(self): with self.assertRaises(ValueError): @@ -544,8 +552,8 @@ def _partition_query_helper( transaction=Transaction(id=new_txn_id), ) database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _partition_query_response=response) + api = database.spanner_api = self._make_spanner_api() + api.partition_query.return_value = response session = _Session(database) derived = self._makeDerived(session) derived._multi_use = multi_use @@ -560,24 +568,23 @@ def _partition_query_helper( self.assertEqual(tokens, [token_1, token_2]) - (r_session, sql, transaction, params, param_types, - partition_options, metadata) = api._partition_query_with - - self.assertEqual(r_session, self.SESSION_NAME) - self.assertEqual(sql, SQL_QUERY_WITH_PARAM) - self.assertIsInstance(transaction, TransactionSelector) - self.assertEqual(transaction.id, TXN_ID) - self.assertFalse(transaction.HasField('begin')) expected_params = Struct(fields={ key: _make_value_pb(value) for (key, value) in PARAMS.items()}) - self.assertEqual(params, expected_params) - self.assertEqual(param_types, PARAM_TYPES) - self.assertEqual( - partition_options, - PartitionOptions( - partition_size_bytes=size, max_partitions=max_partitions)) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + + expected_txn_selector = TransactionSelector(id=TXN_ID) + + expected_partition_options = PartitionOptions( + partition_size_bytes=size, max_partitions=max_partitions) + + api.partition_query.assert_called_once_with( + session=self.SESSION_NAME, + sql=SQL_QUERY_WITH_PARAM, + transaction=expected_txn_selector, + params=expected_params, + param_types=PARAM_TYPES, + partition_options=expected_partition_options, + metadata=[('google-cloud-resource-prefix', database.name)], + ) def test_partition_query_other_error(self): database = _Database() @@ -894,14 +901,15 @@ def test_begin_w_other_error(self): snapshot.begin() def test_begin_ok_exact_staleness(self): + from google.protobuf.duration_pb2 import Duration from google.cloud.spanner_v1.proto.transaction_pb2 import ( - Transaction as TransactionPB) + Transaction as TransactionPB, TransactionOptions) transaction_pb = TransactionPB(id=TXN_ID) database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _begin_transaction_response=transaction_pb) - duration = self._makeDuration(seconds=3, microseconds=123456) + api = database.spanner_api = self._make_spanner_api() + api.begin_transaction.return_value = transaction_pb + duration = self._makeDuration(seconds=SECONDS, microseconds=MICROS) session = _Session(database) snapshot = self._make_one( session, exact_staleness=duration, multi_use=True) @@ -911,22 +919,25 @@ def test_begin_ok_exact_staleness(self): self.assertEqual(txn_id, TXN_ID) self.assertEqual(snapshot._transaction_id, TXN_ID) - session_id, txn_options, metadata = api._begun - self.assertEqual(session_id, session.name) - read_only = txn_options.read_only - self.assertEqual(read_only.exact_staleness.seconds, 3) - self.assertEqual(read_only.exact_staleness.nanos, 123456000) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + expected_duration = Duration( + seconds=SECONDS, nanos=MICROS * 1000) + expected_txn_options = TransactionOptions( + read_only=TransactionOptions.ReadOnly( + exact_staleness=expected_duration)) + + api.begin_transaction.assert_called_once_with( + session.name, + expected_txn_options, + metadata=[('google-cloud-resource-prefix', database.name)]) def test_begin_ok_exact_strong(self): from google.cloud.spanner_v1.proto.transaction_pb2 import ( - Transaction as TransactionPB) + Transaction as TransactionPB, TransactionOptions) transaction_pb = TransactionPB(id=TXN_ID) database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _begin_transaction_response=transaction_pb) + api = database.spanner_api = self._make_spanner_api() + api.begin_transaction.return_value = transaction_pb session = _Session(database) snapshot = self._make_one(session, multi_use=True) @@ -935,11 +946,13 @@ def test_begin_ok_exact_strong(self): self.assertEqual(txn_id, TXN_ID) self.assertEqual(snapshot._transaction_id, TXN_ID) - session_id, txn_options, metadata = api._begun - self.assertEqual(session_id, session.name) - self.assertTrue(txn_options.read_only.strong) - self.assertEqual( - metadata, [('google-cloud-resource-prefix', database.name)]) + expected_txn_options = TransactionOptions( + read_only=TransactionOptions.ReadOnly(strong=True)) + + api.begin_transaction.assert_called_once_with( + session.name, + expected_txn_options, + metadata=[('google-cloud-resource-prefix', database.name)]) class _Session(object): @@ -953,63 +966,6 @@ class _Database(object): name = 'testing' -class _FauxSpannerAPI(object): - - _read_with = _begin = None - - def __init__(self, **kwargs): - self.__dict__.update(**kwargs) - - def begin_transaction(self, session, options_, metadata=None): - self._begun = (session, options_, metadata) - return self._begin_transaction_response - - # pylint: disable=too-many-arguments - def streaming_read(self, session, table, columns, key_set, - transaction=None, index='', limit=0, - resume_token=b'', partition_token=None, metadata=None): - self._streaming_read_with = ( - session, table, columns, key_set, transaction, index, - limit, resume_token, partition_token, metadata) - return self._streaming_read_response - # pylint: enable=too-many-arguments - - def execute_streaming_sql(self, session, sql, transaction=None, - params=None, param_types=None, - resume_token=b'', query_mode=None, - partition_token=None, metadata=None): - self._executed_streaming_sql_with = ( - session, sql, transaction, params, param_types, resume_token, - query_mode, partition_token, metadata) - return self._execute_streaming_sql_response - - # pylint: disable=too-many-arguments - def partition_read(self, session, table, key_set, - transaction=None, - index='', - columns=None, - partition_options=None, - metadata=None): - self._partition_read_with = ( - session, table, key_set, transaction, index, columns, - partition_options, metadata) - return self._partition_read_response - # pylint: enable=too-many-arguments - - # pylint: disable=too-many-arguments - def partition_query(self, session, sql, - transaction=None, - params=None, - param_types=None, - partition_options=None, - metadata=None): - self._partition_query_with = ( - session, sql, transaction, params, param_types, - partition_options, metadata) - return self._partition_query_response - # pylint: enable=too-many-arguments - - class _MockIterator(object): def __init__(self, *values, **kw): diff --git a/spanner/tests/unit/test_transaction.py b/spanner/tests/unit/test_transaction.py index 29c1e765888e..99c401cc7e10 100644 --- a/spanner/tests/unit/test_transaction.py +++ b/spanner/tests/unit/test_transaction.py @@ -24,6 +24,16 @@ ['phred@exammple.com', 'Phred', 'Phlyntstone', 32], ['bharney@example.com', 'Bharney', 'Rhubble', 31], ] +DML_QUERY = """\ +INSERT INTO citizens(first_name, last_name, age) +VALUES ("Phred", "Phlyntstone", 32) +""" +DML_QUERY_WITH_PARAM = """ +INSERT INTO citizens(first_name, last_name, age) +VALUES ("Phred", "Phlyntstone", @age) +""" +PARAMS = {'age': 30} +PARAM_TYPES = {'age': 'INT64'} class TestTransaction(unittest.TestCase): @@ -68,6 +78,7 @@ def test_ctor_defaults(self): self.assertIsNone(transaction.committed) self.assertFalse(transaction._rolled_back) self.assertTrue(transaction._multi_use) + self.assertEqual(transaction._execute_sql_count, 0) def test__check_state_not_begun(self): session = _Session() @@ -238,13 +249,6 @@ def test_commit_already_rolled_back(self): with self.assertRaises(ValueError): transaction.commit() - def test_commit_no_mutations(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - with self.assertRaises(ValueError): - transaction.commit() - def test_commit_w_other_error(self): database = _Database() database.spanner_api = self._make_spanner_api() @@ -259,7 +263,7 @@ def test_commit_w_other_error(self): self.assertIsNone(transaction.committed) - def test_commit_ok(self): + def _commit_helper(self, mutate=True): import datetime from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse from google.cloud.spanner_v1.keyset import KeySet @@ -277,7 +281,9 @@ def test_commit_ok(self): session = _Session(database) transaction = self._make_one(session) transaction._transaction_id = self.TRANSACTION_ID - transaction.delete(TABLE_NAME, keyset) + + if mutate: + transaction.delete(TABLE_NAME, keyset) transaction.commit() @@ -291,6 +297,80 @@ def test_commit_ok(self): self.assertEqual( metadata, [('google-cloud-resource-prefix', database.name)]) + def test_commit_no_mutations(self): + self._commit_helper(mutate=False) + + def test_commit_w_mutations(self): + self._commit_helper(mutate=True) + + def test_execute_update_other_error(self): + database = _Database() + database.spanner_api = self._make_spanner_api() + database.spanner_api.execute_sql.side_effect = RuntimeError() + session = _Session(database) + transaction = self._make_one(session) + transaction._transaction_id = self.TRANSACTION_ID + + with self.assertRaises(RuntimeError): + transaction.execute_update(DML_QUERY) + + def test_execute_update_w_params_wo_param_types(self): + database = _Database() + database.spanner_api = self._make_spanner_api() + session = _Session(database) + session = _Session() + transaction = self._make_one(session) + transaction._transaction_id = self.TRANSACTION_ID + + with self.assertRaises(ValueError): + transaction.execute_update(DML_QUERY_WITH_PARAM, PARAMS) + + def _execute_update_helper(self, count=0): + from google.protobuf.struct_pb2 import Struct + from google.cloud.spanner_v1.proto.result_set_pb2 import ( + ResultSet, ResultSetStats) + from google.cloud.spanner_v1.proto.transaction_pb2 import ( + TransactionSelector) + from google.cloud.spanner_v1._helpers import _make_value_pb + + MODE = 2 # PROFILE + stats_pb = ResultSetStats(row_count_exact=1) + database = _Database() + api = database.spanner_api = self._make_spanner_api() + api.execute_sql.return_value = ResultSet(stats=stats_pb) + session = _Session(database) + transaction = self._make_one(session) + transaction._transaction_id = self.TRANSACTION_ID + transaction._execute_sql_count = count + + row_count = transaction.execute_update( + DML_QUERY_WITH_PARAM, PARAMS, PARAM_TYPES, query_mode=MODE) + + self.assertEqual(row_count, 1) + + expected_transaction = TransactionSelector(id=self.TRANSACTION_ID) + expected_params = Struct(fields={ + key: _make_value_pb(value) for (key, value) in PARAMS.items()}) + + api.execute_sql.assert_called_once_with( + self.SESSION_NAME, + DML_QUERY_WITH_PARAM, + transaction=expected_transaction, + params=expected_params, + param_types=PARAM_TYPES, + query_mode=MODE, + seqno=count, + metadata=[('google-cloud-resource-prefix', database.name)], + ) + + self.assertEqual(transaction._execute_sql_count, count + 1) + + def test_execute_update_new_transaction(self): + self._execute_update_helper() + + def test_execute_update_w_count(self): + self._execute_update_helper(count=1) + def test_context_mgr_success(self): import datetime from google.cloud.spanner_v1.proto.spanner_pb2 import CommitResponse