From 50069e9b70a3038c35870934ed270355eebed35d Mon Sep 17 00:00:00 2001 From: qadro87 <57414069+qadro87@users.noreply.github.com> Date: Mon, 11 Nov 2019 09:31:34 -0800 Subject: [PATCH 001/131] Fix some type annotations that diverged. (#72) * Fix some type annotations that diverged. Stop deferring type annotation validation to prevent further divergence. Prevent the migrations_test from modifying migrations in place. Update documentation on how to create the database. Don't validate migration_id in migrations/__init__.py. * Make migrations_test more robust. * Change import format. * Update one more reference. --- README.md | 2 +- spanner_orm/__init__.py | 6 ++++ spanner_orm/admin/api.py | 8 ++--- spanner_orm/admin/column.py | 2 -- spanner_orm/admin/index.py | 2 -- spanner_orm/admin/index_column.py | 2 -- spanner_orm/admin/metadata.py | 2 -- spanner_orm/admin/migration.py | 2 -- spanner_orm/admin/migration_executor.py | 2 -- spanner_orm/admin/migration_manager.py | 4 +-- spanner_orm/admin/migration_status.py | 2 -- spanner_orm/admin/schema.py | 2 -- spanner_orm/admin/scripts.py | 2 -- spanner_orm/admin/table.py | 2 -- spanner_orm/admin/update.py | 2 -- spanner_orm/api.py | 10 +++--- spanner_orm/condition.py | 2 -- spanner_orm/decorator.py | 2 -- spanner_orm/field.py | 40 ++++++++++----------- spanner_orm/index.py | 2 -- spanner_orm/metadata.py | 4 +-- spanner_orm/model.py | 27 +++++++------- spanner_orm/registry.py | 2 -- spanner_orm/relationship.py | 2 -- spanner_orm/table_apis.py | 2 -- spanner_orm/tests/migrations_test.py | 47 +++++++++++++++++++------ 26 files changed, 86 insertions(+), 96 deletions(-) diff --git a/README.md b/README.md index 75a7cbb..c3d9a02 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ admin_api = spanner_orm.connect_admin( 'instance_name', 'database_name', create_ddl=spanner_orm.model_creation_ddl(TestModel)) -admin_api.create() +admin_api.create_database() ``` If the database already exists, we can execute a Migration where the upgrade diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 335b864..9452974 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """Sets up shortcuts for imports from the library.""" import logging @@ -97,3 +98,8 @@ model_creation_ddl = update_module.model_creation_ddl MigrationExecutor = migration_executor.MigrationExecutor + +try: + __import__('pkg_resources').declare_namespace('spanner_orm') +except ImportError: + __path__ = __import__('pkgutil').extend_path(__path__, 'spanner_orm') diff --git a/spanner_orm/admin/api.py b/spanner_orm/admin/api.py index 59e5ae0..8a28dfc 100644 --- a/spanner_orm/admin/api.py +++ b/spanner_orm/admin/api.py @@ -14,15 +14,13 @@ # limitations under the License. """Class that handles API calls to Spanner that deal with table metadata.""" -from __future__ import annotations - from typing import Iterable, Optional from spanner_orm import api from spanner_orm import error from google.auth import credentials as auth_credentials -from google.cloud import spanner from google.cloud.spanner_v1 import database as spanner_database +from google.cloud.spanner_v1 import pool as spanner_pool class SpannerAdminApi(api.SpannerReadApi, api.SpannerWriteApi): @@ -32,7 +30,7 @@ def __init__(self, connection: api.SpannerConnection): self._spanner_connection = connection @property - def _connection(self) -> spanner_database.SpannerDatabase: + def _connection(self) -> spanner_database.Database: return self._spanner_connection.database def create_database(self) -> None: @@ -54,7 +52,7 @@ def connect(instance: str, database: str, project: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, - pool: Optional[spanner.Pool] = None, + pool: Optional[spanner_pool.AbstractSessionPool] = None, create_ddl: Optional[Iterable[str]] = None) -> SpannerAdminApi: """Connects the global Spanner admin API to a Spanner database.""" connection = api.SpannerConnection( diff --git a/spanner_orm/admin/column.py b/spanner_orm/admin/column.py index 7c22b9a..b559d4e 100644 --- a/spanner_orm/admin/column.py +++ b/spanner_orm/admin/column.py @@ -14,8 +14,6 @@ # limitations under the License. """Model for interacting with Spanner column schema table.""" -from __future__ import annotations - from typing import Type from spanner_orm import error diff --git a/spanner_orm/admin/index.py b/spanner_orm/admin/index.py index f4b0ba4..3ce6166 100644 --- a/spanner_orm/admin/index.py +++ b/spanner_orm/admin/index.py @@ -14,8 +14,6 @@ # limitations under the License. """Model for interacting with Spanner index schema table.""" -from __future__ import annotations - from spanner_orm import field from spanner_orm.admin import schema diff --git a/spanner_orm/admin/index_column.py b/spanner_orm/admin/index_column.py index b2fc736..d148639 100644 --- a/spanner_orm/admin/index_column.py +++ b/spanner_orm/admin/index_column.py @@ -14,8 +14,6 @@ # limitations under the License. """Model for interacting with Spanner index column schema table.""" -from __future__ import annotations - from spanner_orm import field from spanner_orm.admin import schema diff --git a/spanner_orm/admin/metadata.py b/spanner_orm/admin/metadata.py index 04bd89f..4156878 100644 --- a/spanner_orm/admin/metadata.py +++ b/spanner_orm/admin/metadata.py @@ -14,8 +14,6 @@ # limitations under the License. """Retrieves table metadata from Spanner.""" -from __future__ import annotations - import collections from typing import Any, Dict, Optional, Type diff --git a/spanner_orm/admin/migration.py b/spanner_orm/admin/migration.py index e189304..177e1a7 100644 --- a/spanner_orm/admin/migration.py +++ b/spanner_orm/admin/migration.py @@ -14,8 +14,6 @@ # limitations under the License. """Holds information about a specific migration.""" -from __future__ import annotations - from typing import Callable, Optional from spanner_orm.admin import update diff --git a/spanner_orm/admin/migration_executor.py b/spanner_orm/admin/migration_executor.py index 774fc16..1d9e9d5 100644 --- a/spanner_orm/admin/migration_executor.py +++ b/spanner_orm/admin/migration_executor.py @@ -14,8 +14,6 @@ # limitations under the License. """Handles execution of migrations.""" -from __future__ import annotations - import datetime import logging from typing import Iterable, List, Dict, Optional diff --git a/spanner_orm/admin/migration_manager.py b/spanner_orm/admin/migration_manager.py index 046d6d9..da1a29f 100644 --- a/spanner_orm/admin/migration_manager.py +++ b/spanner_orm/admin/migration_manager.py @@ -14,8 +14,6 @@ # limitations under the License. """Handles reading and writing of migration files.""" -from __future__ import annotations - import datetime import importlib import os @@ -91,7 +89,7 @@ def _all_migrations(self) -> List[migration.Migration]: migrations = [] for filename in os.listdir(self.basedir): _, ext = os.path.splitext(filename) - if ext == '.py': + if ext == '.py' and filename != '__init__.py': migrations.append(self._migration_from_file(filename)) return migrations diff --git a/spanner_orm/admin/migration_status.py b/spanner_orm/admin/migration_status.py index e580192..576593c 100644 --- a/spanner_orm/admin/migration_status.py +++ b/spanner_orm/admin/migration_status.py @@ -14,8 +14,6 @@ # limitations under the License. """Indicates whether a migration has been applied to the current database.""" -from __future__ import annotations - from spanner_orm import field from spanner_orm import model from spanner_orm.admin import api diff --git a/spanner_orm/admin/schema.py b/spanner_orm/admin/schema.py index 9887d55..c39d553 100644 --- a/spanner_orm/admin/schema.py +++ b/spanner_orm/admin/schema.py @@ -14,8 +14,6 @@ # limitations under the License. """Base model for schemas.""" -from __future__ import annotations - from typing import NoReturn from spanner_orm import error diff --git a/spanner_orm/admin/scripts.py b/spanner_orm/admin/scripts.py index 4dc4557..b57b2a7 100644 --- a/spanner_orm/admin/scripts.py +++ b/spanner_orm/admin/scripts.py @@ -14,8 +14,6 @@ # limitations under the License. """Entry point for spanner_orm scripts.""" -from __future__ import annotations - import argparse from typing import Any diff --git a/spanner_orm/admin/table.py b/spanner_orm/admin/table.py index 0c5838c..d718e33 100644 --- a/spanner_orm/admin/table.py +++ b/spanner_orm/admin/table.py @@ -14,8 +14,6 @@ # limitations under the License. """Model for interacting with Spanner column schema table.""" -from __future__ import annotations - from spanner_orm import field from spanner_orm.admin import schema diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index e2a73e8..4277e10 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -14,8 +14,6 @@ # limitations under the License. """Used with SpannerAdminApi to manage Spanner schema updates.""" -from __future__ import annotations - import abc from typing import Iterable, List, Optional, Type diff --git a/spanner_orm/api.py b/spanner_orm/api.py index 3254dc0..ccb0a27 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -14,8 +14,6 @@ # limitations under the License. """Class that handles API calls to Spanner.""" -from __future__ import annotations - import abc from typing import Any, Callable, Iterable, Optional, TypeVar @@ -24,6 +22,7 @@ from google.auth import credentials as auth_credentials from google.cloud import spanner from google.cloud.spanner_v1 import database as spanner_database +from google.cloud.spanner_v1 import pool as spanner_pool CallableReturn = TypeVar('CallableReturn') @@ -60,7 +59,7 @@ class SpannerWriteApi(abc.ABC): @property @abc.abstractmethod - def _connection(self) -> spanner_database.SpannerDatabase: + def _connection(self) -> spanner_database.Database: raise NotImplementedError def run_write(self, method: Callable[..., CallableReturn], *args: Any, @@ -92,7 +91,7 @@ def __init__(self, database: str, project: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, - pool: Optional[spanner.Pool] = None, + pool: Optional[spanner_pool.AbstractSessionPool] = None, create_ddl: Optional[Iterable[str]] = None): """Connects to the specified Spanner database.""" client = spanner.Client(project=project, credentials=credentials) @@ -119,7 +118,8 @@ def connect(instance: str, database: str, project: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, - pool: Optional[spanner.Pool] = None) -> SpannerApi: + pool: Optional[spanner_pool.AbstractSessionPool] = None + ) -> SpannerApi: """Connects to the Spanner database and sets the global spanner_api.""" connection = SpannerConnection( instance, database, project=project, credentials=credentials, pool=pool) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 2c9e43b..0174930 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -14,8 +14,6 @@ # limitations under the License. """Used with Model#where and Model#count to help create Spanner queries.""" -from __future__ import annotations - import abc import enum from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union diff --git a/spanner_orm/decorator.py b/spanner_orm/decorator.py index d492bc8..bd2a352 100644 --- a/spanner_orm/decorator.py +++ b/spanner_orm/decorator.py @@ -14,8 +14,6 @@ # limitations under the License. """Transaction decorators.""" -from __future__ import annotations - from typing import Callable, TypeVar from spanner_orm import api diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 1d6e5b2..64f40c6 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -14,8 +14,6 @@ # limitations under the License. """Helper to deal with field types in Spanner interactions.""" -from __future__ import annotations - import abc import datetime from typing import Any, Type @@ -25,6 +23,25 @@ from google.cloud.spanner_v1.proto import type_pb2 +class FieldType(abc.ABC): + """Base class for column types for Spanner interactions.""" + + @staticmethod + @abc.abstractmethod + def ddl() -> str: + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def grpc_type() -> type_pb2.Type: + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def validate_type(value: Any) -> None: + raise NotImplementedError + + class Field(object): """Represents a column in a table as a field in a model.""" @@ -62,25 +79,6 @@ def validate(self, value) -> None: self._type.validate_type(value) -class FieldType(abc.ABC): - """Base class for column types for Spanner interactions.""" - - @staticmethod - @abc.abstractmethod - def ddl() -> str: - raise NotImplementedError - - @staticmethod - @abc.abstractmethod - def grpc_type() -> type_pb2.Type: - raise NotImplementedError - - @staticmethod - @abc.abstractmethod - def validate_type(value: Any) -> None: - raise NotImplementedError - - class Boolean(FieldType): """Represents a boolean type.""" diff --git a/spanner_orm/index.py b/spanner_orm/index.py index 807a61d..1a1ac35 100644 --- a/spanner_orm/index.py +++ b/spanner_orm/index.py @@ -14,8 +14,6 @@ # limitations under the License. """Represents an index on a Model.""" -from __future__ import annotations - from typing import List, Optional from spanner_orm import error diff --git a/spanner_orm/metadata.py b/spanner_orm/metadata.py index 9caaf95..8c7c5e2 100644 --- a/spanner_orm/metadata.py +++ b/spanner_orm/metadata.py @@ -28,8 +28,6 @@ # limitations under the License. """Hold information about a Model extracted from the class attributes.""" -from __future__ import annotations - from typing import Any, Dict, Type, Optional from spanner_orm import error @@ -86,7 +84,7 @@ def finalize(self) -> None: registry.model_registry().register(self.model_class) self._finalized = True - def add_metadata(self, metadata: ModelMetadata) -> None: + def add_metadata(self, metadata: 'ModelMetadata') -> None: self.table = metadata.table or self.table self.fields.update(metadata.fields) self.relations.update(metadata.relations) diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 2dd428d..06060cb 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -14,8 +14,6 @@ # limitations under the License. """Holds table-specific information to make querying spanner eaiser.""" -from __future__ import annotations - import collections import copy from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union @@ -101,7 +99,7 @@ def indexes(cls) -> Dict[str, index.Index]: return cls.meta.indexes @property - def interleaved(cls) -> Optional[Type[Model]]: + def interleaved(cls) -> Optional[Type['Model']]: if cls.meta.interleaved: return registry.model_registry().get(cls.meta.interleaved) return None @@ -151,7 +149,7 @@ def spanner_api(cls) -> api.SpannerApi: def all( cls, transaction: Optional[spanner_transaction.Transaction] = None - ) -> List[ModelObject]: + ) -> List['ModelObject']: """Returns all objects of this type stored in Spanner. Note: this method should only be called on subclasses of Model that have @@ -219,7 +217,7 @@ def count_equal(cls, @classmethod def find(cls, transaction: Optional[spanner_transaction.Transaction] = None, - **keys: Any) -> Optional[ModelObject]: + **keys: Any) -> Optional['ModelObject']: """Retrieves an object from Spanner based on the provided key. Args: @@ -237,7 +235,7 @@ def find(cls, @classmethod def find_multi(cls, transaction: Optional[spanner_transaction.Transaction], - keys: Iterable[Dict[str, Any]]) -> List[ModelObject]: + keys: Iterable[Dict[str, Any]]) -> List['ModelObject']: """Retrieves objects from Spanner based on the provided keys. Args: @@ -262,7 +260,7 @@ def find_multi(cls, transaction: Optional[spanner_transaction.Transaction], @classmethod def where(cls, transaction: Optional[spanner_transaction.Transaction], - *conditions: condition.Condition) -> List[ModelObject]: + *conditions: condition.Condition) -> List['ModelObject']: """Retrieves objects from Spanner based on the provided conditions. Args: @@ -283,7 +281,7 @@ def where(cls, transaction: Optional[spanner_transaction.Transaction], @classmethod def where_equal(cls, transaction: Optional[spanner_transaction.Transaction] = None, - **constraints: Any) -> List[ModelObject]: + **constraints: Any) -> List['ModelObject']: """Retrieves objects from Spanner based on the provided constraints. Args: @@ -307,7 +305,7 @@ def where_equal(cls, @classmethod def _results_to_models(cls, - results: Iterable[Iterable[Any]]) -> List[ModelObject]: + results: Iterable[Iterable[Any]]) -> List['ModelObject']: items = [dict(zip(cls.columns, result)) for result in results] return [cls(item, persisted=True) for item in items] @@ -347,7 +345,7 @@ def create_or_update(cls, @classmethod def delete_batch(cls, transaction: Optional[spanner_transaction.Transaction], - models: List[ModelObject]) -> None: + models: List['ModelObject']) -> None: """Deletes rows from Spanner based on the provided models' primary keys. Args: @@ -370,7 +368,7 @@ def delete_batch(cls, transaction: Optional[spanner_transaction.Transaction], @classmethod def save_batch(cls, transaction: Optional[spanner_transaction.Transaction], - models: List[ModelObject], + models: List['ModelObject'], force_write: bool = False) -> None: """Writes rows to Spanner based on the provided model data. @@ -484,7 +482,7 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) @property - def _metaclass(self) -> Type[Model]: + def _metaclass(self) -> Type['Model']: return type(self) @property @@ -558,7 +556,7 @@ def id(self) -> Dict[str, Any]: def reload( self, - transaction: spanner_transaction.Transaction = None) -> Optional[Model]: + transaction: spanner_transaction.Transaction = None) -> Optional['Model']: """Refreshes this object with information from Spanner. Args: @@ -585,7 +583,8 @@ def reload( self._persisted = True return self - def save(self, transaction: spanner_transaction.Transaction = None) -> Model: + def save(self, + transaction: spanner_transaction.Transaction = None) -> 'Model': """Persists this object to Spanner. Note: if the _persisted flag doesn't match whether this object is actually diff --git a/spanner_orm/registry.py b/spanner_orm/registry.py index 1c596cc..1e78894 100644 --- a/spanner_orm/registry.py +++ b/spanner_orm/registry.py @@ -14,8 +14,6 @@ # limitations under the License. """Registers Model classes so they can be referenced elsewhere.""" -from __future__ import annotations - from typing import Any, Dict, List, Type, Union import dataclasses diff --git a/spanner_orm/relationship.py b/spanner_orm/relationship.py index 54c320c..d929b8f 100644 --- a/spanner_orm/relationship.py +++ b/spanner_orm/relationship.py @@ -14,8 +14,6 @@ # limitations under the License. """Helps define a foreign key relationship between two models.""" -from __future__ import annotations - from typing import Any, Dict, List, Type, Union import dataclasses diff --git a/spanner_orm/table_apis.py b/spanner_orm/table_apis.py index b289f74..3ee4625 100644 --- a/spanner_orm/table_apis.py +++ b/spanner_orm/table_apis.py @@ -14,8 +14,6 @@ # limitations under the License. """Table-level API lambdas for Spanner transactions.""" -from __future__ import annotations - import logging from typing import Any, Dict, Iterable, List diff --git a/spanner_orm/tests/migrations_test.py b/spanner_orm/tests/migrations_test.py index 07c1610..1324e07 100644 --- a/spanner_orm/tests/migrations_test.py +++ b/spanner_orm/tests/migrations_test.py @@ -14,6 +14,9 @@ # limitations under the License. import logging import os +import shutil +import stat +import tempfile import unittest from unittest import mock @@ -25,11 +28,13 @@ class MigrationsTest(unittest.TestCase): - TEST_DIR = os.path.dirname(__file__) + TEST_DIR = tempfile.mkdtemp() TEST_MIGRATIONS_DIR = os.path.join(TEST_DIR, 'migrations') def test_retrieve(self): - manager = migration_manager.MigrationManager(self.TEST_MIGRATIONS_DIR) + testdata_filename = os.path.join(os.path.dirname(__file__), + 'migrations') + manager = migration_manager.MigrationManager(testdata_filename) migrations = manager.migrations self.assertEqual(len(migrations), 3) self.assertEqual(migrations[2].prev_migration_id, @@ -38,6 +43,17 @@ def test_retrieve(self): migrations[0].migration_id) def test_generate(self): + testdata_filename = os.path.join(os.path.dirname(__file__), + 'migrations') + shutil.rmtree(self.TEST_MIGRATIONS_DIR) + shutil.copytree(testdata_filename, self.TEST_MIGRATIONS_DIR) + os.chmod(self.TEST_MIGRATIONS_DIR, + stat.S_IRWXO | stat.S_IRWXU) + for f in os.listdir(self.TEST_MIGRATIONS_DIR): + file_path = os.path.join(self.TEST_MIGRATIONS_DIR, f) + if not os.path.isdir(file_path): + os.chmod(file_path, + stat.S_IROTH | stat.S_IWOTH | stat.S_IRUSR | stat.S_IWUSR) manager = migration_manager.MigrationManager(self.TEST_MIGRATIONS_DIR) path = manager.generate('test migration') try: @@ -47,7 +63,7 @@ def test_generate(self): self.assertIsInstance(migration_.upgrade(), update.NoUpdate) self.assertIsInstance(migration_.downgrade(), update.NoUpdate) finally: - os.remove(path) + shutil.rmtree(self.TEST_MIGRATIONS_DIR) def test_order_migrations(self): first = migration.Migration('1', None) @@ -110,7 +126,8 @@ def test_order_migrations_error_on_no_successor(self): def test_filter_migrations(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor( + connection, self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -130,7 +147,8 @@ def test_filter_migrations(self): def test_filter_migrations_error_on_bad_last_migration(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor( + connection, self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -147,7 +165,8 @@ def test_filter_migrations_error_on_bad_last_migration(self): def test_validate_migrations(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor( + connection, self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -165,7 +184,8 @@ def test_validate_migrations(self): def test_validate_migrations_error_on_unmigrated_after_migrated(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor( + connection, self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -185,7 +205,8 @@ def test_validate_migrations_error_on_unmigrated_after_migrated(self): def test_validate_migrations_error_on_unmigrated_first(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor( + connection, self.TEST_MIGRATIONS_DIR) first = migration.Migration('2', '1') with mock.patch.object(executor, 'migrations') as migrations: @@ -203,7 +224,8 @@ def test_validate_migrations_error_on_unmigrated_first(self): def test_migrate(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor( + connection, self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -217,7 +239,8 @@ def test_migrate(self): def test_rollback(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor( + connection, self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -229,6 +252,10 @@ def test_rollback(self): executor.rollback('1') self.assertEqual(migrated, {'1': False, '2': False, '3': False}) + @classmethod + def tearDownClass(cls): + super().tearDownClass() + shutil.rmtree(MigrationsTest.TEST_DIR) if __name__ == '__main__': logging.basicConfig() From a37503f5cdb3f16e0efc59fd6561810b2282824a Mon Sep 17 00:00:00 2001 From: qadro87 <57414069+qadro87@users.noreply.github.com> Date: Mon, 11 Nov 2019 12:59:51 -0800 Subject: [PATCH 002/131] Add argparse 3.7 compatibility. (#73) * Add argparse 3.7 compatibility. * Enforce subcommand manually in the admin script. --- spanner_orm/admin/scripts.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/spanner_orm/admin/scripts.py b/spanner_orm/admin/scripts.py index b57b2a7..dd445f9 100644 --- a/spanner_orm/admin/scripts.py +++ b/spanner_orm/admin/scripts.py @@ -42,11 +42,12 @@ def rollback(args: Any) -> None: def main(as_module: bool = False) -> None: prog = 'spanner-orm' if as_module else None parser = argparse.ArgumentParser(prog=prog) + # 'subcommand' is actually required, but required subparsers are not supported + # for Python < 3.7. subparsers = parser.add_subparsers( dest='subcommand', title='subcommands', - description='valid subcommands', - required=True) + description='valid subcommands') generate_parser = subparsers.add_parser( 'generate', help='Generate a new migration') @@ -73,7 +74,10 @@ def main(as_module: bool = False) -> None: rollback_parser.set_defaults(execute=rollback) args = parser.parse_args() - args.execute(args) + if args.subcommand is None: + parser.print_help() + else: + args.execute(args) if __name__ == '__main__': From e935c0be2435609e953dc1221fa69a44484ec998 Mon Sep 17 00:00:00 2001 From: Aniruddha Maru Date: Tue, 5 May 2020 11:16:37 -0700 Subject: [PATCH 003/131] Add float type --- .gitignore | 2 ++ README.md | 1 + spanner_orm/__init__.py | 1 + spanner_orm/field.py | 19 ++++++++++++++++++- spanner_orm/tests/model_api_test.py | 9 +++++---- spanner_orm/tests/model_test.py | 6 ++++-- spanner_orm/tests/models.py | 2 ++ spanner_orm/tests/update_test.py | 3 ++- 8 files changed, 35 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 7c47c20..3aed126 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ dist build __pycache__ +.eggs +*.egg-info diff --git a/README.md b/README.md index c3d9a02..db2eeb5 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ class TestModel(spanner_orm.Model): # The name of the column is the same as the name of the class attribute id = spanner_orm.Field(spanner_orm.String, primary_key=True) value = spanner_orm.Field(spanner_orm.Integer, nullable=True) + number = spanner_orm.Field(spanner_orm.Float, nullable=True) # Secondary indexes are specified in a similar manner to fields: value_index = spanner_orm.Index(['value']) diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 9452974..1a77d85 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -61,6 +61,7 @@ Boolean = field.Boolean Field = field.Field Integer = field.Integer +Float = field.Float Index = index.Index Relationship = relationship.Relationship String = field.String diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 64f40c6..b8bc6b3 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -113,6 +113,23 @@ def validate_type(value: Any) -> None: raise error.ValidationError('{} is not of type int'.format(value)) +class Float(FieldType): + """Represents a float type.""" + + @staticmethod + def ddl() -> str: + return "FLOAT64" + + @staticmethod + def grpc_type() -> type_pb2.Type: + return type_pb2.Type(code=type_pb2.FLOAT64) + + @staticmethod + def validate_type(value: Any) -> None: + if not isinstance(value, (int, float)): + raise error.ValidationError("{} is not of type float".format(value)) + + class String(FieldType): """Represents a string type.""" @@ -167,4 +184,4 @@ def validate_type(value: Any) -> None: raise error.ValidationError('{} is not of type datetime'.format(value)) -ALL_TYPES = [Boolean, Integer, String, StringArray, Timestamp] +ALL_TYPES = [Boolean, Integer, Float, String, StringArray, Timestamp] diff --git a/spanner_orm/tests/model_api_test.py b/spanner_orm/tests/model_api_test.py index 4ce0a52..ff98a2e 100644 --- a/spanner_orm/tests/model_api_test.py +++ b/spanner_orm/tests/model_api_test.py @@ -25,14 +25,14 @@ class ModelApiTest(unittest.TestCase): @mock.patch('spanner_orm.table_apis.find') def test_find_calls_api(self, find): mock_transaction = mock.Mock() - models.UnittestModel.find(mock_transaction, string='string', int_=1) + models.UnittestModel.find(mock_transaction, string='string', int_=1, float_=2.3) find.assert_called_once() (transaction, table, columns, keyset), _ = find.call_args self.assertEqual(transaction, mock_transaction) self.assertEqual(table, models.UnittestModel.table) self.assertEqual(columns, models.UnittestModel.columns) - self.assertEqual(keyset.keys, [[1, 'string']]) + self.assertEqual(keyset.keys, [[1, 2.3, 'string']]) @mock.patch('spanner_orm.table_apis.find') def test_find_result(self, find): @@ -52,7 +52,8 @@ def test_find_multi_calls_api(self, find): mock_transaction = mock.Mock() models.UnittestModel.find_multi(mock_transaction, [{ 'string': 'string', - 'int_': 1 + 'int_': 1, + 'float_': 2.3 }]) find.assert_called_once() @@ -60,7 +61,7 @@ def test_find_multi_calls_api(self, find): self.assertEqual(transaction, mock_transaction) self.assertEqual(table, models.UnittestModel.table) self.assertEqual(columns, models.UnittestModel.columns) - self.assertEqual(keyset.keys, [[1, 'string']]) + self.assertEqual(keyset.keys, [[1, 2.3, 'string']]) @mock.patch('spanner_orm.table_apis.find') def test_find_multi_result(self, find): diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 804368a..0cddd56 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -40,13 +40,14 @@ def test_set_error_on_primary_key(self): with self.assertRaises(AttributeError): test_model.key = 'error' - @parameterized.parameters(('int_2', 'foo'), ('string_2', 5), + @parameterized.parameters(('int_2', 'foo'), ('float_2', 'bar'), ('string_2', 5), ('string_array', 'foo'), ('timestamp', 5)) def test_set_error_on_invalid_type(self, attribute, value): string_array = ['foo', 'bar'] timestamp = datetime.datetime.now(tz=datetime.timezone.utc) test_model = models.UnittestModel({ 'int_': 0, + 'float_': 0, 'string': '', 'string_array': string_array, 'timestamp': timestamp @@ -61,7 +62,7 @@ def test_get_attr(self): self.assertEqual(test_model.value_2, None) def test_id(self): - primary_key = {'string': 'foo', 'int_': 5} + primary_key = {'string': 'foo', 'int_': 5, 'float_': 2.3} all_data = primary_key.copy() all_data.update({ 'timestamp': datetime.datetime.now(tz=datetime.timezone.utc), @@ -83,6 +84,7 @@ def test_object_changes(self): timestamp = datetime.datetime.now(tz=datetime.timezone.utc) test_model = models.UnittestModel({ 'int_': 0, + 'float_': 0, 'string': '', 'string_array': array, 'timestamp': timestamp diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 2a8218a..9aef1ed 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -73,6 +73,8 @@ class UnittestModel(model.Model): __table__ = 'table' int_ = field.Field(field.Integer, primary_key=True) int_2 = field.Field(field.Integer, nullable=True) + float_ = field.Field(field.Float, primary_key=True) + float_2 = field.Field(field.Float, nullable=True) string = field.Field(field.String, primary_key=True) string_2 = field.Field(field.String, nullable=True) timestamp = field.Field(field.Timestamp) diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index 7b55965..633b0c3 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -66,9 +66,10 @@ def test_create_table(self, get_model): test_update.validate() test_model_ddl = ('CREATE TABLE table (int_ INT64 NOT NULL, int_2 INT64,' + ' float_ FLOAT64 NOT NULL, float_2 FLOAT64,' ' string STRING(MAX) NOT NULL, string_2 STRING(MAX),' ' timestamp TIMESTAMP NOT NULL, string_array' - ' ARRAY) PRIMARY KEY (int_, string)') + ' ARRAY) PRIMARY KEY (int_, float_, string)') self.assertEqual(test_update.ddl(), test_model_ddl) @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') From 5fcd327d0ebd9b109545caa98420ab95c4c98e36 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Fri, 7 Aug 2020 15:01:22 -0400 Subject: [PATCH 004/131] Add infra for running tests with Spanner Emulator, and a simple test using Spanner Emulator --- setup.py | 2 +- spanner_orm/tests/migrations_emulator_test.py | 45 +++++++ .../create_small_test_model.py | 31 +++++ .../tests/spanner_emulator/__init__.py | 0 .../tests/spanner_emulator/emulator.py | 111 +++++++++++++++++ spanner_orm/tests/spanner_emulator/testlib.py | 117 ++++++++++++++++++ 6 files changed, 305 insertions(+), 1 deletion(-) create mode 100644 spanner_orm/tests/migrations_emulator_test.py create mode 100644 spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py create mode 100644 spanner_orm/tests/spanner_emulator/__init__.py create mode 100644 spanner_orm/tests/spanner_emulator/emulator.py create mode 100644 spanner_orm/tests/spanner_emulator/testlib.py diff --git a/setup.py b/setup.py index 57a6355..566b38b 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ include_package_data=True, python_requires='~=3.7', install_requires=['google-cloud-spanner >= 1.6, <2.0.0dev'], - tests_require=['absl-py'], + tests_require=['absl-py', 'portpicker'], entry_points={ 'console_scripts': ['spanner-orm = spanner_orm.admin.scripts:main'] }) diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py new file mode 100644 index 0000000..aec14f3 --- /dev/null +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -0,0 +1,45 @@ +# python3 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os + + +import spanner_orm +from spanner_orm.tests import models +from spanner_orm.tests.spanner_emulator import testlib + +_EXAMPLE_TIMESTAMP = None + +class MigrationsEmulatorTest(testlib.TestCase): + TEST_MIGRATIONS_DIR = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + 'migrations_for_emulator_test', + ) + + def setUp(self): + super().setUp() + self.run_orm_migrations(self.TEST_MIGRATIONS_DIR) + + def test_basic(self): + test_model = models.SmallTestModel({'key': 'key', 'value_1': 'value'}) + test_model.save() + self.assertEqual( + [x.values for x in models.SmallTestModel.all()], + [{'key': 'key', 'value_1': 'value', 'value_2': None}], + ) + +if __name__ == '__main__': + logging.basicConfig() + unittest.main() diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py new file mode 100644 index 0000000..e4e8dd8 --- /dev/null +++ b/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py @@ -0,0 +1,31 @@ +# Lint as: python3 +"""Creates table with SmallTestModel. + +Migration ID: 'f735d6b706d2' +Created: 2020-07-10 16:24 +""" + +import spanner_orm +from spanner_orm import field + +migration_id = 'f735d6b706d2' +prev_migration_id = None + + +class OriginalSmallTestModelsTable(spanner_orm.model.Model): + """ORM Model with the original schema for the DiabloVerdicts table.""" + + __table__ = 'SmallTestModel' + key = field.Field(field.String, primary_key=True) + value_1 = field.Field(field.String) + value_2 = field.Field(field.String, nullable=True) + + +def upgrade() -> spanner_orm.CreateTable: + """See ORM migrations interface.""" + return spanner_orm.CreateTable(OriginalSmallTestModelsTable) + + +def downgrade() -> spanner_orm.DropTable: + """See ORM migrations interface.""" + return spanner_orm.DropTable(OriginalSmallTestModelsTable.__table__) diff --git a/spanner_orm/tests/spanner_emulator/__init__.py b/spanner_orm/tests/spanner_emulator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/spanner_orm/tests/spanner_emulator/emulator.py b/spanner_orm/tests/spanner_emulator/emulator.py new file mode 100644 index 0000000..fb1576e --- /dev/null +++ b/spanner_orm/tests/spanner_emulator/emulator.py @@ -0,0 +1,111 @@ +# Lint as: python3 +"""Python test wrapper for the cloud spanner emulator binary.""" + +import os +import subprocess +from typing import Mapping, Optional + +import portpicker + +from google.auth import credentials +from google.cloud.spanner_v1 import client + +# Environment variable with path to Spanner Emulator binary. +_EMULATOR_BINARY_PATH_ENV_VAR = "SPANNER_EMULATOR_BINARY_PATH" +# Environment variable used by the client library to set the correct URL. +_CLIENT_EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST" + + +class Emulator: + """Spanner emulator python wrapper. + + Below is an example of how this wrapper can be used in a test class. + + class SpannerTest(googletest.TestCase): + + def setUp(self): + super().setUp() + self._spanner_emulator = emulator.Emulator() + self.addCleanup(self._spanner_emulator.stop) + + def test_something(self): + client = self._spanner_emulator.get_client() + # Create tables, add data, and retrieve it + """ + + def __init__(self, + *, + spanner_emulator_port: Optional[int] = None, + log_emulator_requests: bool = False) -> None: + """Initializer. + + Args: + spanner_emulator_port: The port to start the emulator on. A random unused + port is picked if this value is None. + log_emulator_requests: If true, the emulator subprocess will log each + request and response message. + """ + + self._spanner_emulator_port = spanner_emulator_port + self._log_emulator_requests = log_emulator_requests + + self._process = None + self._host_port = None + + self._start() + self._wait_for_ready() + + def get_client( + self, + project: str = "test-project", + client_options: Optional[Mapping[str, str]] = None) -> client.Client: + """Returns a spanner client for interacting with the emulator. + + Args: + project: Name of the project that the client should point to. + client_options: Any client options that the client should be created with. + """ + return client.Client( + project=project, + credentials=credentials.AnonymousCredentials(), + client_options=client_options) + + def _start(self) -> None: + """Starts the emulator as a subprocess.""" + port = self._spanner_emulator_port or portpicker.pick_unused_port() + self._host_port = f"localhost:{port}" + + # Used by the client library to point to the correct spanner endpoint. + os.environ[_CLIENT_EMULATOR_ENV_VAR] = self._host_port + + try: + emulator_binary_path = os.environ[_EMULATOR_BINARY_PATH_ENV_VAR] + except KeyError as key_error: + raise ValueError( + f'Please set the environment variable {_EMULATOR_BINARY_PATH_ENV_VAR} ' + 'to a binary with the Cloud Spanner Emulator. For more info, see ' + 'https://github.com/GoogleCloudPlatform/cloud-spanner-emulator.' + ) from key_error + + self._process = subprocess.Popen([ + emulator_binary_path, + "--log_requests" if self._log_emulator_requests else "--nolog_requests", + "--host_port", + self._host_port, + ]) + + def _wait_for_ready(self) -> None: + """Waits for the emulator to become ready.""" + emulator_client = self.get_client() + + # This will not return until the emulator is running. + for _ in emulator_client.list_instance_configs(): + return + + def stop(self) -> None: + """If there is an emulator process, stops it and waits for it to stop.""" + if self._process is not None: + self._process.terminate() + self._process.wait() + self._process = None + self._host_port = None diff --git a/spanner_orm/tests/spanner_emulator/testlib.py b/spanner_orm/tests/spanner_emulator/testlib.py new file mode 100644 index 0000000..0b7002e --- /dev/null +++ b/spanner_orm/tests/spanner_emulator/testlib.py @@ -0,0 +1,117 @@ +# Lint as: python3 +"""Superclass and helpers for tests that use the spanner emulator.""" + +import os +import unittest +import uuid + +import spanner_orm + +from google.cloud.spanner_v1 import client +from google.cloud.spanner_v1 import database +from google.cloud.spanner_v1 import instance +from spanner_orm.tests.spanner_emulator import emulator + + +def _make_emulator_spanner_orm_connection( + db: database.Database, inst: instance.Instance, + spanner_client: client.Client) -> spanner_orm.SpannerConnection: + """Returns an spanner_orm.connection to a spanner database. + + Args: + db: database that already exists in spanner + inst: instance that already exists in spanner + spanner_client: client with access to the database and instance provided + """ + # project/project-name -> project-name. + project_name = spanner_client.project_name.split('/')[1] + return spanner_orm.SpannerConnection( + inst.instance_id, + db.database_id, + project=project_name, + credentials=spanner_client.credentials) + + +def _get_instance(spanner_client: client.Client) -> instance.Instance: + """Returns a spanner instance from the client. + + First, checks if there is an existing instance that can be re-used, returning + it if one exists. Otherwise, create a new instance, waits for it to be created + and then returns it. + + Args: + spanner_client: An initialized spanner client. + """ + existing_instances = list(spanner_client.list_instances()) + if existing_instances: + return existing_instances[0] + + # The emulator has one default config. + config = list(spanner_client.list_instance_configs())[0] + inst = spanner_client.instance( + 'qwiklabs-spanner-instance-name', configuration_name=config.name) + inst.create().result() + return inst + + +def _migrate_database_at_connection(connection: spanner_orm.SpannerConnection, + migrations_dir: str) -> None: + """Applies the migrations to the provided connection.""" + spanner_orm.from_connection(connection) + executor = spanner_orm.MigrationExecutor(connection, basedir=migrations_dir) + executor.migrate() + + +def _database_id() -> str: + """Returns a new database ID that's unlikely to conflict with any other.""" + random_string = str(uuid.uuid4()).split('-')[0] + return 'qwiklabs-db-' + random_string + + +class TestCase(unittest.TestCase): + """Sets up a spanner emulator database for each test case. + + Any test class that subclasses this class will have a spanner database + setup for it. That database is empty by default; most subclasses will want to + call run_orm_migrations() in their setUp() method. + + Attributes: + spanner_emulator_client: Client, for use by non-ORM tests. + spanner_emulator_instance: Instance, for use by non-ORM tests. + spanner_emulator_database: Database, for use by non-ORM tests. + """ + + _spanner_emulator: emulator.Emulator + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._spanner_emulator = emulator.Emulator() + + def setUp(self): + super().setUp() + self.spanner_emulator_client = self._spanner_emulator.get_client() + self.spanner_emulator_instance = _get_instance(self.spanner_emulator_client) + self.spanner_emulator_database = self.spanner_emulator_instance.database( + _database_id()) + self.spanner_emulator_database.create().result() + + @classmethod + def tearDownClass(cls): + cls._spanner_emulator.stop() + super().tearDownClass() + + def run_orm_migrations(self, migrations_folder: str) -> None: + """Runs ORM migrations in the given directory and connects the ORM.""" + _migrate_database_at_connection( + _make_emulator_spanner_orm_connection(self.spanner_emulator_database, + self.spanner_emulator_instance, + self.spanner_emulator_client), + migrations_folder) + # spanner_orm closes the connection to Spanner after migrating so we need to + # reconnect before making other Spanner calls. + spanner_orm.from_connection( + _make_emulator_spanner_orm_connection(self.spanner_emulator_database, + self.spanner_emulator_instance, + self.spanner_emulator_client)) + From b4ac27f8edb1ef2ac4f05d151aa5587018eea76c Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Fri, 7 Aug 2020 17:36:32 -0400 Subject: [PATCH 005/131] move non-test files out of tests/ --- .../spanner_emulator => testlib}/__init__.py | 0 .../spanner_emulator/emulator.py | 0 .../spanner_emulator/testlib.py | 2 +- spanner_orm/tests/migrations_emulator_test.py | 7 +++---- .../create_small_test_model.py | 21 +++++++++++++++---- 5 files changed, 21 insertions(+), 9 deletions(-) rename spanner_orm/{tests/spanner_emulator => testlib}/__init__.py (100%) rename spanner_orm/{tests => testlib}/spanner_emulator/emulator.py (100%) rename spanner_orm/{tests => testlib}/spanner_emulator/testlib.py (98%) diff --git a/spanner_orm/tests/spanner_emulator/__init__.py b/spanner_orm/testlib/__init__.py similarity index 100% rename from spanner_orm/tests/spanner_emulator/__init__.py rename to spanner_orm/testlib/__init__.py diff --git a/spanner_orm/tests/spanner_emulator/emulator.py b/spanner_orm/testlib/spanner_emulator/emulator.py similarity index 100% rename from spanner_orm/tests/spanner_emulator/emulator.py rename to spanner_orm/testlib/spanner_emulator/emulator.py diff --git a/spanner_orm/tests/spanner_emulator/testlib.py b/spanner_orm/testlib/spanner_emulator/testlib.py similarity index 98% rename from spanner_orm/tests/spanner_emulator/testlib.py rename to spanner_orm/testlib/spanner_emulator/testlib.py index 0b7002e..840f683 100644 --- a/spanner_orm/tests/spanner_emulator/testlib.py +++ b/spanner_orm/testlib/spanner_emulator/testlib.py @@ -10,7 +10,7 @@ from google.cloud.spanner_v1 import client from google.cloud.spanner_v1 import database from google.cloud.spanner_v1 import instance -from spanner_orm.tests.spanner_emulator import emulator +from spanner_orm.testlib.spanner_emulator import emulator def _make_emulator_spanner_orm_connection( diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py index aec14f3..3e4f8a1 100644 --- a/spanner_orm/tests/migrations_emulator_test.py +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -14,15 +14,14 @@ # limitations under the License. import logging import os +import unittest import spanner_orm from spanner_orm.tests import models -from spanner_orm.tests.spanner_emulator import testlib +from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib -_EXAMPLE_TIMESTAMP = None - -class MigrationsEmulatorTest(testlib.TestCase): +class MigrationsEmulatorTest(spanner_emulator_testlib.TestCase): TEST_MIGRATIONS_DIR = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'migrations_for_emulator_test', diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py index e4e8dd8..0902bdc 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py @@ -1,4 +1,17 @@ # Lint as: python3 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Creates table with SmallTestModel. Migration ID: 'f735d6b706d2' @@ -12,8 +25,8 @@ prev_migration_id = None -class OriginalSmallTestModelsTable(spanner_orm.model.Model): - """ORM Model with the original schema for the DiabloVerdicts table.""" +class OriginalSmallTestModelTable(spanner_orm.model.Model): + """ORM Model with the original schema for the SmallTestModel table.""" __table__ = 'SmallTestModel' key = field.Field(field.String, primary_key=True) @@ -23,9 +36,9 @@ class OriginalSmallTestModelsTable(spanner_orm.model.Model): def upgrade() -> spanner_orm.CreateTable: """See ORM migrations interface.""" - return spanner_orm.CreateTable(OriginalSmallTestModelsTable) + return spanner_orm.CreateTable(OriginalSmallTestModelTable) def downgrade() -> spanner_orm.DropTable: """See ORM migrations interface.""" - return spanner_orm.DropTable(OriginalSmallTestModelsTable.__table__) + return spanner_orm.DropTable(OriginalSmallTestModelTable.__table__) From 8fe9dd06a263d6a6db27cbdbb1a75e366896e4b2 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Fri, 7 Aug 2020 17:38:31 -0400 Subject: [PATCH 006/131] Add licenses --- spanner_orm/testlib/spanner_emulator/emulator.py | 13 +++++++++++++ spanner_orm/testlib/spanner_emulator/testlib.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/spanner_orm/testlib/spanner_emulator/emulator.py b/spanner_orm/testlib/spanner_emulator/emulator.py index fb1576e..cd29669 100644 --- a/spanner_orm/testlib/spanner_emulator/emulator.py +++ b/spanner_orm/testlib/spanner_emulator/emulator.py @@ -1,4 +1,17 @@ # Lint as: python3 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Python test wrapper for the cloud spanner emulator binary.""" import os diff --git a/spanner_orm/testlib/spanner_emulator/testlib.py b/spanner_orm/testlib/spanner_emulator/testlib.py index 840f683..4b14850 100644 --- a/spanner_orm/testlib/spanner_emulator/testlib.py +++ b/spanner_orm/testlib/spanner_emulator/testlib.py @@ -1,4 +1,17 @@ # Lint as: python3 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Superclass and helpers for tests that use the spanner emulator.""" import os From 19e8e3d9986ef8fd6bce00931d9344bf58a59dc9 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Mon, 10 Aug 2020 19:00:30 -0400 Subject: [PATCH 007/131] fix some names used in tests --- spanner_orm/testlib/spanner_emulator/testlib.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spanner_orm/testlib/spanner_emulator/testlib.py b/spanner_orm/testlib/spanner_emulator/testlib.py index 4b14850..aac4049 100644 --- a/spanner_orm/testlib/spanner_emulator/testlib.py +++ b/spanner_orm/testlib/spanner_emulator/testlib.py @@ -62,7 +62,7 @@ def _get_instance(spanner_client: client.Client) -> instance.Instance: # The emulator has one default config. config = list(spanner_client.list_instance_configs())[0] inst = spanner_client.instance( - 'qwiklabs-spanner-instance-name', configuration_name=config.name) + 'spanner-instance-name', configuration_name=config.name) inst.create().result() return inst @@ -78,7 +78,7 @@ def _migrate_database_at_connection(connection: spanner_orm.SpannerConnection, def _database_id() -> str: """Returns a new database ID that's unlikely to conflict with any other.""" random_string = str(uuid.uuid4()).split('-')[0] - return 'qwiklabs-db-' + random_string + return 'spanner-db-' + random_string class TestCase(unittest.TestCase): From 3f72113b16ee60fa053f159dd5afab70889b3eca Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Wed, 12 Aug 2020 11:18:36 -0400 Subject: [PATCH 008/131] First pass at adding support for foreign keys --- spanner_orm/admin/update.py | 15 +++- spanner_orm/foreign_key_relationship.py | 85 +++++++++++++++++++ spanner_orm/metadata.py | 11 +++ spanner_orm/model.py | 16 ++++ .../create_foreign_key_test_model.py | 44 ++++++++++ spanner_orm/tests/models.py | 13 +++ spanner_orm/tests/update_test.py | 17 ++++ 7 files changed, 199 insertions(+), 2 deletions(-) create mode 100644 spanner_orm/foreign_key_relationship.py create mode 100644 spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 4277e10..d804b33 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -49,13 +49,24 @@ def __init__(self, model_: Type[model.Model]): self._model = model_ def ddl(self) -> str: - fields = [ + key_fields = [ '{} {}'.format(name, field.ddl()) for name, field in self._model.fields.items() ] + key_fields_ddl = ', '.join(key_fields) + if self._model.foreign_key_relations: + fk = list(self._model.foreign_key_relations.values())[0] + for referencing_table_col, referenced_table_col in fk.constraints.items(): + key_fields_ddl += ( + ', FOREIGN KEY ({referencing_table_col}) REFERENCES' + ' {parent} ({referenced_table_col})').format( + parent=fk.destination, + referencing_table_col=referencing_table_col, + referenced_table_col=referenced_table_col, + ) index_ddl = 'PRIMARY KEY ({})'.format(', '.join(self._model.primary_keys)) statement = 'CREATE TABLE {} ({}) {}'.format(self._model.table, - ', '.join(fields), index_ddl) + key_fields_ddl, index_ddl) if self._model.interleaved: statement += ', INTERLEAVE IN PARENT {parent} ON DELETE CASCADE'.format( diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py new file mode 100644 index 0000000..f73249a --- /dev/null +++ b/spanner_orm/foreign_key_relationship.py @@ -0,0 +1,85 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helps define a foreign key relationship between two models.""" + +from typing import Any, List, Mapping, Type, Union + +import dataclasses +from spanner_orm import error +from spanner_orm import registry + + +@dataclasses.dataclass +class RelationshipConstraint: + destination_class: Type[Any] + destination_column: str + origin_class: Type[Any] + origin_column: str + + +class ForeignKeyRelationship(object): + """Helps define a foreign key relationship between two models.""" + + def __init__(self, + referenced_table_name: str, + constraints: Mapping[str, str]): + """Creates a ForeignKeyRelationship. + + Args: + referenced_table_name: Destination model class or fully qualified class + name of the destination model. + constraints: Dictionary where the keys are names of columns from the + referencing table and the values are the names of the columns in the + referenced table. + """ + self.origin = None + self.name = None + self._referenced_table_name = referenced_table_name + self._constraints = constraints + + @property + def constraints(self) -> List[RelationshipConstraint]: + return self._constraints + + @property + def destination(self) -> Type[Any]: + return self._referenced_table_name + if not self._destination: + self._destination = registry.model_registry().get( + self._referenced_table_name) + return self._destination + + @property + def single(self) -> bool: + return self._single + + def _parse_constraints(self) -> List[RelationshipConstraint]: + """Validates the dictionary of constraints and turns it into Conditions.""" + constraints = [] + for origin_column, destination_column in self._constraints.items(): + if origin_column not in self.origin.fields: + raise error.ValidationError( + 'Origin column must be present in origin model') + + if destination_column not in self.destination.fields: + raise error.ValidationError( + 'Destination column must be present in destination model') + + # TODO(dbrandao): remove when pytype #234 is fixed + constraints.append( + RelationshipConstraint(self.destination, destination_column, + self.origin, origin_column)) # type: ignore + + return constraints diff --git a/spanner_orm/metadata.py b/spanner_orm/metadata.py index 8c7c5e2..1fd10b1 100644 --- a/spanner_orm/metadata.py +++ b/spanner_orm/metadata.py @@ -32,6 +32,7 @@ from spanner_orm import error from spanner_orm import field +from spanner_orm import foreign_key_relationship from spanner_orm import index from spanner_orm import registry from spanner_orm import relationship @@ -44,6 +45,7 @@ def __init__(self, table: Optional[str] = None, fields: Optional[Dict[str, field.Field]] = None, relations: Optional[Dict[str, relationship.Relationship]] = None, + foreign_key_relations: Optional[Dict[str, foreign_key_relationship.ForeignKeyRelationship]] = None, indexes: Optional[Dict[str, index.Index]] = None, interleaved: Optional[str] = None, model_class: Optional[Type[Any]] = None): @@ -55,6 +57,7 @@ def __init__(self, self.model_class = model_class self.primary_keys = [] self.relations = dict(relations or {}) + self.foreign_key_relations = dict(foreign_key_relations or {}) self.table = table or '' def finalize(self) -> None: @@ -101,6 +104,14 @@ def add_relation(self, name: str, new_relation.name = name self.relations[name] = new_relation + def add_foreign_key_relation( + self, + name: str, + new_relation: foreign_key_relationship.ForeignKeyRelationship, + ) -> None: + new_relation.name = name + self.foreign_key_relations[name] = new_relation + def add_index(self, name: str, new_index: index.Index) -> None: new_index.name = name self.indexes[name] = new_index diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 06060cb..c4dc130 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -21,6 +21,7 @@ from spanner_orm import api from spanner_orm import condition from spanner_orm import error +from spanner_orm import foreign_key_relationship from spanner_orm import field from spanner_orm import index from spanner_orm import metadata @@ -52,12 +53,19 @@ def __new__(mcs, name: str, bases: Any, attrs: Dict[str, Any], **kwargs: Any): model_metadata.table = value elif key == '__interleaved__': model_metadata.interleaved = value + elif key == '__foreign_key__': + model_metadata.foreign_key = value if isinstance(value, field.Field): model_metadata.add_field(key, value) elif isinstance(value, index.Index): model_metadata.add_index(key, value) elif isinstance(value, relationship.Relationship): model_metadata.add_relation(key, value) + elif isinstance( + value, + foreign_key_relationship.ForeignKeyRelationship, + ): + model_metadata.add_foreign_key_relation(key, value) else: non_model_attrs[key] = value @@ -112,6 +120,14 @@ def primary_keys(cls) -> List[str]: def relations(cls) -> Dict[str, relationship.Relationship]: return cls.meta.relations + @property + def foreign_key_relations(cls) -> Dict[str, foreign_key_relationship.ForeignKeyRelationship]: + return cls.meta.foreign_key_relations + #if cls.meta.foreign_key: + # return registry.model_registry().get(cls.meta.foreign_key) + #return None + + @property def fields(cls) -> Dict[str, field.Field]: return cls.meta.fields diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py new file mode 100644 index 0000000..4c36500 --- /dev/null +++ b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py @@ -0,0 +1,44 @@ +# Lint as: python3 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Creates table with SmallTestModel. + +Migration ID: 'f735d6b706d3' +Created: 2020-07-10 16:24 +""" + +import spanner_orm +from spanner_orm import field + +migration_id = 'f735d6b706d3' +prev_migration_id = 'f735d6b706d2' + + +class OriginalForeignKeyTestModelTable(spanner_orm.model.Model): + """ORM Model with the original schema for the ForeignKeyTestModel table.""" + + __table__ = 'ForeignKeyTestModel' + __foreign_key__ = 'SmallTestModel' + key = field.Field(field.String, primary_key=True) + child_key = field.Field(field.String, primary_key=True) + + +def upgrade() -> spanner_orm.CreateTable: + """See ORM migrations interface.""" + return spanner_orm.CreateTable(OriginalForeignKeyTestModelTable) + + +def downgrade() -> spanner_orm.DropTable: + """See ORM migrations interface.""" + return spanner_orm.DropTable(OriginalForeignKeyTestModelTable.__table__) diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 9aef1ed..fde210a 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -15,6 +15,7 @@ """Models used by unit tests.""" from spanner_orm import field +from spanner_orm import foreign_key_relationship from spanner_orm import index from spanner_orm import model from spanner_orm import relationship @@ -61,6 +62,18 @@ class RelationshipTestModel(model.Model): parents = relationship.Relationship('spanner_orm.tests.models.SmallTestModel', {'parent_key': 'key'}) +class ForeignKeyTestModel(model.Model): + """Model class for testing foreign keys.""" + + __table__ = 'ForeignKeyTestModel' + # __foreign_key__ = 'SmallTestModel' + referencing_key = field.Field(field.String, primary_key=True) + value = field.Field(field.String) + foreign_key_relationship = foreign_key_relationship.ForeignKeyRelationship( + 'SmallTestModel', {'referencing_key': 'key'}) + # single=True) + #parents = relationship.Relationship('spanner_orm.tests.models.SmallTestModel', + # {'parent_key': 'key'}) class InheritanceTestModel(SmallTestModel): """Model class used for testing model inheritance.""" diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index 633b0c3..dca7e41 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -86,6 +86,23 @@ def test_create_table_interleaved(self, get_model): 'INTERLEAVE IN PARENT SmallTestModel ON DELETE CASCADE') self.assertEqual(test_update.ddl(), test_model_ddl) + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') + def test_create_table_foreign_key(self, get_model): + self.maxDiff = 1000 + + get_model.return_value = None + new_model = models.ForeignKeyTestModel + test_update = update.CreateTable(new_model) + test_update.validate() + + test_model_ddl = ( + 'CREATE TABLE ForeignKeyTestModel (' + 'referencing_key STRING(MAX) NOT NULL, ' + 'value STRING(MAX) NOT NULL, ' + 'FOREIGN KEY (referencing_key) REFERENCES SmallTestModel (key)) ' + 'PRIMARY KEY (referencing_key)') + self.assertEqual(test_update.ddl(), test_model_ddl) + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_create_table_error_on_existing_table(self, get_model): get_model.return_value = models.SmallTestModel From 33d4bf5813f8d72b3a41af31b94c3116cdb6140f Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Wed, 12 Aug 2020 13:35:00 -0400 Subject: [PATCH 009/131] Add support for multiple foreign keys in one table --- setup.py | 2 +- spanner_orm/admin/update.py | 7 ++--- spanner_orm/foreign_key_relationship.py | 21 ++++++------- spanner_orm/tests/migrations_emulator_test.py | 31 +++++++++++++++++++ .../create_foreign_key_test_model.py | 20 ++++++++---- spanner_orm/tests/models.py | 16 +++++----- spanner_orm/tests/update_test.py | 12 ++++--- 7 files changed, 75 insertions(+), 34 deletions(-) diff --git a/setup.py b/setup.py index 566b38b..011756a 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ include_package_data=True, python_requires='~=3.7', install_requires=['google-cloud-spanner >= 1.6, <2.0.0dev'], - tests_require=['absl-py', 'portpicker'], + tests_require=['absl-py', 'google-api-core', 'portpicker'], entry_points={ 'console_scripts': ['spanner-orm = spanner_orm.admin.scripts:main'] }) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index d804b33..2fbaa4a 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -54,13 +54,12 @@ def ddl(self) -> str: for name, field in self._model.fields.items() ] key_fields_ddl = ', '.join(key_fields) - if self._model.foreign_key_relations: - fk = list(self._model.foreign_key_relations.values())[0] - for referencing_table_col, referenced_table_col in fk.constraints.items(): + for relation in self._model.foreign_key_relations.values(): + for referencing_table_col, referenced_table_col in relation.constraints.items(): key_fields_ddl += ( ', FOREIGN KEY ({referencing_table_col}) REFERENCES' ' {parent} ({referenced_table_col})').format( - parent=fk.destination, + parent=relation.destination, referencing_table_col=referencing_table_col, referenced_table_col=referenced_table_col, ) diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py index f73249a..2561f12 100644 --- a/spanner_orm/foreign_key_relationship.py +++ b/spanner_orm/foreign_key_relationship.py @@ -22,11 +22,11 @@ @dataclasses.dataclass -class RelationshipConstraint: - destination_class: Type[Any] - destination_column: str - origin_class: Type[Any] - origin_column: str +class ForeignKeyRelationshipConstraint: + referencing_column: str + referenced_columns: str + referenced_table_name: str + class ForeignKeyRelationship(object): @@ -43,6 +43,7 @@ def __init__(self, constraints: Dictionary where the keys are names of columns from the referencing table and the values are the names of the columns in the referenced table. + # TODO(dgorelik): Allow constraints to have custom names. """ self.origin = None self.name = None @@ -50,22 +51,18 @@ def __init__(self, self._constraints = constraints @property - def constraints(self) -> List[RelationshipConstraint]: + def constraints(self) -> List[ForeignKeyRelationshipConstraint]: return self._constraints @property def destination(self) -> Type[Any]: - return self._referenced_table_name + return registry.model_registry().get(self._referenced_table_name).table if not self._destination: self._destination = registry.model_registry().get( self._referenced_table_name) return self._destination - @property - def single(self) -> bool: - return self._single - - def _parse_constraints(self) -> List[RelationshipConstraint]: + def _parse_constraints(self) -> List[ForeignKeyRelationshipConstraint]: """Validates the dictionary of constraints and turns it into Conditions.""" constraints = [] for origin_column, destination_column in self._constraints.items(): diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py index 3e4f8a1..8c9efb7 100644 --- a/spanner_orm/tests/migrations_emulator_test.py +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -12,11 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime import logging import os import unittest +from google.api_core import exceptions as google_api_exceptions import spanner_orm from spanner_orm.tests import models from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib @@ -39,6 +41,35 @@ def test_basic(self): [{'key': 'key', 'value_1': 'value', 'value_2': None}], ) + def test_error_with_missing_referencing_key(self): + with self.assertRaisesRegex( + google_api_exceptions.FailedPrecondition, + 'Cannot find referenced key', + ): + models.ForeignKeyTestModel({ + 'referencing_key_1': 'key', + 'referencing_key_2': 'key', + 'referencing_key_3': 42, + 'value': 'value' + }).save() + + def test_key(self): + test_model = models.SmallTestModel({'key': 'key', 'value_1': 'value'}) + test_model.save() + models.UnittestModel( + {'string': 'string', + 'int_': 42, + 'float_': 4.2, + 'timestamp': datetime.datetime.now(tz=datetime.timezone.utc), + }).save() + test_model_2 = models.ForeignKeyTestModel({ + 'referencing_key_1': 'key', + 'referencing_key_2': 'string', + 'referencing_key_3': 42, + 'value': 'value' + }) + test_model_2.save() + if __name__ == '__main__': logging.basicConfig() unittest.main() diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py index 4c36500..a88811d 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Creates table with SmallTestModel. +"""Creates table with ForeignKeyTestModel. Migration ID: 'f735d6b706d3' Created: 2020-07-10 16:24 @@ -20,18 +20,26 @@ import spanner_orm from spanner_orm import field +from spanner_orm import foreign_key_relationship -migration_id = 'f735d6b706d3' -prev_migration_id = 'f735d6b706d2' +migration_id = 'f735d6b706d4' +prev_migration_id = 'f735d6b706d3' class OriginalForeignKeyTestModelTable(spanner_orm.model.Model): """ORM Model with the original schema for the ForeignKeyTestModel table.""" __table__ = 'ForeignKeyTestModel' - __foreign_key__ = 'SmallTestModel' - key = field.Field(field.String, primary_key=True) - child_key = field.Field(field.String, primary_key=True) + referencing_key_1 = field.Field(field.String, primary_key=True) + referencing_key_2 = field.Field(field.String, primary_key=True) + referencing_key_3 = field.Field(field.Integer, primary_key=True) + value = field.Field(field.String) + foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( + 'SmallTestModel', {'referencing_key_1': 'key'}) + foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( + 'UnittestModel', + {'referencing_key_2': 'string', 'referencing_key_3': 'int_'}, + ) def upgrade() -> spanner_orm.CreateTable: diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index fde210a..e82c071 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -66,14 +66,16 @@ class ForeignKeyTestModel(model.Model): """Model class for testing foreign keys.""" __table__ = 'ForeignKeyTestModel' - # __foreign_key__ = 'SmallTestModel' - referencing_key = field.Field(field.String, primary_key=True) + referencing_key_1 = field.Field(field.String, primary_key=True) + referencing_key_2 = field.Field(field.String, primary_key=True) + referencing_key_3 = field.Field(field.Integer, primary_key=True) value = field.Field(field.String) - foreign_key_relationship = foreign_key_relationship.ForeignKeyRelationship( - 'SmallTestModel', {'referencing_key': 'key'}) - # single=True) - #parents = relationship.Relationship('spanner_orm.tests.models.SmallTestModel', - # {'parent_key': 'key'}) + foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( + 'SmallTestModel', {'referencing_key_1': 'key'}) + foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( + 'UnittestModel', + {'referencing_key_2': 'string', 'referencing_key_3': 'int_'}, + ) class InheritanceTestModel(SmallTestModel): """Model class used for testing model inheritance.""" diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index dca7e41..9b84b31 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -88,7 +88,7 @@ def test_create_table_interleaved(self, get_model): @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_create_table_foreign_key(self, get_model): - self.maxDiff = 1000 + self.maxDiff = 2000 get_model.return_value = None new_model = models.ForeignKeyTestModel @@ -97,10 +97,14 @@ def test_create_table_foreign_key(self, get_model): test_model_ddl = ( 'CREATE TABLE ForeignKeyTestModel (' - 'referencing_key STRING(MAX) NOT NULL, ' + 'referencing_key_1 STRING(MAX) NOT NULL, ' + 'referencing_key_2 STRING(MAX) NOT NULL, ' + 'referencing_key_3 INT64 NOT NULL, ' 'value STRING(MAX) NOT NULL, ' - 'FOREIGN KEY (referencing_key) REFERENCES SmallTestModel (key)) ' - 'PRIMARY KEY (referencing_key)') + 'FOREIGN KEY (referencing_key_1) REFERENCES SmallTestModel (key), ' + 'FOREIGN KEY (referencing_key_2) REFERENCES table (string), ' + 'FOREIGN KEY (referencing_key_3) REFERENCES table (int_)) ' + 'PRIMARY KEY (referencing_key_1, referencing_key_2, referencing_key_3)') self.assertEqual(test_update.ddl(), test_model_ddl) @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') From c87ab42e1f1a33ca69238b01b2721fcf255dc885 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Wed, 12 Aug 2020 14:23:06 -0400 Subject: [PATCH 010/131] Use a dataclass to encapsulate foreign key constraints --- spanner_orm/admin/update.py | 14 ++++----- spanner_orm/foreign_key_relationship.py | 42 +++++++++---------------- spanner_orm/metadata.py | 6 +++- spanner_orm/model.py | 8 ++--- 4 files changed, 28 insertions(+), 42 deletions(-) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 2fbaa4a..a77a101 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -55,14 +55,14 @@ def ddl(self) -> str: ] key_fields_ddl = ', '.join(key_fields) for relation in self._model.foreign_key_relations.values(): - for referencing_table_col, referenced_table_col in relation.constraints.items(): + for constraint in relation.constraints: key_fields_ddl += ( - ', FOREIGN KEY ({referencing_table_col}) REFERENCES' - ' {parent} ({referenced_table_col})').format( - parent=relation.destination, - referencing_table_col=referencing_table_col, - referenced_table_col=referenced_table_col, - ) + ', FOREIGN KEY ({referencing_column}) REFERENCES' + ' {referenced_table} ({referenced_column})').format( + referencing_column=constraint.referencing_column, + referenced_table=constraint.referenced_table_name, + referenced_column=constraint.referenced_column, + ) index_ddl = 'PRIMARY KEY ({})'.format(', '.join(self._model.primary_keys)) statement = 'CREATE TABLE {} ({}) {}'.format(self._model.table, key_fields_ddl, index_ddl) diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py index 2561f12..6dc7cb4 100644 --- a/spanner_orm/foreign_key_relationship.py +++ b/spanner_orm/foreign_key_relationship.py @@ -14,19 +14,17 @@ # limitations under the License. """Helps define a foreign key relationship between two models.""" -from typing import Any, List, Mapping, Type, Union +from typing import List, Mapping import dataclasses -from spanner_orm import error from spanner_orm import registry @dataclasses.dataclass class ForeignKeyRelationshipConstraint: referencing_column: str - referenced_columns: str + referenced_column: str referenced_table_name: str - class ForeignKeyRelationship(object): @@ -38,8 +36,7 @@ def __init__(self, """Creates a ForeignKeyRelationship. Args: - referenced_table_name: Destination model class or fully qualified class - name of the destination model. + referenced_table_name: Name of the table which the foreign key references. constraints: Dictionary where the keys are names of columns from the referencing table and the values are the names of the columns in the referenced table. @@ -52,31 +49,20 @@ def __init__(self, @property def constraints(self) -> List[ForeignKeyRelationshipConstraint]: - return self._constraints - - @property - def destination(self) -> Type[Any]: - return registry.model_registry().get(self._referenced_table_name).table - if not self._destination: - self._destination = registry.model_registry().get( - self._referenced_table_name) - return self._destination + return self._parse_constraints() def _parse_constraints(self) -> List[ForeignKeyRelationshipConstraint]: - """Validates the dictionary of constraints and turns it into Conditions.""" + """Returns a list of Constraints for the relationship.""" constraints = [] - for origin_column, destination_column in self._constraints.items(): - if origin_column not in self.origin.fields: - raise error.ValidationError( - 'Origin column must be present in origin model') - - if destination_column not in self.destination.fields: - raise error.ValidationError( - 'Destination column must be present in destination model') - - # TODO(dbrandao): remove when pytype #234 is fixed + referenced_table = registry.model_registry().get( + self._referenced_table_name) + for referencing_column, referenced_column in self._constraints.items(): constraints.append( - RelationshipConstraint(self.destination, destination_column, - self.origin, origin_column)) # type: ignore + ForeignKeyRelationshipConstraint( + referencing_column, + referenced_column, + referenced_table.table, + ) + ) return constraints diff --git a/spanner_orm/metadata.py b/spanner_orm/metadata.py index 1fd10b1..251dfc2 100644 --- a/spanner_orm/metadata.py +++ b/spanner_orm/metadata.py @@ -45,7 +45,11 @@ def __init__(self, table: Optional[str] = None, fields: Optional[Dict[str, field.Field]] = None, relations: Optional[Dict[str, relationship.Relationship]] = None, - foreign_key_relations: Optional[Dict[str, foreign_key_relationship.ForeignKeyRelationship]] = None, + foreign_key_relations: Optional[ + Dict[ + str, + foreign_key_relationship.ForeignKeyRelationship] + ] = None, indexes: Optional[Dict[str, index.Index]] = None, interleaved: Optional[str] = None, model_class: Optional[Type[Any]] = None): diff --git a/spanner_orm/model.py b/spanner_orm/model.py index c4dc130..c21999d 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -53,8 +53,6 @@ def __new__(mcs, name: str, bases: Any, attrs: Dict[str, Any], **kwargs: Any): model_metadata.table = value elif key == '__interleaved__': model_metadata.interleaved = value - elif key == '__foreign_key__': - model_metadata.foreign_key = value if isinstance(value, field.Field): model_metadata.add_field(key, value) elif isinstance(value, index.Index): @@ -121,11 +119,9 @@ def relations(cls) -> Dict[str, relationship.Relationship]: return cls.meta.relations @property - def foreign_key_relations(cls) -> Dict[str, foreign_key_relationship.ForeignKeyRelationship]: + def foreign_key_relations( + cls) -> Dict[str, foreign_key_relationship.ForeignKeyRelationship]: return cls.meta.foreign_key_relations - #if cls.meta.foreign_key: - # return registry.model_registry().get(cls.meta.foreign_key) - #return None @property From a95d8dc9c779f0aeba6158e1f799582f1924e426 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Wed, 12 Aug 2020 14:27:42 -0400 Subject: [PATCH 011/131] Minor style clean up --- spanner_orm/tests/migrations_emulator_test.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py index 8c9efb7..75d03ce 100644 --- a/spanner_orm/tests/migrations_emulator_test.py +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -17,11 +17,11 @@ import os import unittest +from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib +from spanner_orm.tests import models from google.api_core import exceptions as google_api_exceptions -import spanner_orm -from spanner_orm.tests import models -from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib + class MigrationsEmulatorTest(spanner_emulator_testlib.TestCase): TEST_MIGRATIONS_DIR = os.path.join( @@ -34,8 +34,7 @@ def setUp(self): self.run_orm_migrations(self.TEST_MIGRATIONS_DIR) def test_basic(self): - test_model = models.SmallTestModel({'key': 'key', 'value_1': 'value'}) - test_model.save() + models.SmallTestModel({'key': 'key', 'value_1': 'value'}).save() self.assertEqual( [x.values for x in models.SmallTestModel.all()], [{'key': 'key', 'value_1': 'value', 'value_2': None}], @@ -54,21 +53,19 @@ def test_error_with_missing_referencing_key(self): }).save() def test_key(self): - test_model = models.SmallTestModel({'key': 'key', 'value_1': 'value'}) - test_model.save() + models.SmallTestModel({'key': 'key', 'value_1': 'value'}).save() models.UnittestModel( {'string': 'string', 'int_': 42, 'float_': 4.2, 'timestamp': datetime.datetime.now(tz=datetime.timezone.utc), }).save() - test_model_2 = models.ForeignKeyTestModel({ + models.ForeignKeyTestModel({ 'referencing_key_1': 'key', 'referencing_key_2': 'string', 'referencing_key_3': 42, 'value': 'value' - }) - test_model_2.save() + }).save() if __name__ == '__main__': logging.basicConfig() From 45d480cba2c47b69ae72da74e69e9e283a40e08d Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Thu, 13 Aug 2020 12:39:28 -0400 Subject: [PATCH 012/131] Revert "Add support for setting foreign keys" --- setup.py | 2 +- spanner_orm/admin/update.py | 14 +--- spanner_orm/foreign_key_relationship.py | 68 ------------------- spanner_orm/metadata.py | 15 ---- spanner_orm/model.py | 12 ---- spanner_orm/tests/migrations_emulator_test.py | 38 ++--------- .../create_foreign_key_test_model.py | 52 -------------- spanner_orm/tests/models.py | 15 ---- spanner_orm/tests/update_test.py | 21 ------ 9 files changed, 8 insertions(+), 229 deletions(-) delete mode 100644 spanner_orm/foreign_key_relationship.py delete mode 100644 spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py diff --git a/setup.py b/setup.py index 011756a..566b38b 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ include_package_data=True, python_requires='~=3.7', install_requires=['google-cloud-spanner >= 1.6, <2.0.0dev'], - tests_require=['absl-py', 'google-api-core', 'portpicker'], + tests_require=['absl-py', 'portpicker'], entry_points={ 'console_scripts': ['spanner-orm = spanner_orm.admin.scripts:main'] }) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index a77a101..4277e10 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -49,23 +49,13 @@ def __init__(self, model_: Type[model.Model]): self._model = model_ def ddl(self) -> str: - key_fields = [ + fields = [ '{} {}'.format(name, field.ddl()) for name, field in self._model.fields.items() ] - key_fields_ddl = ', '.join(key_fields) - for relation in self._model.foreign_key_relations.values(): - for constraint in relation.constraints: - key_fields_ddl += ( - ', FOREIGN KEY ({referencing_column}) REFERENCES' - ' {referenced_table} ({referenced_column})').format( - referencing_column=constraint.referencing_column, - referenced_table=constraint.referenced_table_name, - referenced_column=constraint.referenced_column, - ) index_ddl = 'PRIMARY KEY ({})'.format(', '.join(self._model.primary_keys)) statement = 'CREATE TABLE {} ({}) {}'.format(self._model.table, - key_fields_ddl, index_ddl) + ', '.join(fields), index_ddl) if self._model.interleaved: statement += ', INTERLEAVE IN PARENT {parent} ON DELETE CASCADE'.format( diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py deleted file mode 100644 index 6dc7cb4..0000000 --- a/spanner_orm/foreign_key_relationship.py +++ /dev/null @@ -1,68 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Helps define a foreign key relationship between two models.""" - -from typing import List, Mapping - -import dataclasses -from spanner_orm import registry - - -@dataclasses.dataclass -class ForeignKeyRelationshipConstraint: - referencing_column: str - referenced_column: str - referenced_table_name: str - - -class ForeignKeyRelationship(object): - """Helps define a foreign key relationship between two models.""" - - def __init__(self, - referenced_table_name: str, - constraints: Mapping[str, str]): - """Creates a ForeignKeyRelationship. - - Args: - referenced_table_name: Name of the table which the foreign key references. - constraints: Dictionary where the keys are names of columns from the - referencing table and the values are the names of the columns in the - referenced table. - # TODO(dgorelik): Allow constraints to have custom names. - """ - self.origin = None - self.name = None - self._referenced_table_name = referenced_table_name - self._constraints = constraints - - @property - def constraints(self) -> List[ForeignKeyRelationshipConstraint]: - return self._parse_constraints() - - def _parse_constraints(self) -> List[ForeignKeyRelationshipConstraint]: - """Returns a list of Constraints for the relationship.""" - constraints = [] - referenced_table = registry.model_registry().get( - self._referenced_table_name) - for referencing_column, referenced_column in self._constraints.items(): - constraints.append( - ForeignKeyRelationshipConstraint( - referencing_column, - referenced_column, - referenced_table.table, - ) - ) - - return constraints diff --git a/spanner_orm/metadata.py b/spanner_orm/metadata.py index 251dfc2..8c7c5e2 100644 --- a/spanner_orm/metadata.py +++ b/spanner_orm/metadata.py @@ -32,7 +32,6 @@ from spanner_orm import error from spanner_orm import field -from spanner_orm import foreign_key_relationship from spanner_orm import index from spanner_orm import registry from spanner_orm import relationship @@ -45,11 +44,6 @@ def __init__(self, table: Optional[str] = None, fields: Optional[Dict[str, field.Field]] = None, relations: Optional[Dict[str, relationship.Relationship]] = None, - foreign_key_relations: Optional[ - Dict[ - str, - foreign_key_relationship.ForeignKeyRelationship] - ] = None, indexes: Optional[Dict[str, index.Index]] = None, interleaved: Optional[str] = None, model_class: Optional[Type[Any]] = None): @@ -61,7 +55,6 @@ def __init__(self, self.model_class = model_class self.primary_keys = [] self.relations = dict(relations or {}) - self.foreign_key_relations = dict(foreign_key_relations or {}) self.table = table or '' def finalize(self) -> None: @@ -108,14 +101,6 @@ def add_relation(self, name: str, new_relation.name = name self.relations[name] = new_relation - def add_foreign_key_relation( - self, - name: str, - new_relation: foreign_key_relationship.ForeignKeyRelationship, - ) -> None: - new_relation.name = name - self.foreign_key_relations[name] = new_relation - def add_index(self, name: str, new_index: index.Index) -> None: new_index.name = name self.indexes[name] = new_index diff --git a/spanner_orm/model.py b/spanner_orm/model.py index c21999d..06060cb 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -21,7 +21,6 @@ from spanner_orm import api from spanner_orm import condition from spanner_orm import error -from spanner_orm import foreign_key_relationship from spanner_orm import field from spanner_orm import index from spanner_orm import metadata @@ -59,11 +58,6 @@ def __new__(mcs, name: str, bases: Any, attrs: Dict[str, Any], **kwargs: Any): model_metadata.add_index(key, value) elif isinstance(value, relationship.Relationship): model_metadata.add_relation(key, value) - elif isinstance( - value, - foreign_key_relationship.ForeignKeyRelationship, - ): - model_metadata.add_foreign_key_relation(key, value) else: non_model_attrs[key] = value @@ -118,12 +112,6 @@ def primary_keys(cls) -> List[str]: def relations(cls) -> Dict[str, relationship.Relationship]: return cls.meta.relations - @property - def foreign_key_relations( - cls) -> Dict[str, foreign_key_relationship.ForeignKeyRelationship]: - return cls.meta.foreign_key_relations - - @property def fields(cls) -> Dict[str, field.Field]: return cls.meta.fields diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py index 75d03ce..3e4f8a1 100644 --- a/spanner_orm/tests/migrations_emulator_test.py +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -12,16 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import datetime import logging import os import unittest -from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib -from spanner_orm.tests import models - -from google.api_core import exceptions as google_api_exceptions +import spanner_orm +from spanner_orm.tests import models +from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib class MigrationsEmulatorTest(spanner_emulator_testlib.TestCase): TEST_MIGRATIONS_DIR = os.path.join( @@ -34,39 +32,13 @@ def setUp(self): self.run_orm_migrations(self.TEST_MIGRATIONS_DIR) def test_basic(self): - models.SmallTestModel({'key': 'key', 'value_1': 'value'}).save() + test_model = models.SmallTestModel({'key': 'key', 'value_1': 'value'}) + test_model.save() self.assertEqual( [x.values for x in models.SmallTestModel.all()], [{'key': 'key', 'value_1': 'value', 'value_2': None}], ) - def test_error_with_missing_referencing_key(self): - with self.assertRaisesRegex( - google_api_exceptions.FailedPrecondition, - 'Cannot find referenced key', - ): - models.ForeignKeyTestModel({ - 'referencing_key_1': 'key', - 'referencing_key_2': 'key', - 'referencing_key_3': 42, - 'value': 'value' - }).save() - - def test_key(self): - models.SmallTestModel({'key': 'key', 'value_1': 'value'}).save() - models.UnittestModel( - {'string': 'string', - 'int_': 42, - 'float_': 4.2, - 'timestamp': datetime.datetime.now(tz=datetime.timezone.utc), - }).save() - models.ForeignKeyTestModel({ - 'referencing_key_1': 'key', - 'referencing_key_2': 'string', - 'referencing_key_3': 42, - 'value': 'value' - }).save() - if __name__ == '__main__': logging.basicConfig() unittest.main() diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py deleted file mode 100644 index a88811d..0000000 --- a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py +++ /dev/null @@ -1,52 +0,0 @@ -# Lint as: python3 -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Creates table with ForeignKeyTestModel. - -Migration ID: 'f735d6b706d3' -Created: 2020-07-10 16:24 -""" - -import spanner_orm -from spanner_orm import field -from spanner_orm import foreign_key_relationship - -migration_id = 'f735d6b706d4' -prev_migration_id = 'f735d6b706d3' - - -class OriginalForeignKeyTestModelTable(spanner_orm.model.Model): - """ORM Model with the original schema for the ForeignKeyTestModel table.""" - - __table__ = 'ForeignKeyTestModel' - referencing_key_1 = field.Field(field.String, primary_key=True) - referencing_key_2 = field.Field(field.String, primary_key=True) - referencing_key_3 = field.Field(field.Integer, primary_key=True) - value = field.Field(field.String) - foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( - 'SmallTestModel', {'referencing_key_1': 'key'}) - foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( - 'UnittestModel', - {'referencing_key_2': 'string', 'referencing_key_3': 'int_'}, - ) - - -def upgrade() -> spanner_orm.CreateTable: - """See ORM migrations interface.""" - return spanner_orm.CreateTable(OriginalForeignKeyTestModelTable) - - -def downgrade() -> spanner_orm.DropTable: - """See ORM migrations interface.""" - return spanner_orm.DropTable(OriginalForeignKeyTestModelTable.__table__) diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index e82c071..9aef1ed 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -15,7 +15,6 @@ """Models used by unit tests.""" from spanner_orm import field -from spanner_orm import foreign_key_relationship from spanner_orm import index from spanner_orm import model from spanner_orm import relationship @@ -62,20 +61,6 @@ class RelationshipTestModel(model.Model): parents = relationship.Relationship('spanner_orm.tests.models.SmallTestModel', {'parent_key': 'key'}) -class ForeignKeyTestModel(model.Model): - """Model class for testing foreign keys.""" - - __table__ = 'ForeignKeyTestModel' - referencing_key_1 = field.Field(field.String, primary_key=True) - referencing_key_2 = field.Field(field.String, primary_key=True) - referencing_key_3 = field.Field(field.Integer, primary_key=True) - value = field.Field(field.String) - foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( - 'SmallTestModel', {'referencing_key_1': 'key'}) - foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( - 'UnittestModel', - {'referencing_key_2': 'string', 'referencing_key_3': 'int_'}, - ) class InheritanceTestModel(SmallTestModel): """Model class used for testing model inheritance.""" diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index 9b84b31..633b0c3 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -86,27 +86,6 @@ def test_create_table_interleaved(self, get_model): 'INTERLEAVE IN PARENT SmallTestModel ON DELETE CASCADE') self.assertEqual(test_update.ddl(), test_model_ddl) - @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') - def test_create_table_foreign_key(self, get_model): - self.maxDiff = 2000 - - get_model.return_value = None - new_model = models.ForeignKeyTestModel - test_update = update.CreateTable(new_model) - test_update.validate() - - test_model_ddl = ( - 'CREATE TABLE ForeignKeyTestModel (' - 'referencing_key_1 STRING(MAX) NOT NULL, ' - 'referencing_key_2 STRING(MAX) NOT NULL, ' - 'referencing_key_3 INT64 NOT NULL, ' - 'value STRING(MAX) NOT NULL, ' - 'FOREIGN KEY (referencing_key_1) REFERENCES SmallTestModel (key), ' - 'FOREIGN KEY (referencing_key_2) REFERENCES table (string), ' - 'FOREIGN KEY (referencing_key_3) REFERENCES table (int_)) ' - 'PRIMARY KEY (referencing_key_1, referencing_key_2, referencing_key_3)') - self.assertEqual(test_update.ddl(), test_model_ddl) - @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_create_table_error_on_existing_table(self, get_model): get_model.return_value = models.SmallTestModel From e38e1129d33c04738447d8419dce4d8b6ec73d47 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Thu, 13 Aug 2020 13:06:33 -0400 Subject: [PATCH 013/131] Revert "Revert "Add support for setting foreign keys"" --- setup.py | 2 +- spanner_orm/admin/update.py | 14 +++- spanner_orm/foreign_key_relationship.py | 68 +++++++++++++++++++ spanner_orm/metadata.py | 15 ++++ spanner_orm/model.py | 12 ++++ spanner_orm/tests/migrations_emulator_test.py | 38 +++++++++-- .../create_foreign_key_test_model.py | 52 ++++++++++++++ spanner_orm/tests/models.py | 15 ++++ spanner_orm/tests/update_test.py | 21 ++++++ 9 files changed, 229 insertions(+), 8 deletions(-) create mode 100644 spanner_orm/foreign_key_relationship.py create mode 100644 spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py diff --git a/setup.py b/setup.py index 566b38b..011756a 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ include_package_data=True, python_requires='~=3.7', install_requires=['google-cloud-spanner >= 1.6, <2.0.0dev'], - tests_require=['absl-py', 'portpicker'], + tests_require=['absl-py', 'google-api-core', 'portpicker'], entry_points={ 'console_scripts': ['spanner-orm = spanner_orm.admin.scripts:main'] }) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 4277e10..a77a101 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -49,13 +49,23 @@ def __init__(self, model_: Type[model.Model]): self._model = model_ def ddl(self) -> str: - fields = [ + key_fields = [ '{} {}'.format(name, field.ddl()) for name, field in self._model.fields.items() ] + key_fields_ddl = ', '.join(key_fields) + for relation in self._model.foreign_key_relations.values(): + for constraint in relation.constraints: + key_fields_ddl += ( + ', FOREIGN KEY ({referencing_column}) REFERENCES' + ' {referenced_table} ({referenced_column})').format( + referencing_column=constraint.referencing_column, + referenced_table=constraint.referenced_table_name, + referenced_column=constraint.referenced_column, + ) index_ddl = 'PRIMARY KEY ({})'.format(', '.join(self._model.primary_keys)) statement = 'CREATE TABLE {} ({}) {}'.format(self._model.table, - ', '.join(fields), index_ddl) + key_fields_ddl, index_ddl) if self._model.interleaved: statement += ', INTERLEAVE IN PARENT {parent} ON DELETE CASCADE'.format( diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py new file mode 100644 index 0000000..6dc7cb4 --- /dev/null +++ b/spanner_orm/foreign_key_relationship.py @@ -0,0 +1,68 @@ +# python3 +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helps define a foreign key relationship between two models.""" + +from typing import List, Mapping + +import dataclasses +from spanner_orm import registry + + +@dataclasses.dataclass +class ForeignKeyRelationshipConstraint: + referencing_column: str + referenced_column: str + referenced_table_name: str + + +class ForeignKeyRelationship(object): + """Helps define a foreign key relationship between two models.""" + + def __init__(self, + referenced_table_name: str, + constraints: Mapping[str, str]): + """Creates a ForeignKeyRelationship. + + Args: + referenced_table_name: Name of the table which the foreign key references. + constraints: Dictionary where the keys are names of columns from the + referencing table and the values are the names of the columns in the + referenced table. + # TODO(dgorelik): Allow constraints to have custom names. + """ + self.origin = None + self.name = None + self._referenced_table_name = referenced_table_name + self._constraints = constraints + + @property + def constraints(self) -> List[ForeignKeyRelationshipConstraint]: + return self._parse_constraints() + + def _parse_constraints(self) -> List[ForeignKeyRelationshipConstraint]: + """Returns a list of Constraints for the relationship.""" + constraints = [] + referenced_table = registry.model_registry().get( + self._referenced_table_name) + for referencing_column, referenced_column in self._constraints.items(): + constraints.append( + ForeignKeyRelationshipConstraint( + referencing_column, + referenced_column, + referenced_table.table, + ) + ) + + return constraints diff --git a/spanner_orm/metadata.py b/spanner_orm/metadata.py index 8c7c5e2..251dfc2 100644 --- a/spanner_orm/metadata.py +++ b/spanner_orm/metadata.py @@ -32,6 +32,7 @@ from spanner_orm import error from spanner_orm import field +from spanner_orm import foreign_key_relationship from spanner_orm import index from spanner_orm import registry from spanner_orm import relationship @@ -44,6 +45,11 @@ def __init__(self, table: Optional[str] = None, fields: Optional[Dict[str, field.Field]] = None, relations: Optional[Dict[str, relationship.Relationship]] = None, + foreign_key_relations: Optional[ + Dict[ + str, + foreign_key_relationship.ForeignKeyRelationship] + ] = None, indexes: Optional[Dict[str, index.Index]] = None, interleaved: Optional[str] = None, model_class: Optional[Type[Any]] = None): @@ -55,6 +61,7 @@ def __init__(self, self.model_class = model_class self.primary_keys = [] self.relations = dict(relations or {}) + self.foreign_key_relations = dict(foreign_key_relations or {}) self.table = table or '' def finalize(self) -> None: @@ -101,6 +108,14 @@ def add_relation(self, name: str, new_relation.name = name self.relations[name] = new_relation + def add_foreign_key_relation( + self, + name: str, + new_relation: foreign_key_relationship.ForeignKeyRelationship, + ) -> None: + new_relation.name = name + self.foreign_key_relations[name] = new_relation + def add_index(self, name: str, new_index: index.Index) -> None: new_index.name = name self.indexes[name] = new_index diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 06060cb..c21999d 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -21,6 +21,7 @@ from spanner_orm import api from spanner_orm import condition from spanner_orm import error +from spanner_orm import foreign_key_relationship from spanner_orm import field from spanner_orm import index from spanner_orm import metadata @@ -58,6 +59,11 @@ def __new__(mcs, name: str, bases: Any, attrs: Dict[str, Any], **kwargs: Any): model_metadata.add_index(key, value) elif isinstance(value, relationship.Relationship): model_metadata.add_relation(key, value) + elif isinstance( + value, + foreign_key_relationship.ForeignKeyRelationship, + ): + model_metadata.add_foreign_key_relation(key, value) else: non_model_attrs[key] = value @@ -112,6 +118,12 @@ def primary_keys(cls) -> List[str]: def relations(cls) -> Dict[str, relationship.Relationship]: return cls.meta.relations + @property + def foreign_key_relations( + cls) -> Dict[str, foreign_key_relationship.ForeignKeyRelationship]: + return cls.meta.foreign_key_relations + + @property def fields(cls) -> Dict[str, field.Field]: return cls.meta.fields diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py index 3e4f8a1..75d03ce 100644 --- a/spanner_orm/tests/migrations_emulator_test.py +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -12,14 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime import logging import os import unittest - -import spanner_orm -from spanner_orm.tests import models from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib +from spanner_orm.tests import models + +from google.api_core import exceptions as google_api_exceptions + class MigrationsEmulatorTest(spanner_emulator_testlib.TestCase): TEST_MIGRATIONS_DIR = os.path.join( @@ -32,13 +34,39 @@ def setUp(self): self.run_orm_migrations(self.TEST_MIGRATIONS_DIR) def test_basic(self): - test_model = models.SmallTestModel({'key': 'key', 'value_1': 'value'}) - test_model.save() + models.SmallTestModel({'key': 'key', 'value_1': 'value'}).save() self.assertEqual( [x.values for x in models.SmallTestModel.all()], [{'key': 'key', 'value_1': 'value', 'value_2': None}], ) + def test_error_with_missing_referencing_key(self): + with self.assertRaisesRegex( + google_api_exceptions.FailedPrecondition, + 'Cannot find referenced key', + ): + models.ForeignKeyTestModel({ + 'referencing_key_1': 'key', + 'referencing_key_2': 'key', + 'referencing_key_3': 42, + 'value': 'value' + }).save() + + def test_key(self): + models.SmallTestModel({'key': 'key', 'value_1': 'value'}).save() + models.UnittestModel( + {'string': 'string', + 'int_': 42, + 'float_': 4.2, + 'timestamp': datetime.datetime.now(tz=datetime.timezone.utc), + }).save() + models.ForeignKeyTestModel({ + 'referencing_key_1': 'key', + 'referencing_key_2': 'string', + 'referencing_key_3': 42, + 'value': 'value' + }).save() + if __name__ == '__main__': logging.basicConfig() unittest.main() diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py new file mode 100644 index 0000000..a88811d --- /dev/null +++ b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py @@ -0,0 +1,52 @@ +# Lint as: python3 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Creates table with ForeignKeyTestModel. + +Migration ID: 'f735d6b706d3' +Created: 2020-07-10 16:24 +""" + +import spanner_orm +from spanner_orm import field +from spanner_orm import foreign_key_relationship + +migration_id = 'f735d6b706d4' +prev_migration_id = 'f735d6b706d3' + + +class OriginalForeignKeyTestModelTable(spanner_orm.model.Model): + """ORM Model with the original schema for the ForeignKeyTestModel table.""" + + __table__ = 'ForeignKeyTestModel' + referencing_key_1 = field.Field(field.String, primary_key=True) + referencing_key_2 = field.Field(field.String, primary_key=True) + referencing_key_3 = field.Field(field.Integer, primary_key=True) + value = field.Field(field.String) + foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( + 'SmallTestModel', {'referencing_key_1': 'key'}) + foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( + 'UnittestModel', + {'referencing_key_2': 'string', 'referencing_key_3': 'int_'}, + ) + + +def upgrade() -> spanner_orm.CreateTable: + """See ORM migrations interface.""" + return spanner_orm.CreateTable(OriginalForeignKeyTestModelTable) + + +def downgrade() -> spanner_orm.DropTable: + """See ORM migrations interface.""" + return spanner_orm.DropTable(OriginalForeignKeyTestModelTable.__table__) diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 9aef1ed..e82c071 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -15,6 +15,7 @@ """Models used by unit tests.""" from spanner_orm import field +from spanner_orm import foreign_key_relationship from spanner_orm import index from spanner_orm import model from spanner_orm import relationship @@ -61,6 +62,20 @@ class RelationshipTestModel(model.Model): parents = relationship.Relationship('spanner_orm.tests.models.SmallTestModel', {'parent_key': 'key'}) +class ForeignKeyTestModel(model.Model): + """Model class for testing foreign keys.""" + + __table__ = 'ForeignKeyTestModel' + referencing_key_1 = field.Field(field.String, primary_key=True) + referencing_key_2 = field.Field(field.String, primary_key=True) + referencing_key_3 = field.Field(field.Integer, primary_key=True) + value = field.Field(field.String) + foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( + 'SmallTestModel', {'referencing_key_1': 'key'}) + foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( + 'UnittestModel', + {'referencing_key_2': 'string', 'referencing_key_3': 'int_'}, + ) class InheritanceTestModel(SmallTestModel): """Model class used for testing model inheritance.""" diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index 633b0c3..9b84b31 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -86,6 +86,27 @@ def test_create_table_interleaved(self, get_model): 'INTERLEAVE IN PARENT SmallTestModel ON DELETE CASCADE') self.assertEqual(test_update.ddl(), test_model_ddl) + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') + def test_create_table_foreign_key(self, get_model): + self.maxDiff = 2000 + + get_model.return_value = None + new_model = models.ForeignKeyTestModel + test_update = update.CreateTable(new_model) + test_update.validate() + + test_model_ddl = ( + 'CREATE TABLE ForeignKeyTestModel (' + 'referencing_key_1 STRING(MAX) NOT NULL, ' + 'referencing_key_2 STRING(MAX) NOT NULL, ' + 'referencing_key_3 INT64 NOT NULL, ' + 'value STRING(MAX) NOT NULL, ' + 'FOREIGN KEY (referencing_key_1) REFERENCES SmallTestModel (key), ' + 'FOREIGN KEY (referencing_key_2) REFERENCES table (string), ' + 'FOREIGN KEY (referencing_key_3) REFERENCES table (int_)) ' + 'PRIMARY KEY (referencing_key_1, referencing_key_2, referencing_key_3)') + self.assertEqual(test_update.ddl(), test_model_ddl) + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_create_table_error_on_existing_table(self, get_model): get_model.return_value = models.SmallTestModel From 5249af6ef4b598d38b28b67fcd30fc30f264c451 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Fri, 14 Aug 2020 10:02:51 -0400 Subject: [PATCH 014/131] Add support for setting foreign key names --- spanner_orm/admin/update.py | 18 +++---- spanner_orm/foreign_key_relationship.py | 37 ++++++-------- .../create_foreign_key_test_model.py | 6 ++- .../create_unittest_model.py | 48 +++++++++++++++++++ spanner_orm/tests/models.py | 8 +++- spanner_orm/tests/update_test.py | 14 ++++-- 6 files changed, 91 insertions(+), 40 deletions(-) create mode 100644 spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index a77a101..44050ba 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -55,14 +55,16 @@ def ddl(self) -> str: ] key_fields_ddl = ', '.join(key_fields) for relation in self._model.foreign_key_relations.values(): - for constraint in relation.constraints: - key_fields_ddl += ( - ', FOREIGN KEY ({referencing_column}) REFERENCES' - ' {referenced_table} ({referenced_column})').format( - referencing_column=constraint.referencing_column, - referenced_table=constraint.referenced_table_name, - referenced_column=constraint.referenced_column, - ) + referencing_columns_ddl = ', '.join(relation.constraint.columns.keys()) + referenced_columns_ddl = ', '.join(relation.constraint.columns.values()) + key_fields_ddl += ( + ', CONSTRAINT {fk_name} FOREIGN KEY ({referencing_columns}) REFERENCES' + ' {referenced_table} ({referenced_columns})').format( + fk_name=relation.name, + referencing_columns=referencing_columns_ddl, + referenced_table=relation.constraint.referenced_table_name, + referenced_columns=referenced_columns_ddl, + ) index_ddl = 'PRIMARY KEY ({})'.format(', '.join(self._model.primary_keys)) statement = 'CREATE TABLE {} ({}) {}'.format(self._model.table, key_fields_ddl, index_ddl) diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py index 6dc7cb4..ed4b579 100644 --- a/spanner_orm/foreign_key_relationship.py +++ b/spanner_orm/foreign_key_relationship.py @@ -1,5 +1,5 @@ # python3 -# Copyright 2019 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # limitations under the License. """Helps define a foreign key relationship between two models.""" -from typing import List, Mapping +from typing import Mapping import dataclasses from spanner_orm import registry @@ -22,8 +22,7 @@ @dataclasses.dataclass class ForeignKeyRelationshipConstraint: - referencing_column: str - referenced_column: str + columns: Mapping[str, str] referenced_table_name: str @@ -32,37 +31,29 @@ class ForeignKeyRelationship(object): def __init__(self, referenced_table_name: str, - constraints: Mapping[str, str]): + columns: Mapping[str, str]): """Creates a ForeignKeyRelationship. Args: referenced_table_name: Name of the table which the foreign key references. - constraints: Dictionary where the keys are names of columns from the + columns: Dictionary where the keys are names of columns from the referencing table and the values are the names of the columns in the referenced table. - # TODO(dgorelik): Allow constraints to have custom names. """ self.origin = None self.name = None self._referenced_table_name = referenced_table_name - self._constraints = constraints + self._columns = columns @property - def constraints(self) -> List[ForeignKeyRelationshipConstraint]: - return self._parse_constraints() + def constraint(self) -> ForeignKeyRelationshipConstraint: + return self._parse_constraint() - def _parse_constraints(self) -> List[ForeignKeyRelationshipConstraint]: - """Returns a list of Constraints for the relationship.""" - constraints = [] + def _parse_constraint(self) -> ForeignKeyRelationshipConstraint: + """Return the relationship constraint.""" referenced_table = registry.model_registry().get( self._referenced_table_name) - for referencing_column, referenced_column in self._constraints.items(): - constraints.append( - ForeignKeyRelationshipConstraint( - referencing_column, - referenced_column, - referenced_table.table, - ) - ) - - return constraints + return ForeignKeyRelationshipConstraint( + self._columns, + referenced_table.table, + ) diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py index a88811d..b8e776b 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py @@ -33,13 +33,15 @@ class OriginalForeignKeyTestModelTable(spanner_orm.model.Model): referencing_key_1 = field.Field(field.String, primary_key=True) referencing_key_2 = field.Field(field.String, primary_key=True) referencing_key_3 = field.Field(field.Integer, primary_key=True) - value = field.Field(field.String) + self_referencing_key = field.Field(field.String, nullable=True) foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( - 'SmallTestModel', {'referencing_key_1': 'key'}) + 'SmallTestModel', {'referencing_key_1': 'key'}) foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( 'UnittestModel', {'referencing_key_2': 'string', 'referencing_key_3': 'int_'}, ) + foreign_key_3 = foreign_key_relationship.ForeignKeyRelationship( + 'ForeignKeyTestModel', {'self_referencing_key': 'referencing_key_1'}) def upgrade() -> spanner_orm.CreateTable: diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py new file mode 100644 index 0000000..91fc5fe --- /dev/null +++ b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py @@ -0,0 +1,48 @@ +# Lint as: python3 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Creates table with UnittestModel. + +Migration ID: 'f735d6b706d3' +Created: 2020-07-10 16:24 +""" + +import spanner_orm +from spanner_orm import field + +migration_id = 'f735d6b706d3' +prev_migration_id = 'f735d6b706d2' + + +class OriginalUnittestModelTable(spanner_orm.model.Model): + """ORM Model with the original schema for the UnittestModel table.""" + + __table__ = 'table' + int_ = field.Field(field.Integer, primary_key=True) + int_2 = field.Field(field.Integer, nullable=True) + float_ = field.Field(field.Float, primary_key=True) + float_2 = field.Field(field.Float, nullable=True) + string = field.Field(field.String, primary_key=True) + string_2 = field.Field(field.String, nullable=True) + timestamp = field.Field(field.Timestamp) + string_array = field.Field(field.StringArray, nullable=True) + +def upgrade() -> spanner_orm.CreateTable: + """See ORM migrations interface.""" + return spanner_orm.CreateTable(OriginalUnittestModelTable) + + +def downgrade() -> spanner_orm.DropTable: + """See ORM migrations interface.""" + return spanner_orm.DropTable(OriginalUnittestModelTable.__table__) diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index e82c071..f8cc307 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -62,6 +62,7 @@ class RelationshipTestModel(model.Model): parents = relationship.Relationship('spanner_orm.tests.models.SmallTestModel', {'parent_key': 'key'}) + class ForeignKeyTestModel(model.Model): """Model class for testing foreign keys.""" @@ -69,13 +70,16 @@ class ForeignKeyTestModel(model.Model): referencing_key_1 = field.Field(field.String, primary_key=True) referencing_key_2 = field.Field(field.String, primary_key=True) referencing_key_3 = field.Field(field.Integer, primary_key=True) - value = field.Field(field.String) + self_referencing_key = field.Field(field.String, nullable=True) foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( - 'SmallTestModel', {'referencing_key_1': 'key'}) + 'SmallTestModel', {'referencing_key_1': 'key'}) foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( 'UnittestModel', {'referencing_key_2': 'string', 'referencing_key_3': 'int_'}, ) + foreign_key_3 = foreign_key_relationship.ForeignKeyRelationship( + 'ForeignKeyTestModel', {'self_referencing_key': 'referencing_key_1'}) + class InheritanceTestModel(SmallTestModel): """Model class used for testing model inheritance.""" diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index 9b84b31..a96190f 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -89,7 +89,7 @@ def test_create_table_interleaved(self, get_model): @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_create_table_foreign_key(self, get_model): self.maxDiff = 2000 - + get_model.return_value = None new_model = models.ForeignKeyTestModel test_update = update.CreateTable(new_model) @@ -100,10 +100,14 @@ def test_create_table_foreign_key(self, get_model): 'referencing_key_1 STRING(MAX) NOT NULL, ' 'referencing_key_2 STRING(MAX) NOT NULL, ' 'referencing_key_3 INT64 NOT NULL, ' - 'value STRING(MAX) NOT NULL, ' - 'FOREIGN KEY (referencing_key_1) REFERENCES SmallTestModel (key), ' - 'FOREIGN KEY (referencing_key_2) REFERENCES table (string), ' - 'FOREIGN KEY (referencing_key_3) REFERENCES table (int_)) ' + 'self_referencing_key STRING(MAX), ' + 'CONSTRAINT foreign_key_1 FOREIGN KEY (referencing_key_1) ' + 'REFERENCES SmallTestModel (key), ' + 'CONSTRAINT foreign_key_2 ' + 'FOREIGN KEY (referencing_key_2, referencing_key_3) ' + 'REFERENCES table (string, int_), ' + 'CONSTRAINT foreign_key_3 FOREIGN KEY (self_referencing_key) ' + 'REFERENCES ForeignKeyTestModel (referencing_key_1)) ' 'PRIMARY KEY (referencing_key_1, referencing_key_2, referencing_key_3)') self.assertEqual(test_update.ddl(), test_model_ddl) From 5aedcc696fb54ed1d1f3e5888adabea3fbb72c40 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Mon, 14 Sep 2020 15:11:30 -0400 Subject: [PATCH 015/131] Implement most of functionality for includes --- spanner_orm/admin/update.py | 64 +++++++-- spanner_orm/condition.py | 107 ++++++++++---- spanner_orm/foreign_key_relationship.py | 21 ++- spanner_orm/model.py | 17 ++- spanner_orm/tests/models.py | 35 +++-- spanner_orm/tests/query_test.py | 179 +++++++++++++++++++----- 6 files changed, 332 insertions(+), 91 deletions(-) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 44050ba..5f961cf 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -20,6 +20,7 @@ from spanner_orm import condition from spanner_orm import error from spanner_orm import field +from spanner_orm import foreign_key_relationship from spanner_orm import model from spanner_orm.admin import api from spanner_orm.admin import index_column @@ -55,16 +56,7 @@ def ddl(self) -> str: ] key_fields_ddl = ', '.join(key_fields) for relation in self._model.foreign_key_relations.values(): - referencing_columns_ddl = ', '.join(relation.constraint.columns.keys()) - referenced_columns_ddl = ', '.join(relation.constraint.columns.values()) - key_fields_ddl += ( - ', CONSTRAINT {fk_name} FOREIGN KEY ({referencing_columns}) REFERENCES' - ' {referenced_table} ({referenced_columns})').format( - fk_name=relation.name, - referencing_columns=referencing_columns_ddl, - referenced_table=relation.constraint.referenced_table_name, - referenced_columns=referenced_columns_ddl, - ) + key_fields_ddl += f', {relation.ddl}' index_ddl = 'PRIMARY KEY ({})'.format(', '.join(self._model.primary_keys)) statement = 'CREATE TABLE {} ({}) {}'.format(self._model.table, key_fields_ddl, index_ddl) @@ -327,6 +319,58 @@ def validate(self) -> None: raise error.SpannerError('Index {} is the primary index'.format( self._index)) +class AddForeignKeyRelationship(SchemaUpdate): + """Update for adding a column to an existing table. + + Only supports adding nullable columns + """ + + def __init__( + self, + referencing_table_name: str, + referenced_table_name: str, + column_mapping, + ): + self._table = table_name + self._column = column_name + self._field = field_ + + def ddl(self) -> str: + return 'ALTER TABLE {} ADD'.format(self._table, self._column, + self._field.ddl()) + + def validate(self) -> None: + model_ = metadata.SpannerMetadata.model(self._table) + if not model_: + raise error.SpannerError('Table {} does not exist'.format(self._table)) + + +class DropForeignKeyRelationship(SchemaUpdate): + """Update for dropping a column from an existing table.""" + + def __init__(self, table_name: str, column_name: str): + self._table = table_name + self._column = column_name + + def ddl(self) -> str: + return 'ALTER TABLE {} DROP COLUMN {}'.format(self._table, self._column) + + def validate(self) -> None: + model_ = metadata.SpannerMetadata.model(self._table) + if not model_: + raise error.SpannerError('Table {} does not exist'.format(self._table)) + + if self._column not in model_.fields: + raise error.SpannerError('Column {} does not exist on {}'.format( + self._column, self._table)) + + # Verify no indices exist on the column we're trying to drop + num_indexed_columns = index_column.IndexColumnSchema.count( + None, condition.equal_to('column_name', self._column), + condition.equal_to('table_name', self._table)) + if num_indexed_columns > 0: + raise error.SpannerError('Column {} is indexed'.format(self._column)) + class NoUpdate(SchemaUpdate): """Update that does nothing, for migrations that don't update db schemas.""" diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 0174930..c28efec 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -20,6 +20,7 @@ from spanner_orm import error from spanner_orm import field +from spanner_orm import foreign_key_relationship from spanner_orm import index from spanner_orm import relationship @@ -200,11 +201,38 @@ def _validate(self, model_class: Type[Any]) -> None: class IncludesCondition(Condition): """Used to include related model_classs via a relation in a Spanner query.""" - def __init__(self, - relation_or_name: Union[relationship.Relationship, str], - conditions: List[Condition] = None): + def __init__( + self, + relation_or_name: Union[relationship.Relationship, + foreign_key_relationship.ForeignKeyRelationship, + str], + conditions: List[Condition] = None, + foreign_key_relation=False, + ): + """Initializer. + + + Args: + relation: Name of the relationship on the origin model or the Relationship/ + ForeignKeyRelationship on the origin model class used to retrieve + associated objects + conditions: Conditions to apply on the subquery + foreign_key_relation: True if the relation is a foreign key relation, + False if it is a legacy relation (eg not enforced in Spanner) + """ super().__init__() + self.foreign_key_relation = foreign_key_relation if isinstance(relation_or_name, relationship.Relationship): + if foreign_key_relation: + raise ValueError( + 'Must pass foreign key relation if ''`foreign_key_relation=True`.') + self.name = relation_or_name.name + self.relation = relation_or_name + elif isinstance(relation_or_name, + foreign_key_relationship.ForeignKeyRelationship): + if not foreign_key_relation: + raise ValueError( + 'Must pass legacy relation if `foreign_key_relation=False`.') self.name = relation_or_name.name self.relation = relation_or_name else: @@ -214,7 +242,10 @@ def __init__(self, def bind(self, model_class: Type[Any]) -> None: super().bind(model_class) - self.relation = self.model_class.relations[self.name] + if self.foreign_key_relation: + self.relation = self.model_class.foreign_key_relations[self.name] + else: + self.relation = self.model_class.relations[self.name] @property def conditions(self) -> List[Condition]: @@ -223,13 +254,20 @@ def conditions(self) -> List[Condition]: raise error.SpannerError( 'Condition must be bound before conditions is called') relation_conditions = [] - for constraint in self.relation.constraints: - # This is backward from what you might imagine because the condition will - # be processed from the context of the destination model - relation_conditions.append( - ColumnsEqualCondition(constraint.destination_column, - constraint.origin_class, - constraint.origin_column)) + if not self.foreign_key_relation: + for constraint in self.relation.constraints: + # This is backward from what you might imagine because the condition + # will be processed from the context of the destination model. + relation_conditions.append( + ColumnsEqualCondition(constraint.destination_column, + constraint.origin_class, + constraint.origin_column)) + else: + for pair in self.relation.constraint.columns.items(): + referencing_column, referenced_column = pair + relation_conditions.append( + ColumnsEqualCondition(referenced_column, self.model_class, + referencing_column)) return relation_conditions + self._conditions @property @@ -237,7 +275,10 @@ def destination(self) -> Type[Any]: if not self.relation: raise error.SpannerError( 'Condition must be bound before destination is called') - return self.relation.destination + if self.foreign_key_relation: + return self.relation.constraint.referenced_table + else: + return self.relation.destination @property def relation_name(self) -> str: @@ -263,14 +304,25 @@ def _types(self) -> Dict[str, type_pb2.Type]: return {} def _validate(self, model_class: Type[Any]) -> None: - if self.name not in model_class.relations: - raise error.ValidationError('{} is not a relation on {}'.format( - self.name, model_class.table)) - if self.relation and self.relation != model_class.relations[self.name]: - raise error.ValidationError('{} does not belong to {}'.format( - self.relation.name, model_class.table)) + if self.foreign_key_relation: + if self.name not in model_class.foreign_key_relations: + raise error.ValidationError('{} is not a relation on {}'.format( + self.name, model_class.table)) + if self.relation and self.relation != model_class.foreign_key_relations[ + self.name]: + raise error.ValidationError('{} does not belong to {}'.format( + self.relation.name, model_class.table)) + other_model_class = model_class.foreign_key_relations[ + self.name].constraint.referenced_table + else: + if self.name not in model_class.relations: + raise error.ValidationError('{} is not a relation on {}'.format( + self.name, model_class.table)) + if self.relation and self.relation != model_class.relations[self.name]: + raise error.ValidationError('{} does not belong to {}'.format( + self.relation.name, model_class.table)) + other_model_class = model_class.relations[self.name].destination - other_model_class = model_class.relations[self.name].destination for condition in self._conditions: condition._validate(other_model_class) # pylint: disable=protected-access @@ -629,8 +681,11 @@ def greater_than_or_equal_to(column: Union[field.Field, str], return ComparisonCondition('>=', column, value) -def includes(relation: Union[relationship.Relationship, str], - conditions: List[Condition] = None) -> IncludesCondition: +def includes(relation: Union[relationship.Relationship, + foreign_key_relationship.ForeignKeyRelationship, + str], + conditions: List[Condition] = None, + foreign_key_relation: bool = False) -> IncludesCondition: """Condition where the objects associated with a relationship are retrieved. Note that the query formed by this call is not a JOIN, but instead a @@ -639,14 +694,18 @@ def includes(relation: Union[relationship.Relationship, str], subquery may be included, but not all conditions may apply Args: - relation: Name of the relationship on the origin model or the Relationship - on the origin model class used to retrievec associated objects + relation: Name of the relationship on the origin model or the Relationship/ + ForeignKeyRelationship on the origin model class used to retrieve + associated objects conditions: Conditions to apply on the subquery + foreign_key_relation: True if the relation is a foreign key relation, + False if it is a legacy relation (eg not enforced in Spanner) Returns: A Condition subclass that will be used in the query """ - return IncludesCondition(relation, conditions) + return IncludesCondition( + relation, conditions, foreign_key_relation) def in_list(column: Union[field.Field, str], diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py index ed4b579..d03ebb0 100644 --- a/spanner_orm/foreign_key_relationship.py +++ b/spanner_orm/foreign_key_relationship.py @@ -14,7 +14,7 @@ # limitations under the License. """Helps define a foreign key relationship between two models.""" -from typing import Mapping +from typing import Any, Mapping import dataclasses from spanner_orm import registry @@ -24,6 +24,7 @@ class ForeignKeyRelationshipConstraint: columns: Mapping[str, str] referenced_table_name: str + referenced_table: Any class ForeignKeyRelationship(object): @@ -49,6 +50,19 @@ def __init__(self, def constraint(self) -> ForeignKeyRelationshipConstraint: return self._parse_constraint() + @property + def ddl(self) -> str: + referencing_columns_ddl = ', '.join(self.constraint.columns.keys()) + referenced_columns_ddl = ', '.join(self.constraint.columns.values()) + return ( + 'CONSTRAINT {fk_name} FOREIGN KEY ({referencing_columns}) REFERENCES' + ' {referenced_table} ({referenced_columns})').format( + fk_name=self.name, + referencing_columns=referencing_columns_ddl, + referenced_table=self.constraint.referenced_table_name, + referenced_columns=referenced_columns_ddl, + ) + def _parse_constraint(self) -> ForeignKeyRelationshipConstraint: """Return the relationship constraint.""" referenced_table = registry.model_registry().get( @@ -56,4 +70,9 @@ def _parse_constraint(self) -> ForeignKeyRelationshipConstraint: return ForeignKeyRelationshipConstraint( self._columns, referenced_table.table, + referenced_table, ) + + @property + def single(self) -> bool: + return True diff --git a/spanner_orm/model.py b/spanner_orm/model.py index c21999d..258b269 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -79,13 +79,19 @@ def __new__(mcs, name: str, bases: Any, attrs: Dict[str, Any], **kwargs: Any): def __getattr__( cls, - name: str) -> Union[field.Field, relationship.Relationship, index.Index]: + name: str) -> Union[ + field.Field, + relationship.Relationship, + foreign_key_relationship.ForeignKeyRelationship, + index.Index]: # Unclear why pylint doesn't like this # pylint: disable=unsupported-membership-test if name in cls.fields: return cls.fields[name] elif name in cls.relations: return cls.relations[name] + elif name in cls.foreign_key_relations: + return cls.foreign_key_relations[name] elif name in cls.indexes: return cls.indexes[name] # pylint: enable=unsupported-membership-test @@ -484,6 +490,10 @@ def __init__(self, values: Dict[str, Any], persisted: bool = False): if relation in values: self.__dict__[relation] = values[relation] + for foreign_key_relation in self._foreign_key_relations: + if foreign_key_relation in values: + self.__dict__[foreign_key_relation] = values[foreign_key_relation] + def __setattr__(self, name: str, value: Any) -> None: if name in self._relations: raise AttributeError(name) @@ -513,6 +523,11 @@ def _primary_keys(self) -> List[str]: def _relations(self) -> Dict[str, relationship.Relationship]: return self._metaclass.relations + @property + def _foreign_key_relations( + self) -> Dict[str, foreign_key_relationship.ForeignKeyRelationship]: + return self._metaclass.foreign_key_relations + @property def _table(self) -> str: return self._metaclass.table diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index f8cc307..b257fb2 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -62,6 +62,20 @@ class RelationshipTestModel(model.Model): parents = relationship.Relationship('spanner_orm.tests.models.SmallTestModel', {'parent_key': 'key'}) +class UnittestModel(model.Model): + """Model class used for model testing.""" + + __table__ = 'table' + int_ = field.Field(field.Integer, primary_key=True) + int_2 = field.Field(field.Integer, nullable=True) + float_ = field.Field(field.Float, primary_key=True) + float_2 = field.Field(field.Float, nullable=True) + string = field.Field(field.String, primary_key=True) + string_2 = field.Field(field.String, nullable=True) + timestamp = field.Field(field.Timestamp) + string_array = field.Field(field.StringArray, nullable=True) + + test_index = index.Index(['string_2']) class ForeignKeyTestModel(model.Model): """Model class for testing foreign keys.""" @@ -72,31 +86,16 @@ class ForeignKeyTestModel(model.Model): referencing_key_3 = field.Field(field.Integer, primary_key=True) self_referencing_key = field.Field(field.String, nullable=True) foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( - 'SmallTestModel', {'referencing_key_1': 'key'}) + 'spanner_orm.tests.models.SmallTestModel', {'referencing_key_1': 'key'}) foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( - 'UnittestModel', + 'spanner_orm.tests.models.UnittestModel', {'referencing_key_2': 'string', 'referencing_key_3': 'int_'}, ) foreign_key_3 = foreign_key_relationship.ForeignKeyRelationship( - 'ForeignKeyTestModel', {'self_referencing_key': 'referencing_key_1'}) + 'spanner_orm.tests.models.ForeignKeyTestModel', {'self_referencing_key': 'referencing_key_1'}) class InheritanceTestModel(SmallTestModel): """Model class used for testing model inheritance.""" value_3 = field.Field(field.String, nullable=True) - -class UnittestModel(model.Model): - """Model class used for model testing.""" - - __table__ = 'table' - int_ = field.Field(field.Integer, primary_key=True) - int_2 = field.Field(field.Integer, nullable=True) - float_ = field.Field(field.Float, primary_key=True) - float_2 = field.Field(field.Float, nullable=True) - string = field.Field(field.String, primary_key=True) - string_2 = field.Field(field.String, nullable=True) - timestamp = field.Field(field.Timestamp) - string_array = field.Field(field.StringArray, nullable=True) - - test_index = index.Index(['string_2']) diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index 02ffa28..e476602 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -207,36 +207,70 @@ def test_force_index_with_object(self): expected_sql = 'FROM table@{FORCE_INDEX=test_index}' self.assertEndsWith(select_query.sql(), expected_sql) - def includes(self, relation, *conditions): - include_condition = condition.includes(relation, list(conditions)) - return query.SelectQuery(models.RelationshipTestModel, [include_condition]) - - def test_includes(self): - select_query = self.includes('parent') - - # The column order varies between test runs - expected_sql = ( - r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ' - r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' - r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' - r'RelationshipTestModel.parent_key\)') - self.assertRegex(select_query.sql(), expected_sql) - self.assertEmpty(select_query.parameters()) - self.assertEmpty(select_query.types()) - - def test_includes_with_object(self): - select_query = self.includes(models.RelationshipTestModel.parent) + def includes( + self, relation, *conditions, foreign_key_relation=False): + include_condition = condition.includes( + relation, list(conditions), foreign_key_relation) + return query.SelectQuery( + models.ForeignKeyTestModel if foreign_key_relation + else models.RelationshipTestModel, + [include_condition], + ) + + @parameterized.named_parameters( + ( + 'legacy_relationship', + {'relation': 'parent'}, + r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ' + r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' + r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' + r'RelationshipTestModel.parent_key\)', + ), + ( + 'legacy_relationship_with_object_arg', + {'relation': models.RelationshipTestModel.parent}, + r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ' + r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' + r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' + r'RelationshipTestModel.parent_key\)', + ), + ( + 'foreign_key_relationship', + {'relation': 'foreign_key_1', 'foreign_key_relation': True}, + r'SELECT ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ' + r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' + r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' + r'ForeignKeyTestModel.referencing_key_1\)', + ), + ( + 'foreign_key_relationship_with_object_arg', + {'relation': models.ForeignKeyTestModel.foreign_key_1, 'foreign_key_relation': True}, + r'SELECT ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ' + r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' + r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' + r'ForeignKeyTestModel.referencing_key_1\)', + ), + ) + def test_includes(self, includes_kwargs, expected_sql): + select_query = self.includes(**includes_kwargs) # The column order varies between test runs - expected_sql = ( - r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ' - r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' - r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' - r'RelationshipTestModel.parent_key\)') self.assertRegex(select_query.sql(), expected_sql) self.assertEmpty(select_query.parameters()) self.assertEmpty(select_query.types()) + @parameterized.parameters( + ( + {'relation': models.RelationshipTestModel.parent, 'foreign_key_relation': True}, + ), + ( + {'relation': models.ForeignKeyTestModel.foreign_key_1, 'foreign_key_relation': False}, + ) + ) + def test_error_mismatched_params(self, includes_kwargs): + with self.assertRaisesRegex(ValueError, 'Must pass'): + self.includes(**includes_kwargs) + def test_includes_subconditions_query(self): select_query = self.includes('parents', condition.equal_to('key', 'value')) expected_sql = ( @@ -254,6 +288,20 @@ def includes_result(self, related=1): result.append(parents) return child, parent, [result] + def fk_includes_result(self, related=1): + child = {'referencing_key_1': 'parent_key', + 'referencing_key_2': 'child', + 'referencing_key_3': 'child', + 'self_referencing_key': 'child'} + result = [child[name] for name in models.ForeignKeyTestModel.columns] + parent = {'key': 'key', 'value_1': 'value_1', 'value_2': None} + parents = [] + for _ in range(related): + parents.append([parent[name] for name in models.SmallTestModel.columns]) + result.append(parents) + return child, parent, [result] + + def test_includes_single_related_object_result(self): select_query = self.includes('parent') child_values, parent_values, rows = self.includes_result(related=1) @@ -266,12 +314,38 @@ def test_includes_single_related_object_result(self): for name, value in parent_values.items(): self.assertEqual(getattr(result.parent, name), value) - def test_includes_single_no_related_object_result(self): - select_query = self.includes('parent') - child_values, _, rows = self.includes_result(related=0) + def test_fk_includes_single_related_object_result(self): + select_query = self.includes('foreign_key_1', foreign_key_relation = True) + child_values, parent_values, rows = self.fk_includes_result(related=1) result = select_query.process_results(rows)[0] - self.assertIsNone(result.parent) + self.assertIsInstance(result.foreign_key_1, models.SmallTestModel) + for name, value in child_values.items(): + self.assertEqual(getattr(result, name), value) + + for name, value in parent_values.items(): + self.assertEqual(getattr(result.foreign_key_1, name), value) + + @parameterized.named_parameters( + ( + 'legacy_relationship', + {'relation': 'parent'}, + lambda x: x.parent, + lambda x: x.includes_result(related=0), + ), + ( + 'foreign_key_relationship', + {'relation': 'foreign_key_1', 'foreign_key_relation': True}, + lambda x: x.foreign_key_1, + lambda x: x.fk_includes_result(related=0), + ), + ) + def test_includes_single_no_related_object_result(self, includes_kwargs, x, y): + select_query = self.includes(**includes_kwargs) + child_values, _, rows = y(self) + result = select_query.process_results(rows)[0] + + self.assertIsNone(x(result)) for name, value in child_values.items(): self.assertEqual(getattr(result, name), value) @@ -288,21 +362,52 @@ def test_includes_subcondition_result(self): for name, value in parent_values.items(): self.assertEqual(getattr(result.parents[0], name), value) - def test_includes_error_on_multiple_results_for_single(self): - select_query = self.includes('parent') - _, _, rows = self.includes_result(related=2) + @parameterized.named_parameters( + ( + 'legacy_relationship', + {'relation': 'parent'}, + lambda x: x.includes_result(related=2), + ), + ( + 'foreign_key_relationship', + {'relation': 'foreign_key_1', 'foreign_key_relation': True}, + lambda x: x.fk_includes_result(related=2), + ), + ) + def test_includes_error_on_multiple_results_for_single( + self, includes_kwargs, x): + select_query = self.includes(**includes_kwargs) + _, _, rows = x(self) with self.assertRaises(error.SpannerError): _ = select_query.process_results(rows) - def test_includes_error_on_invalid_relation(self): + @parameterized.parameters(True, False) + def test_includes_error_on_invalid_relation( + self, foreign_key_relation): with self.assertRaises(error.ValidationError): - self.includes('bad_relation') + self.includes('bad_relation', foreign_key_relation=foreign_key_relation) - @parameterized.parameters(('bad_column', 0), ('child_key', 'good value'), - ('key', ['bad value'])) - def test_includes_error_on_invalid_subconditions(self, column, value): + @parameterized.parameters( + ('bad_column', 0, 'parent', False), + ('bad_column', 0, 'foreign_key_1', True), + ('child_key', 'good value', 'parent', False), + ('child_key', 'good value', 'foreign_key_1', False), + ('key', ['bad value'], 'parent', False), + ('key', ['bad value'], 'foreign_key_1', False), + ) + def test_includes_error_on_invalid_subconditions( + self, + column, + value, + relation, + foreign_key_relation + ): with self.assertRaises(error.ValidationError): - self.includes('parent', condition.equal_to(column, value)) + self.includes( + relation, + condition.equal_to(column, value), + foreign_key_relation, + ) def test_or(self): condition_1 = condition.equal_to('int_', 1) From 6f6df42ae4e494c9eb4d143ac2667d10559060b4 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Mon, 5 Oct 2020 08:43:43 -0400 Subject: [PATCH 016/131] Add some type hints, and clean up style. --- spanner_orm/foreign_key_relationship.py | 33 ++++++++++++++----------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py index d03ebb0..f915418 100644 --- a/spanner_orm/foreign_key_relationship.py +++ b/spanner_orm/foreign_key_relationship.py @@ -14,7 +14,7 @@ # limitations under the License. """Helps define a foreign key relationship between two models.""" -from typing import Any, Mapping +from typing import Any, Mapping, Type import dataclasses from spanner_orm import registry @@ -24,7 +24,7 @@ class ForeignKeyRelationshipConstraint: columns: Mapping[str, str] referenced_table_name: str - referenced_table: Any + referenced_table: Type[Any] class ForeignKeyRelationship(object): @@ -32,7 +32,8 @@ class ForeignKeyRelationship(object): def __init__(self, referenced_table_name: str, - columns: Mapping[str, str]): + columns: Mapping[str, str], + single: bool = False): """Creates a ForeignKeyRelationship. Args: @@ -40,11 +41,14 @@ def __init__(self, columns: Dictionary where the keys are names of columns from the referencing table and the values are the names of the columns in the referenced table. + single: True if the referenced table should be treated as a single object + instead of a list of objects. """ self.origin = None self.name = None self._referenced_table_name = referenced_table_name self._columns = columns + self._single = single @property def constraint(self) -> ForeignKeyRelationshipConstraint: @@ -55,24 +59,25 @@ def ddl(self) -> str: referencing_columns_ddl = ', '.join(self.constraint.columns.keys()) referenced_columns_ddl = ', '.join(self.constraint.columns.values()) return ( - 'CONSTRAINT {fk_name} FOREIGN KEY ({referencing_columns}) REFERENCES' - ' {referenced_table} ({referenced_columns})').format( - fk_name=self.name, - referencing_columns=referencing_columns_ddl, - referenced_table=self.constraint.referenced_table_name, - referenced_columns=referenced_columns_ddl, - ) + 'CONSTRAINT {fk_name} FOREIGN KEY ({referencing_columns}) REFERENCES' + ' {referenced_table} ({referenced_columns})').format( + fk_name=self.name, + referencing_columns=referencing_columns_ddl, + referenced_table=self.constraint.referenced_table_name, + referenced_columns=referenced_columns_ddl, + ) def _parse_constraint(self) -> ForeignKeyRelationshipConstraint: """Return the relationship constraint.""" referenced_table = registry.model_registry().get( - self._referenced_table_name) + self._referenced_table_name) return ForeignKeyRelationshipConstraint( - self._columns, - referenced_table.table, - referenced_table, + self._columns, + referenced_table.table, + referenced_table, ) @property def single(self) -> bool: + # self._single return True From cb585727d7959dd2b4d5c494ad3b9307e6ba122c Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Mon, 5 Oct 2020 14:38:43 -0400 Subject: [PATCH 017/131] Ignore pytype folder --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 3aed126..ca6c5cf 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ build __pycache__ .eggs *.egg-info +.pytype From 9dba955b725ba4e38b17b46cc55bbc03652f02f6 Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Mon, 5 Oct 2020 14:47:06 -0400 Subject: [PATCH 018/131] Add test + pytype instructions to the README --- README.md | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/README.md b/README.md index db2eeb5..f01775c 100644 --- a/README.md +++ b/README.md @@ -176,3 +176,34 @@ multiple times, so try not to do that. If a migration needs to be rolled back, ```spanner-orm rollback ``` or the corresponding ```MigrationExecutor``` method should be used. + +## Tests + +Before running any tests, you'll need to download the Cloud Spanner Emulator. +See https://github.com/GoogleCloudPlatform/cloud-spanner-emulator for several +options. If you're on Linux, we recommend: + +``` +VERSION=1.0.0 +wget https://storage.googleapis.com/cloud-spanner-emulator/releases/${VERSION}/cloud-spanner-emulator_linux_amd64-${VERSION}.tar.gz +tar zxvf cloud-spanner-emulator_linux_amd64-${VERSION}.tar.gz +chmod u+x gateway_main emulator_main +``` + +``` +git clone git@github.com:GoogleCloudPlatform/cloud-spanner-emulator.git +``` + +To check type annotations, run: + +``` +pip install pytype +# https://github.com/google/pytype/issues/80#issuecomment-385128856 +pytype -V 3.7 spanner_orm -d import-error +``` + +Then run tests with: + +``` +SPANNER_EMULATOR_BINARY_PATH=$(pwd)/emulator_main python3 setup.py test +``` From 9402131c9ba10af269573619b5bc11359d5e9e39 Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Mon, 5 Oct 2020 14:47:15 -0400 Subject: [PATCH 019/131] Ignore Spanner emulator files --- .gitignore | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.gitignore b/.gitignore index ca6c5cf..0e9dea3 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,9 @@ __pycache__ .eggs *.egg-info .pytype + +# Files that may or may not be added to the repo while acquiring the Spanner +# emulator. +cloud-spanner-emulator* +emulator_main +gateway_main From 0f3b46cad7e85d149c099fa8eaf62fae89bb09c3 Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Mon, 5 Oct 2020 15:40:59 -0400 Subject: [PATCH 020/131] Re-establish connection if needed --- spanner_orm/api.py | 16 +++++++++++++--- spanner_orm/decorator.py | 9 ++++++++- spanner_orm/tests/decorator_test.py | 27 +++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/spanner_orm/api.py b/spanner_orm/api.py index ccb0a27..437a268 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -94,10 +94,20 @@ def __init__(self, pool: Optional[spanner_pool.AbstractSessionPool] = None, create_ddl: Optional[Iterable[str]] = None): """Connects to the specified Spanner database.""" - client = spanner.Client(project=project, credentials=credentials) - instance = client.instance(instance) + self._instance = instance + self._database = database + self._project = project + self._credentials = credentials + self._pool = pool + self._create_ddl = create_ddl + self.connect() + + def connect(self): + """Establish a new connection to the specified Spanner database.""" + client = spanner.Client(project=self._project, credentials=self._credentials) + instance = client.instance(self._instance) self.database = instance.database( - database, pool=pool, ddl_statements=create_ddl or ()) + self._database, pool=self._pool, ddl_statements=self._create_ddl or ()) class SpannerApi(SpannerReadApi, SpannerWriteApi): diff --git a/spanner_orm/decorator.py b/spanner_orm/decorator.py index bd2a352..5980101 100644 --- a/spanner_orm/decorator.py +++ b/spanner_orm/decorator.py @@ -16,6 +16,7 @@ from typing import Callable, TypeVar +from google.api_core import exceptions from spanner_orm import api T = TypeVar('T') @@ -103,6 +104,12 @@ def wrapper(*args, **kwargs) -> T: return func(*args, **kwargs) spanner_api_method = spanner_api_method_lambda() - return spanner_api_method(spanner_wrapper, *args, **kwargs) + try: + return spanner_api_method(spanner_wrapper, *args, **kwargs) + except exceptions.NotFound: + # https://cloud.google.com/spanner/docs/sessions#handle_deleted_sessions + # states that Spanner may delete existing sessions for various reasons. + api.spanner_api().connect() + return spanner_api_method(spanner_wrapper, *args, **kwargs) return wrapper diff --git a/spanner_orm/tests/decorator_test.py b/spanner_orm/tests/decorator_test.py index ba30d25..31bf23d 100644 --- a/spanner_orm/tests/decorator_test.py +++ b/spanner_orm/tests/decorator_test.py @@ -17,6 +17,7 @@ from unittest import mock from absl.testing import parameterized +from google.api_core import exceptions from spanner_orm import decorator @@ -64,6 +65,32 @@ def get_book(book_id, genre=None, transaction=None): self.assertEqual(200, result) + @parameterized.parameters( + (decorator.transactional_read, 'run_read_only'), + (decorator.transactional_write, 'run_write'), + ) + @mock.patch('spanner_orm.api.spanner_api') + def test_reconnect_on_expected_error(self, decorator_in_test, + method_name_to_mock, + mock_spanner_api): + mock_api_method = getattr(mock_spanner_api.return_value, + method_name_to_mock) + mock_api_method.side_effect = [ + exceptions.NotFound('404 Session not found'), + 'Anything other than an exception' + ] + mock_connect = mock_spanner_api.return_value.connect + + @decorator_in_test + def get_book(book_id, genre=None, transaction=None): + pass + + get_book(123, genre='horror') + + mock_connect.assert_called_once() + mock_api_method.assert_has_calls( + [mock.call(mock.ANY, 123, genre='horror')] * 2) + def mock_spanner_method(mock_transaction): From 0261536113c80c9fff9744075b0e9a1e888c8c9a Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Mon, 5 Oct 2020 16:48:22 -0400 Subject: [PATCH 021/131] Remove changes that will be added in follow-up PR --- spanner_orm/admin/update.py | 52 ------------------------------------- 1 file changed, 52 deletions(-) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 5f961cf..54f7ec8 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -318,58 +318,6 @@ def validate(self) -> None: if db_index.primary: raise error.SpannerError('Index {} is the primary index'.format( self._index)) - -class AddForeignKeyRelationship(SchemaUpdate): - """Update for adding a column to an existing table. - - Only supports adding nullable columns - """ - - def __init__( - self, - referencing_table_name: str, - referenced_table_name: str, - column_mapping, - ): - self._table = table_name - self._column = column_name - self._field = field_ - - def ddl(self) -> str: - return 'ALTER TABLE {} ADD'.format(self._table, self._column, - self._field.ddl()) - - def validate(self) -> None: - model_ = metadata.SpannerMetadata.model(self._table) - if not model_: - raise error.SpannerError('Table {} does not exist'.format(self._table)) - - -class DropForeignKeyRelationship(SchemaUpdate): - """Update for dropping a column from an existing table.""" - - def __init__(self, table_name: str, column_name: str): - self._table = table_name - self._column = column_name - - def ddl(self) -> str: - return 'ALTER TABLE {} DROP COLUMN {}'.format(self._table, self._column) - - def validate(self) -> None: - model_ = metadata.SpannerMetadata.model(self._table) - if not model_: - raise error.SpannerError('Table {} does not exist'.format(self._table)) - - if self._column not in model_.fields: - raise error.SpannerError('Column {} does not exist on {}'.format( - self._column, self._table)) - - # Verify no indices exist on the column we're trying to drop - num_indexed_columns = index_column.IndexColumnSchema.count( - None, condition.equal_to('column_name', self._column), - condition.equal_to('table_name', self._table)) - if num_indexed_columns > 0: - raise error.SpannerError('Column {} is indexed'.format(self._column)) class NoUpdate(SchemaUpdate): From b1bb9264767c893eba4069eb2b542b9e47acaf00 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Mon, 5 Oct 2020 16:49:03 -0400 Subject: [PATCH 022/131] Clarify single must be set --- spanner_orm/foreign_key_relationship.py | 3 ++- spanner_orm/tests/query_test.py | 36 +++++++++++++------------ 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py index f915418..44b3244 100644 --- a/spanner_orm/foreign_key_relationship.py +++ b/spanner_orm/foreign_key_relationship.py @@ -79,5 +79,6 @@ def _parse_constraint(self) -> ForeignKeyRelationshipConstraint: @property def single(self) -> bool: - # self._single + # Spanner enforces uniqueness for values of fields referenced by + # foreign keys. return True diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index e476602..857edd8 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -301,30 +301,32 @@ def fk_includes_result(self, related=1): result.append(parents) return child, parent, [result] - - def test_includes_single_related_object_result(self): - select_query = self.includes('parent') - child_values, parent_values, rows = self.includes_result(related=1) - result = select_query.process_results(rows)[0] - - self.assertIsInstance(result.parent, models.SmallTestModel) - for name, value in child_values.items(): - self.assertEqual(getattr(result, name), value) - - for name, value in parent_values.items(): - self.assertEqual(getattr(result.parent, name), value) - def test_fk_includes_single_related_object_result(self): - select_query = self.includes('foreign_key_1', foreign_key_relation = True) - child_values, parent_values, rows = self.fk_includes_result(related=1) + @parameterized.named_parameters( + ( + 'legacy_relationship', + {'relation': 'parent'}, + lambda x: x.parent, + lambda x: x.includes_result(related=1), + ), + ( + 'foreign_key_relationship', + {'relation': 'foreign_key_1', 'foreign_key_relation': True}, + lambda x: x.foreign_key_1, + lambda x: x.fk_includes_result(related=1), + ), + ) + def test_includes_single_related_object_result(self, includes_kwargs, x, y): + select_query = self.includes(**includes_kwargs) + child_values, parent_values, rows = y(self) result = select_query.process_results(rows)[0] - self.assertIsInstance(result.foreign_key_1, models.SmallTestModel) + self.assertIsInstance(x(result), models.SmallTestModel) for name, value in child_values.items(): self.assertEqual(getattr(result, name), value) for name, value in parent_values.items(): - self.assertEqual(getattr(result.foreign_key_1, name), value) + self.assertEqual(getattr(x(result), name), value) @parameterized.named_parameters( ( From 9bcc90518f2c0197ee69e8aca9319a38319f50bb Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Mon, 5 Oct 2020 20:54:41 -0400 Subject: [PATCH 023/131] Suggest the use of virtualenv Suggest the use of virtualenv for local development, running tests, and type checking. --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index f01775c..bed79bd 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,11 @@ or the corresponding ```MigrationExecutor``` method should be used. ## Tests +Note: we suggest using a Python 3.7 +[virtualenv](https://docs.python.org/3/library/venv.html) +for running tests and type checking. + + Before running any tests, you'll need to download the Cloud Spanner Emulator. See https://github.com/GoogleCloudPlatform/cloud-spanner-emulator for several options. If you're on Linux, we recommend: From 0be29e9166ba80aed7adfa947d22a695bd402128 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Tue, 6 Oct 2020 12:45:54 -0400 Subject: [PATCH 024/131] Get pytype to pass --- spanner_orm/admin/migration_manager.py | 2 +- spanner_orm/model.py | 2 +- spanner_orm/query.py | 8 ++++---- spanner_orm/relationship.py | 3 +-- spanner_orm/table_apis.py | 6 +++--- spanner_orm/tests/metadata_test.py | 3 +++ 6 files changed, 13 insertions(+), 11 deletions(-) diff --git a/spanner_orm/admin/migration_manager.py b/spanner_orm/admin/migration_manager.py index da1a29f..620a53f 100644 --- a/spanner_orm/admin/migration_manager.py +++ b/spanner_orm/admin/migration_manager.py @@ -74,7 +74,7 @@ def _migration_from_file(self, filename: str) -> migration.Migration: path = os.path.join(self.basedir, filename) spec = importlib.util.spec_from_file_location(module_name, path) module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + spec.loader.exec_module(module) # type: ignore try: result = migration.Migration(module.migration_id, module.prev_migration_id, diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 06060cb..17442a2 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -307,7 +307,7 @@ def where_equal(cls, def _results_to_models(cls, results: Iterable[Iterable[Any]]) -> List['ModelObject']: items = [dict(zip(cls.columns, result)) for result in results] - return [cls(item, persisted=True) for item in items] + return [cls(item, persisted=True) for item in items] # type: ignore @classmethod def _execute_read(cls, db_api: Callable[..., CallableReturn], diff --git a/spanner_orm/query.py b/spanner_orm/query.py index 75c461e..8b4273c 100644 --- a/spanner_orm/query.py +++ b/spanner_orm/query.py @@ -15,7 +15,7 @@ """Helps build SQL for complex Spanner queries.""" import abc -from typing import Any, Dict, Iterable, List, Tuple, Type +from typing import Any, Dict, Iterable, List, Sequence, Tuple, Type from spanner_orm import condition from spanner_orm import error @@ -47,7 +47,7 @@ def types(self) -> Dict[str, Any]: return self._types @abc.abstractmethod - def process_results(self, results: List[List[Any]]) -> None: + def process_results(self, results: List[Sequence[Any]]) -> None: pass def _segments(self, @@ -148,7 +148,7 @@ def __init__(self, model: Type[Any], def _select(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: return ('SELECT COUNT(*)', {}, {}) - def process_results(self, results: List[List[Any]]) -> int: + def process_results(self, results: List[Sequence[Any]]) -> int: return int(results[0][0]) @@ -186,7 +186,7 @@ def _select(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: prefix=self._select_prefix(), columns=', '.join(columns)), parameters, types) - def process_results(self, results: List[List[Any]]) -> List[Type[Any]]: + def process_results(self, results: List[Sequence[Any]]) -> List[Type[Any]]: return [self._process_row(result) for result in results] def _process_row(self, row: List[Any]) -> Type[Any]: diff --git a/spanner_orm/relationship.py b/spanner_orm/relationship.py index d929b8f..d4947ef 100644 --- a/spanner_orm/relationship.py +++ b/spanner_orm/relationship.py @@ -88,9 +88,8 @@ def _parse_constraints(self) -> List[RelationshipConstraint]: raise error.ValidationError( 'Destination column must be present in destination model') - # TODO(dbrandao): remove when pytype #234 is fixed constraints.append( RelationshipConstraint(self.destination, destination_column, - self.origin, origin_column)) # type: ignore + self.origin, origin_column)) return constraints diff --git a/spanner_orm/table_apis.py b/spanner_orm/table_apis.py index 3ee4625..0b64df3 100644 --- a/spanner_orm/table_apis.py +++ b/spanner_orm/table_apis.py @@ -15,7 +15,7 @@ """Table-level API lambdas for Spanner transactions.""" import logging -from typing import Any, Dict, Iterable, List +from typing import Any, Dict, Iterable, List, Sequence from google.cloud import spanner from google.cloud.spanner_v1 import transaction as spanner_transaction @@ -26,7 +26,7 @@ # Read methods def find(transaction: spanner_transaction.Transaction, table_name: str, - columns: Iterable[str], keyset: spanner.KeySet) -> List[Iterable[Any]]: + columns: Iterable[str], keyset: spanner.KeySet) -> List[Sequence[Any]]: """Retrieves rows from the given table based on the provided KeySet. Args: @@ -51,7 +51,7 @@ def find(transaction: spanner_transaction.Transaction, table_name: str, def sql_query(transaction: spanner_transaction.Transaction, query: str, parameters: Dict[str, Any], - parameter_types: Dict[str, type_pb2.Type]) -> List[Iterable[Any]]: + parameter_types: Dict[str, type_pb2.Type]) -> List[Sequence[Any]]: """Executes a given SQL query against the Spanner database. This isn't technically read-only, but it's necessary to implement the read- diff --git a/spanner_orm/tests/metadata_test.py b/spanner_orm/tests/metadata_test.py index 8d6b46f..df1259f 100644 --- a/spanner_orm/tests/metadata_test.py +++ b/spanner_orm/tests/metadata_test.py @@ -12,6 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# type: ignore + import logging import unittest From 93e39cb6fe50b068773442bc2c78ccf9da744e45 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Tue, 6 Oct 2020 14:05:09 -0400 Subject: [PATCH 025/131] Add TODOs --- spanner_orm/admin/migration_manager.py | 1 + spanner_orm/model.py | 1 + spanner_orm/tests/metadata_test.py | 1 + 3 files changed, 3 insertions(+) diff --git a/spanner_orm/admin/migration_manager.py b/spanner_orm/admin/migration_manager.py index 620a53f..8f68806 100644 --- a/spanner_orm/admin/migration_manager.py +++ b/spanner_orm/admin/migration_manager.py @@ -74,6 +74,7 @@ def _migration_from_file(self, filename: str) -> migration.Migration: path = os.path.join(self.basedir, filename) spec = importlib.util.spec_from_file_location(module_name, path) module = importlib.util.module_from_spec(spec) + # TODO(#93): Remove pytype disable below. spec.loader.exec_module(module) # type: ignore try: result = migration.Migration(module.migration_id, diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 17442a2..c3997a7 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -307,6 +307,7 @@ def where_equal(cls, def _results_to_models(cls, results: Iterable[Iterable[Any]]) -> List['ModelObject']: items = [dict(zip(cls.columns, result)) for result in results] + # TODO(#93): Remove pytype disable below. return [cls(item, persisted=True) for item in items] # type: ignore @classmethod diff --git a/spanner_orm/tests/metadata_test.py b/spanner_orm/tests/metadata_test.py index df1259f..308789d 100644 --- a/spanner_orm/tests/metadata_test.py +++ b/spanner_orm/tests/metadata_test.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# TODO(#93): Remove pytype disable below. # type: ignore import logging From 3af5eca9a82de6de607bf32c0aa81f86c8658afe Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 7 Oct 2020 18:24:57 -0400 Subject: [PATCH 026/131] Fix pytype error when dynamically importing a python file. See https://github.com/google/pytype/issues/319 for context on the error. Addresses #93. --- spanner_orm/admin/migration_manager.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/spanner_orm/admin/migration_manager.py b/spanner_orm/admin/migration_manager.py index 8f68806..c7f9fde 100644 --- a/spanner_orm/admin/migration_manager.py +++ b/spanner_orm/admin/migration_manager.py @@ -72,10 +72,9 @@ def _migration_from_file(self, filename: str) -> migration.Migration: """Loads a single migration from the given filename in the base dir.""" module_name = re.sub(r'\W', '_', filename) path = os.path.join(self.basedir, filename) - spec = importlib.util.spec_from_file_location(module_name, path) - module = importlib.util.module_from_spec(spec) - # TODO(#93): Remove pytype disable below. - spec.loader.exec_module(module) # type: ignore + module = importlib.util.module_from_spec( + importlib.util.spec_from_file_location(module_name, path)) + importlib.machinery.SourceFileLoader(module_name, path).exec_module(module) try: result = migration.Migration(module.migration_id, module.prev_migration_id, From 08223d3ec39d542653136cda49c4f72813b36af3 Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Thu, 8 Oct 2020 09:54:00 -0400 Subject: [PATCH 027/131] Only catch Session Not Found errors --- spanner_orm/decorator.py | 5 ++++- spanner_orm/tests/decorator_test.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/spanner_orm/decorator.py b/spanner_orm/decorator.py index 5980101..e0ed499 100644 --- a/spanner_orm/decorator.py +++ b/spanner_orm/decorator.py @@ -106,9 +106,12 @@ def wrapper(*args, **kwargs) -> T: spanner_api_method = spanner_api_method_lambda() try: return spanner_api_method(spanner_wrapper, *args, **kwargs) - except exceptions.NotFound: + except exceptions.NotFound as e: # https://cloud.google.com/spanner/docs/sessions#handle_deleted_sessions # states that Spanner may delete existing sessions for various reasons. + if not 'Session not found' in e.message: + raise + api.spanner_api().connect() return spanner_api_method(spanner_wrapper, *args, **kwargs) diff --git a/spanner_orm/tests/decorator_test.py b/spanner_orm/tests/decorator_test.py index 31bf23d..448c6f1 100644 --- a/spanner_orm/tests/decorator_test.py +++ b/spanner_orm/tests/decorator_test.py @@ -91,6 +91,25 @@ def get_book(book_id, genre=None, transaction=None): mock_api_method.assert_has_calls( [mock.call(mock.ANY, 123, genre='horror')] * 2) + @parameterized.parameters( + (decorator.transactional_read, 'run_read_only'), + (decorator.transactional_write, 'run_write'), + ) + @mock.patch('spanner_orm.api.spanner_api') + def test_raise_on_expected_error(self, decorator_in_test, + method_name_to_mock, + mock_spanner_api): + mock_api_method = getattr(mock_spanner_api.return_value, + method_name_to_mock) + mock_api_method.side_effect = exceptions.NotFound('404 Database not found') + + @decorator_in_test + def get_book(book_id, genre=None, transaction=None): + pass + + with self.assertRaises(exceptions.NotFound): + get_book(123, genre='horror') + def mock_spanner_method(mock_transaction): From de5f32f49a16f6581b0663e696e42f525697cbbf Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Thu, 8 Oct 2020 13:55:16 -0400 Subject: [PATCH 028/131] Add trailing comma to some parameter lists. I'm going to be moving these around and changing type annotations in a future commit; this change should make it easier to see those changes. --- spanner_orm/model.py | 113 ++++++++++++++++++++++++++++--------------- 1 file changed, 73 insertions(+), 40 deletions(-) diff --git a/spanner_orm/model.py b/spanner_orm/model.py index c3997a7..1758863 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -148,7 +148,7 @@ def spanner_api(cls) -> api.SpannerApi: @classmethod def all( cls, - transaction: Optional[spanner_transaction.Transaction] = None + transaction: Optional[spanner_transaction.Transaction] = None, ) -> List['ModelObject']: """Returns all objects of this type stored in Spanner. @@ -168,8 +168,11 @@ def all( return cls._results_to_models(results) @classmethod - def count(cls, transaction: Optional[spanner_transaction.Transaction], - *conditions: condition.Condition) -> int: + def count( + cls, + transaction: Optional[spanner_transaction.Transaction], + *conditions: condition.Condition, + ) -> int: """Returns the number of objects in Spanner that match the given conditions. Args: @@ -188,9 +191,11 @@ def count(cls, transaction: Optional[spanner_transaction.Transaction], return builder.process_results(results) @classmethod - def count_equal(cls, - transaction: Optional[spanner_transaction.Transaction] = None, - **constraints: Any) -> int: + def count_equal( + cls, + transaction: Optional[spanner_transaction.Transaction] = None, + **constraints: Any, + ) -> int: """Returns the number of objects in Spanner that match the given conditions. Convenience method that generates equality conditions based on the keyword @@ -215,9 +220,11 @@ def count_equal(cls, return cls.count(transaction, *conditions) @classmethod - def find(cls, - transaction: Optional[spanner_transaction.Transaction] = None, - **keys: Any) -> Optional['ModelObject']: + def find( + cls, + transaction: Optional[spanner_transaction.Transaction] = None, + **keys: Any, + ) -> Optional['ModelObject']: """Retrieves an object from Spanner based on the provided key. Args: @@ -234,8 +241,11 @@ def find(cls, return resources[0] if resources else None @classmethod - def find_multi(cls, transaction: Optional[spanner_transaction.Transaction], - keys: Iterable[Dict[str, Any]]) -> List['ModelObject']: + def find_multi( + cls, + transaction: Optional[spanner_transaction.Transaction], + keys: Iterable[Dict[str, Any]], + ) -> List['ModelObject']: """Retrieves objects from Spanner based on the provided keys. Args: @@ -259,8 +269,11 @@ def find_multi(cls, transaction: Optional[spanner_transaction.Transaction], return cls._results_to_models(results) @classmethod - def where(cls, transaction: Optional[spanner_transaction.Transaction], - *conditions: condition.Condition) -> List['ModelObject']: + def where( + cls, + transaction: Optional[spanner_transaction.Transaction], + *conditions: condition.Condition, + ) -> List['ModelObject']: """Retrieves objects from Spanner based on the provided conditions. Args: @@ -279,9 +292,11 @@ def where(cls, transaction: Optional[spanner_transaction.Transaction], return builder.process_results(results) @classmethod - def where_equal(cls, - transaction: Optional[spanner_transaction.Transaction] = None, - **constraints: Any) -> List['ModelObject']: + def where_equal( + cls, + transaction: Optional[spanner_transaction.Transaction] = None, + **constraints: Any, + ) -> List['ModelObject']: """Retrieves objects from Spanner based on the provided constraints. Args: @@ -304,16 +319,21 @@ def where_equal(cls, return cls.where(transaction, *conditions) @classmethod - def _results_to_models(cls, - results: Iterable[Iterable[Any]]) -> List['ModelObject']: + def _results_to_models( + cls, + results: Iterable[Iterable[Any]], + ) -> List['ModelObject']: items = [dict(zip(cls.columns, result)) for result in results] # TODO(#93): Remove pytype disable below. return [cls(item, persisted=True) for item in items] # type: ignore @classmethod - def _execute_read(cls, db_api: Callable[..., CallableReturn], - transaction: Optional[spanner_transaction.Transaction], - args: List[Any]) -> CallableReturn: + def _execute_read( + cls, + db_api: Callable[..., CallableReturn], + transaction: Optional[spanner_transaction.Transaction], + args: List[Any], + ) -> CallableReturn: if transaction is not None: return db_api(transaction, *args) else: @@ -321,9 +341,11 @@ def _execute_read(cls, db_api: Callable[..., CallableReturn], # Table write methods @classmethod - def create(cls, - transaction: Optional[spanner_transaction.Transaction] = None, - **kwargs: Any) -> None: + def create( + cls, + transaction: Optional[spanner_transaction.Transaction] = None, + **kwargs: Any, + ) -> None: """Creates a row in Spanner based on the provided data. Note: may throw an exception if bad values are provided. @@ -338,15 +360,19 @@ def create(cls, cls._execute_write(table_apis.insert, transaction, [kwargs]) @classmethod - def create_or_update(cls, - transaction: Optional[ - spanner_transaction.Transaction] = None, - **kwargs: Any) -> None: + def create_or_update( + cls, + transaction: Optional[spanner_transaction.Transaction] = None, + **kwargs: Any, + ) -> None: cls._execute_write(table_apis.upsert, transaction, [kwargs]) @classmethod - def delete_batch(cls, transaction: Optional[spanner_transaction.Transaction], - models: List['ModelObject']) -> None: + def delete_batch( + cls, + transaction: Optional[spanner_transaction.Transaction], + models: List['ModelObject'], + ) -> None: """Deletes rows from Spanner based on the provided models' primary keys. Args: @@ -367,10 +393,12 @@ def delete_batch(cls, transaction: Optional[spanner_transaction.Transaction], cls.spanner_api().run_write(db_api, *args) @classmethod - def save_batch(cls, - transaction: Optional[spanner_transaction.Transaction], - models: List['ModelObject'], - force_write: bool = False) -> None: + def save_batch( + cls, + transaction: Optional[spanner_transaction.Transaction], + models: List['ModelObject'], + force_write: bool = False, + ) -> None: """Writes rows to Spanner based on the provided model data. Args: @@ -400,9 +428,11 @@ def save_batch(cls, cls._execute_write(api_method, transaction, values) @classmethod - def update(cls, - transaction: Optional[spanner_transaction.Transaction] = None, - **kwargs: Any) -> None: + def update( + cls, + transaction: Optional[spanner_transaction.Transaction] = None, + **kwargs: Any, + ) -> None: """Updates a row in Spanner based on the provided data. Args: @@ -416,9 +446,12 @@ def update(cls, cls._execute_write(table_apis.update, transaction, [kwargs]) @classmethod - def _execute_write(cls, db_api: Callable[..., Any], - transaction: Optional[spanner_transaction.Transaction], - dictionaries: Iterable[Dict[str, Any]]) -> None: + def _execute_write( + cls, + db_api: Callable[..., Any], + transaction: Optional[spanner_transaction.Transaction], + dictionaries: Iterable[Dict[str, Any]], + ) -> None: """Validates all write value types and commits write to Spanner.""" columns, values = None, [] for dictionary in dictionaries: From d955f0b3576411bc5f19e89c6a50371268ef6f51 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Thu, 8 Oct 2020 16:37:37 -0400 Subject: [PATCH 029/131] Make pytype infer the correct instance attribute type. Fields have class attributes of type Field, but instance attributes of the specific field type (int, str, etc.) This commit is a hack to get instance attributes to play nicely with pytype. See #97 for a better long-term solution. Addresses #93. --- spanner_orm/admin/column.py | 15 ++++++++------- spanner_orm/admin/index.py | 22 +++++++++++++--------- spanner_orm/admin/index_column.py | 23 ++++++++++++++--------- spanner_orm/admin/migration_status.py | 9 ++++++--- spanner_orm/admin/table.py | 15 ++++++++++----- 5 files changed, 51 insertions(+), 33 deletions(-) diff --git a/spanner_orm/admin/column.py b/spanner_orm/admin/column.py index b559d4e..167be6c 100644 --- a/spanner_orm/admin/column.py +++ b/spanner_orm/admin/column.py @@ -14,6 +14,7 @@ # limitations under the License. """Model for interacting with Spanner column schema table.""" +import typing from typing import Type from spanner_orm import error @@ -25,13 +26,13 @@ class ColumnSchema(schema.InformationSchema): """Model for interacting with Spanner column schema table.""" __table__ = 'information_schema.columns' - table_catalog = field.Field(field.String, primary_key=True) - table_schema = field.Field(field.String, primary_key=True) - table_name = field.Field(field.String, primary_key=True) - column_name = field.Field(field.String, primary_key=True) - ordinal_position = field.Field(field.Integer) - is_nullable = field.Field(field.String) - spanner_type = field.Field(field.String) + table_catalog = typing.cast(str, field.Field(field.String, primary_key=True)) + table_schema = typing.cast(str, field.Field(field.String, primary_key=True)) + table_name = typing.cast(str, field.Field(field.String, primary_key=True)) + column_name = typing.cast(str, field.Field(field.String, primary_key=True)) + ordinal_position = typing.cast(int, field.Field(field.Integer)) + is_nullable = typing.cast(str, field.Field(field.String)) + spanner_type = typing.cast(str, field.Field(field.String)) def nullable(self) -> bool: return self.is_nullable == 'YES' diff --git a/spanner_orm/admin/index.py b/spanner_orm/admin/index.py index 3ce6166..95f4804 100644 --- a/spanner_orm/admin/index.py +++ b/spanner_orm/admin/index.py @@ -14,6 +14,9 @@ # limitations under the License. """Model for interacting with Spanner index schema table.""" +import typing +from typing import Optional + from spanner_orm import field from spanner_orm.admin import schema @@ -22,12 +25,13 @@ class IndexSchema(schema.InformationSchema): """Model for interacting with Spanner index schema table.""" __table__ = 'information_schema.indexes' - table_catalog = field.Field(field.String, primary_key=True) - table_schema = field.Field(field.String, primary_key=True) - table_name = field.Field(field.String, primary_key=True) - index_name = field.Field(field.String, primary_key=True) - index_type = field.Field(field.String) - parent_table_name = field.Field(field.String, nullable=True) - is_unique = field.Field(field.Boolean) - is_null_filtered = field.Field(field.Boolean) - index_state = field.Field(field.String) + table_catalog = typing.cast(str, field.Field(field.String, primary_key=True)) + table_schema = typing.cast(str, field.Field(field.String, primary_key=True)) + table_name = typing.cast(str, field.Field(field.String, primary_key=True)) + index_name = typing.cast(str, field.Field(field.String, primary_key=True)) + index_type = typing.cast(str, field.Field(field.String)) + parent_table_name = typing.cast(Optional[str], + field.Field(field.String, nullable=True)) + is_unique = typing.cast(bool, field.Field(field.Boolean)) + is_null_filtered = typing.cast(bool, field.Field(field.Boolean)) + index_state = typing.cast(str, field.Field(field.String)) diff --git a/spanner_orm/admin/index_column.py b/spanner_orm/admin/index_column.py index d148639..210be13 100644 --- a/spanner_orm/admin/index_column.py +++ b/spanner_orm/admin/index_column.py @@ -14,6 +14,9 @@ # limitations under the License. """Model for interacting with Spanner index column schema table.""" +import typing +from typing import Optional + from spanner_orm import field from spanner_orm.admin import schema @@ -22,12 +25,14 @@ class IndexColumnSchema(schema.InformationSchema): """Model for interacting with Spanner index column schema table.""" __table__ = 'information_schema.index_columns' - table_catalog = field.Field(field.String, primary_key=True) - table_schema = field.Field(field.String, primary_key=True) - table_name = field.Field(field.String, primary_key=True) - index_name = field.Field(field.String, primary_key=True) - column_name = field.Field(field.String, primary_key=True) - ordinal_position = field.Field(field.Integer, nullable=True) - column_ordering = field.Field(field.String, nullable=True) - is_nullable = field.Field(field.String) - spanner_type = field.Field(field.String) + table_catalog = typing.cast(str, field.Field(field.String, primary_key=True)) + table_schema = typing.cast(str, field.Field(field.String, primary_key=True)) + table_name = typing.cast(str, field.Field(field.String, primary_key=True)) + index_name = typing.cast(str, field.Field(field.String, primary_key=True)) + column_name = typing.cast(str, field.Field(field.String, primary_key=True)) + ordinal_position = typing.cast(Optional[int], + field.Field(field.Integer, nullable=True)) + column_ordering = typing.cast(Optional[str], + field.Field(field.String, nullable=True)) + is_nullable = typing.cast(str, field.Field(field.String)) + spanner_type = typing.cast(str, field.Field(field.String)) diff --git a/spanner_orm/admin/migration_status.py b/spanner_orm/admin/migration_status.py index 576593c..cc9eb7b 100644 --- a/spanner_orm/admin/migration_status.py +++ b/spanner_orm/admin/migration_status.py @@ -14,6 +14,9 @@ # limitations under the License. """Indicates whether a migration has been applied to the current database.""" +import datetime +import typing + from spanner_orm import field from spanner_orm import model from spanner_orm.admin import api @@ -26,6 +29,6 @@ def spanner_api(cls) -> api.SpannerAdminApi: return api.spanner_admin_api() __table__ = 'spanner_orm_migrations' - id = field.Field(field.String, primary_key=True) - migrated = field.Field(field.Boolean) - update_time = field.Field(field.Timestamp) + id = typing.cast(str, field.Field(field.String, primary_key=True)) + migrated = typing.cast(bool, field.Field(field.Boolean)) + update_time = typing.cast(datetime.datetime, field.Field(field.Timestamp)) diff --git a/spanner_orm/admin/table.py b/spanner_orm/admin/table.py index d718e33..c07a1df 100644 --- a/spanner_orm/admin/table.py +++ b/spanner_orm/admin/table.py @@ -14,6 +14,9 @@ # limitations under the License. """Model for interacting with Spanner column schema table.""" +import typing +from typing import Optional + from spanner_orm import field from spanner_orm.admin import schema @@ -22,8 +25,10 @@ class TableSchema(schema.InformationSchema): """Model for interacting with Spanner column schema table.""" __table__ = 'information_schema.tables' - table_catalog = field.Field(field.String, primary_key=True) - table_schema = field.Field(field.String, primary_key=True) - table_name = field.Field(field.String, primary_key=True) - parent_table_name = field.Field(field.String, nullable=True) - on_delete_action = field.Field(field.String, nullable=True) + table_catalog = typing.cast(str, field.Field(field.String, primary_key=True)) + table_schema = typing.cast(str, field.Field(field.String, primary_key=True)) + table_name = typing.cast(str, field.Field(field.String, primary_key=True)) + parent_table_name = typing.cast(Optional[str], + field.Field(field.String, nullable=True)) + on_delete_action = typing.cast(Optional[str], + field.Field(field.String, nullable=True)) From c9897d0d074fe94f8c6d63890f2e95d90b24b258 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Thu, 8 Oct 2020 17:37:10 -0400 Subject: [PATCH 030/131] Give all ModelMetaclass classes a `meta` attribute. This fixes bugs where pytype got the type for `meta` wrong because it only existed on some classes. I'm not 100% sure adding a default (empty) metadata to classes that didn't previously have it makes sense, but tests pass and it seems likely to be safe to me. This change also has the side effect of parsing Model-specific attributes from the Model class itself (as opposed to just subclasses of Model), but it doesn't look like Model has any attributes that would be affected. Addresses #93. Tested: Ran tests and pytype. --- spanner_orm/model.py | 3 +-- spanner_orm/tests/metadata_test.py | 3 --- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 1758863..7b04800 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -35,11 +35,10 @@ class ModelMetaclass(type): """Populates ModelMetadata based on class attributes.""" + meta: metadata.ModelMetadata def __new__(mcs, name: str, bases: Any, attrs: Dict[str, Any], **kwargs: Any): parents = [base for base in bases if isinstance(base, ModelMetaclass)] - if not parents: - return super().__new__(mcs, name, bases, attrs, **kwargs) model_metadata = metadata.ModelMetadata() for parent in parents: diff --git a/spanner_orm/tests/metadata_test.py b/spanner_orm/tests/metadata_test.py index 308789d..16b43d7 100644 --- a/spanner_orm/tests/metadata_test.py +++ b/spanner_orm/tests/metadata_test.py @@ -13,9 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO(#93): Remove pytype disable below. -# type: ignore - import logging import unittest From 6b819d880552165ca1f948fe3d1182a99d059fc0 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Thu, 8 Oct 2020 13:52:11 -0400 Subject: [PATCH 031/131] Merge ModelApi into Model. ModelApi was mostly a bunch of named constructors for Model, but there isn't a particularly good way to handle type annotations for named constructors that apply only to a subclass rather than the class they're defined in. Addresses #93. --- spanner_orm/model.py | 105 ++++++++++---------- spanner_orm/tests/model_api_test.py | 144 ---------------------------- spanner_orm/tests/model_test.py | 133 +++++++++++++++++++++++-- 3 files changed, 175 insertions(+), 207 deletions(-) delete mode 100644 spanner_orm/tests/model_api_test.py diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 7b04800..dfe789c 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -32,6 +32,8 @@ from google.cloud import spanner from google.cloud.spanner_v1 import transaction as spanner_transaction +T = TypeVar('T') + class ModelMetaclass(type): """Populates ModelMetadata based on class attributes.""" @@ -129,14 +131,39 @@ def validate_value(cls, field_name, value, error_type=error.SpannerError): CallableReturn = TypeVar('CallableReturn') -class ModelApi(metaclass=ModelMetaclass): - """Implements class-level Spanner queries on top of ModelMetaclass. +class Model(metaclass=ModelMetaclass): + """Maps to a table in spanner and has basic functions for querying tables. - Note: all methods in this class should only be called on subclasses of Model - that have associated tables. Violating this will cause an exception to be - raised. + Note: all methods in this class should only be called on subclasses that have + associated tables. Violating this will cause an exception to be raised. """ + def __init__(self, values: Dict[str, Any], persisted: bool = False): + start_values = {} + self.__dict__['start_values'] = start_values + self.__dict__['_persisted'] = persisted + + # If the values came from Spanner, trust them and skip validation + if not persisted: + # An object is invalid if primary key values are missing + missing_keys = set(self._primary_keys) - set(values.keys()) + if missing_keys: + raise error.SpannerError( + 'All primary keys must be specified. Missing: {keys}'.format( + keys=missing_keys)) + + for column in self._columns: + self._metaclass.validate_value(column, values.get(column), ValueError) + + for column in self._columns: + value = values.get(column) + start_values[column] = copy.copy(value) + self.__dict__[column] = value + + for relation in self._relations: + if relation in values: + self.__dict__[relation] = values[relation] + @classmethod def spanner_api(cls) -> api.SpannerApi: if not cls.table: @@ -146,9 +173,9 @@ def spanner_api(cls) -> api.SpannerApi: # Table read methods @classmethod def all( - cls, + cls: Type[T], transaction: Optional[spanner_transaction.Transaction] = None, - ) -> List['ModelObject']: + ) -> List[T]: """Returns all objects of this type stored in Spanner. Note: this method should only be called on subclasses of Model that have @@ -220,10 +247,10 @@ def count_equal( @classmethod def find( - cls, + cls: Type[T], transaction: Optional[spanner_transaction.Transaction] = None, **keys: Any, - ) -> Optional['ModelObject']: + ) -> Optional[T]: """Retrieves an object from Spanner based on the provided key. Args: @@ -241,10 +268,10 @@ def find( @classmethod def find_multi( - cls, + cls: Type[T], transaction: Optional[spanner_transaction.Transaction], keys: Iterable[Dict[str, Any]], - ) -> List['ModelObject']: + ) -> List[T]: """Retrieves objects from Spanner based on the provided keys. Args: @@ -269,10 +296,10 @@ def find_multi( @classmethod def where( - cls, + cls: Type[T], transaction: Optional[spanner_transaction.Transaction], *conditions: condition.Condition, - ) -> List['ModelObject']: + ) -> List[T]: """Retrieves objects from Spanner based on the provided conditions. Args: @@ -292,10 +319,10 @@ def where( @classmethod def where_equal( - cls, + cls: Type[T], transaction: Optional[spanner_transaction.Transaction] = None, **constraints: Any, - ) -> List['ModelObject']: + ) -> List[T]: """Retrieves objects from Spanner based on the provided constraints. Args: @@ -319,12 +346,11 @@ def where_equal( @classmethod def _results_to_models( - cls, + cls: Type[T], results: Iterable[Iterable[Any]], - ) -> List['ModelObject']: + ) -> List[T]: items = [dict(zip(cls.columns, result)) for result in results] - # TODO(#93): Remove pytype disable below. - return [cls(item, persisted=True) for item in items] # type: ignore + return [cls(item, persisted=True) for item in items] @classmethod def _execute_read( @@ -368,9 +394,9 @@ def create_or_update( @classmethod def delete_batch( - cls, + cls: Type[T], transaction: Optional[spanner_transaction.Transaction], - models: List['ModelObject'], + models: List[T], ) -> None: """Deletes rows from Spanner based on the provided models' primary keys. @@ -393,9 +419,9 @@ def delete_batch( @classmethod def save_batch( - cls, + cls: Type[T], transaction: Optional[spanner_transaction.Transaction], - models: List['ModelObject'], + models: List[T], force_write: bool = False, ) -> None: """Writes rows to Spanner based on the provided model data. @@ -475,36 +501,6 @@ def _execute_write( else: return cls.spanner_api().run_write(db_api, *args) - -class Model(ModelApi): - """Maps to a table in spanner and has basic functions for querying tables.""" - - def __init__(self, values: Dict[str, Any], persisted: bool = False): - start_values = {} - self.__dict__['start_values'] = start_values - self.__dict__['_persisted'] = persisted - - # If the values came from Spanner, trust them and skip validation - if not persisted: - # An object is invalid if primary key values are missing - missing_keys = set(self._primary_keys) - set(values.keys()) - if missing_keys: - raise error.SpannerError( - 'All primary keys must be specified. Missing: {keys}'.format( - keys=missing_keys)) - - for column in self._columns: - self._metaclass.validate_value(column, values.get(column), ValueError) - - for column in self._columns: - value = values.get(column) - start_values[column] = copy.copy(value) - self.__dict__[column] = value - - for relation in self._relations: - if relation in values: - self.__dict__[relation] = values[relation] - def __setattr__(self, name: str, value: Any) -> None: if name in self._relations: raise AttributeError(name) @@ -640,6 +636,3 @@ def save(self, self._metaclass.create(transaction, **self.values) self._persisted = True return self - - -ModelObject = TypeVar('ModelObject', bound=Model) diff --git a/spanner_orm/tests/model_api_test.py b/spanner_orm/tests/model_api_test.py deleted file mode 100644 index ff98a2e..0000000 --- a/spanner_orm/tests/model_api_test.py +++ /dev/null @@ -1,144 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import unittest -from unittest import mock - -from spanner_orm import error -from spanner_orm.tests import models - - -class ModelApiTest(unittest.TestCase): - - @mock.patch('spanner_orm.table_apis.find') - def test_find_calls_api(self, find): - mock_transaction = mock.Mock() - models.UnittestModel.find(mock_transaction, string='string', int_=1, float_=2.3) - - find.assert_called_once() - (transaction, table, columns, keyset), _ = find.call_args - self.assertEqual(transaction, mock_transaction) - self.assertEqual(table, models.UnittestModel.table) - self.assertEqual(columns, models.UnittestModel.columns) - self.assertEqual(keyset.keys, [[1, 2.3, 'string']]) - - @mock.patch('spanner_orm.table_apis.find') - def test_find_result(self, find): - mock_transaction = mock.Mock() - - find.return_value = [['key', 'value_1', None]] - result = models.SmallTestModel.find(mock_transaction, key='key') - if result: - self.assertEqual(result.key, 'key') - self.assertEqual(result.value_1, 'value_1') - self.assertIsNone(result.value_2) - else: - self.fail('Failed to find result') - - @mock.patch('spanner_orm.table_apis.find') - def test_find_multi_calls_api(self, find): - mock_transaction = mock.Mock() - models.UnittestModel.find_multi(mock_transaction, [{ - 'string': 'string', - 'int_': 1, - 'float_': 2.3 - }]) - - find.assert_called_once() - (transaction, table, columns, keyset), _ = find.call_args - self.assertEqual(transaction, mock_transaction) - self.assertEqual(table, models.UnittestModel.table) - self.assertEqual(columns, models.UnittestModel.columns) - self.assertEqual(keyset.keys, [[1, 2.3, 'string']]) - - @mock.patch('spanner_orm.table_apis.find') - def test_find_multi_result(self, find): - mock_transaction = mock.Mock() - find.return_value = [['key', 'value_1', None]] - results = models.SmallTestModel.find_multi(mock_transaction, [{ - 'key': 'key' - }]) - - self.assertEqual(results[0].key, 'key') - self.assertEqual(results[0].value_1, 'value_1') - self.assertIsNone(results[0].value_2) - - @mock.patch('spanner_orm.table_apis.insert') - def test_create_calls_api(self, insert): - mock_transaction = mock.Mock() - models.SmallTestModel.create(mock_transaction, key='key', value_1='value') - - insert.assert_called_once() - (transaction, table, columns, values), _ = insert.call_args - self.assertEqual(transaction, mock_transaction) - self.assertEqual(table, models.SmallTestModel.table) - self.assertEqual(list(columns), ['key', 'value_1']) - self.assertEqual(list(values), [['key', 'value']]) - - def test_create_error_on_invalid_keys(self): - with self.assertRaises(error.SpannerError): - models.SmallTestModel.create(key_2='key') - - def assert_api_called(self, mock_api, mock_transaction): - mock_api.assert_called_once() - (transaction, table, columns, values), _ = mock_api.call_args - self.assertEqual(transaction, mock_transaction) - self.assertEqual(table, models.SmallTestModel.table) - self.assertEqual(list(columns), ['key', 'value_1', 'value_2']) - self.assertEqual(list(values), [['key', 'value', None]]) - - @mock.patch('spanner_orm.table_apis.insert') - def test_save_batch_inserts(self, insert): - mock_transaction = mock.Mock() - values = {'key': 'key', 'value_1': 'value'} - not_persisted = models.SmallTestModel(values) - models.SmallTestModel.save_batch(mock_transaction, [not_persisted]) - self.assert_api_called(insert, mock_transaction) - - @mock.patch('spanner_orm.table_apis.update') - def test_save_batch_updates(self, update): - mock_transaction = mock.Mock() - values = {'key': 'key', 'value_1': 'value'} - persisted = models.SmallTestModel(values, persisted=True) - models.SmallTestModel.save_batch(mock_transaction, [persisted]) - - self.assert_api_called(update, mock_transaction) - - @mock.patch('spanner_orm.table_apis.upsert') - def test_save_batch_force_write_upserts(self, upsert): - mock_transaction = mock.Mock() - values = {'key': 'key', 'value_1': 'value'} - not_persisted = models.SmallTestModel(values) - models.SmallTestModel.save_batch( - mock_transaction, [not_persisted], force_write=True) - self.assert_api_called(upsert, mock_transaction) - - @mock.patch('spanner_orm.table_apis.delete') - def test_delete_batch_deletes(self, delete): - mock_transaction = mock.Mock() - values = {'key': 'key', 'value_1': 'value'} - model = models.SmallTestModel(values) - models.SmallTestModel.delete_batch(mock_transaction, [model]) - - delete.assert_called_once() - (transaction, table, keyset), _ = delete.call_args - self.assertEqual(transaction, mock_transaction) - self.assertEqual(table, models.SmallTestModel.table) - self.assertEqual(keyset.keys, [[model.key]]) - - -if __name__ == '__main__': - logging.basicConfig() - unittest.main() diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 0cddd56..02a80ee 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -19,12 +19,130 @@ from unittest import mock from absl.testing import parameterized +from spanner_orm import error from spanner_orm import field from spanner_orm.tests import models class ModelTest(parameterized.TestCase): + @mock.patch('spanner_orm.table_apis.find') + def test_find_calls_api(self, find): + mock_transaction = mock.Mock() + models.UnittestModel.find( + mock_transaction, string='string', int_=1, float_=2.3) + + find.assert_called_once() + (transaction, table, columns, keyset), _ = find.call_args + self.assertEqual(transaction, mock_transaction) + self.assertEqual(table, models.UnittestModel.table) + self.assertEqual(columns, models.UnittestModel.columns) + self.assertEqual(keyset.keys, [[1, 2.3, 'string']]) + + @mock.patch('spanner_orm.table_apis.find') + def test_find_result(self, find): + mock_transaction = mock.Mock() + + find.return_value = [['key', 'value_1', None]] + result = models.SmallTestModel.find(mock_transaction, key='key') + if result: + self.assertEqual(result.key, 'key') + self.assertEqual(result.value_1, 'value_1') + self.assertIsNone(result.value_2) + else: + self.fail('Failed to find result') + + @mock.patch('spanner_orm.table_apis.find') + def test_find_multi_calls_api(self, find): + mock_transaction = mock.Mock() + models.UnittestModel.find_multi(mock_transaction, [{ + 'string': 'string', + 'int_': 1, + 'float_': 2.3 + }]) + + find.assert_called_once() + (transaction, table, columns, keyset), _ = find.call_args + self.assertEqual(transaction, mock_transaction) + self.assertEqual(table, models.UnittestModel.table) + self.assertEqual(columns, models.UnittestModel.columns) + self.assertEqual(keyset.keys, [[1, 2.3, 'string']]) + + @mock.patch('spanner_orm.table_apis.find') + def test_find_multi_result(self, find): + mock_transaction = mock.Mock() + find.return_value = [['key', 'value_1', None]] + results = models.SmallTestModel.find_multi(mock_transaction, [{ + 'key': 'key' + }]) + + self.assertEqual(results[0].key, 'key') + self.assertEqual(results[0].value_1, 'value_1') + self.assertIsNone(results[0].value_2) + + @mock.patch('spanner_orm.table_apis.insert') + def test_create_calls_api(self, insert): + mock_transaction = mock.Mock() + models.SmallTestModel.create(mock_transaction, key='key', value_1='value') + + insert.assert_called_once() + (transaction, table, columns, values), _ = insert.call_args + self.assertEqual(transaction, mock_transaction) + self.assertEqual(table, models.SmallTestModel.table) + self.assertEqual(list(columns), ['key', 'value_1']) + self.assertEqual(list(values), [['key', 'value']]) + + def test_create_error_on_invalid_keys(self): + with self.assertRaises(error.SpannerError): + models.SmallTestModel.create(key_2='key') + + def assert_api_called(self, mock_api, mock_transaction): + mock_api.assert_called_once() + (transaction, table, columns, values), _ = mock_api.call_args + self.assertEqual(transaction, mock_transaction) + self.assertEqual(table, models.SmallTestModel.table) + self.assertEqual(list(columns), ['key', 'value_1', 'value_2']) + self.assertEqual(list(values), [['key', 'value', None]]) + + @mock.patch('spanner_orm.table_apis.insert') + def test_save_batch_inserts(self, insert): + mock_transaction = mock.Mock() + values = {'key': 'key', 'value_1': 'value'} + not_persisted = models.SmallTestModel(values) + models.SmallTestModel.save_batch(mock_transaction, [not_persisted]) + self.assert_api_called(insert, mock_transaction) + + @mock.patch('spanner_orm.table_apis.update') + def test_save_batch_updates(self, update): + mock_transaction = mock.Mock() + values = {'key': 'key', 'value_1': 'value'} + persisted = models.SmallTestModel(values, persisted=True) + models.SmallTestModel.save_batch(mock_transaction, [persisted]) + + self.assert_api_called(update, mock_transaction) + + @mock.patch('spanner_orm.table_apis.upsert') + def test_save_batch_force_write_upserts(self, upsert): + mock_transaction = mock.Mock() + values = {'key': 'key', 'value_1': 'value'} + not_persisted = models.SmallTestModel(values) + models.SmallTestModel.save_batch( + mock_transaction, [not_persisted], force_write=True) + self.assert_api_called(upsert, mock_transaction) + + @mock.patch('spanner_orm.table_apis.delete') + def test_delete_batch_deletes(self, delete): + mock_transaction = mock.Mock() + values = {'key': 'key', 'value_1': 'value'} + model = models.SmallTestModel(values) + models.SmallTestModel.delete_batch(mock_transaction, [model]) + + delete.assert_called_once() + (transaction, table, keyset), _ = delete.call_args + self.assertEqual(transaction, mock_transaction) + self.assertEqual(table, models.SmallTestModel.table) + self.assertEqual(keyset.keys, [[model.key]]) + def test_set_attr(self): test_model = models.SmallTestModel({'key': 'key', 'value_1': 'value'}) test_model.value_1 = 'value_1' @@ -40,8 +158,9 @@ def test_set_error_on_primary_key(self): with self.assertRaises(AttributeError): test_model.key = 'error' - @parameterized.parameters(('int_2', 'foo'), ('float_2', 'bar'), ('string_2', 5), - ('string_array', 'foo'), ('timestamp', 5)) + @parameterized.parameters(('int_2', 'foo'), ('float_2', 'bar'), + ('string_2', 5), ('string_array', 'foo'), + ('timestamp', 5)) def test_set_error_on_invalid_type(self, attribute, value): string_array = ['foo', 'bar'] timestamp = datetime.datetime.now(tz=datetime.timezone.utc) @@ -128,7 +247,7 @@ def test_relation_get_error_on_unretrieved(self): def test_interleaved(self): self.assertEqual(models.ChildTestModel.interleaved, models.SmallTestModel) - @mock.patch('spanner_orm.model.ModelApi.find') + @mock.patch('spanner_orm.model.Model.find') def test_reload(self, find): values = {'key': 'key', 'value_1': 'value_1'} model = models.SmallTestModel(values, persisted=False) @@ -140,7 +259,7 @@ def test_reload(self, find): self.assertIsNone(transaction) self.assertEqual(kwargs, model.id()) - @mock.patch('spanner_orm.model.ModelApi.find') + @mock.patch('spanner_orm.model.Model.find') def test_reload_reloads(self, find): values = {'key': 'key', 'value_1': 'value_1'} model = models.SmallTestModel(values, persisted=False) @@ -151,7 +270,7 @@ def test_reload_reloads(self, find): self.assertEqual(model.value_1, updated_values['value_1']) self.assertEqual(model.changes(), {}) - @mock.patch('spanner_orm.model.ModelApi.create') + @mock.patch('spanner_orm.model.Model.create') def test_save_creates(self, create): values = {'key': 'key', 'value_1': 'value_1'} model = models.SmallTestModel(values, persisted=False) @@ -162,7 +281,7 @@ def test_save_creates(self, create): self.assertIsNone(transaction) self.assertEqual(kwargs, {**values, 'value_2': None}) - @mock.patch('spanner_orm.model.ModelApi.update') + @mock.patch('spanner_orm.model.Model.update') def test_save_updates(self, update): values = {'key': 'key', 'value_1': 'value_1'} model = models.SmallTestModel(values, persisted=True) @@ -176,7 +295,7 @@ def test_save_updates(self, update): self.assertIsNone(transaction) self.assertEqual(kwargs, values) - @mock.patch('spanner_orm.model.ModelApi.update') + @mock.patch('spanner_orm.model.Model.update') def test_save_no_changes(self, update): values = {'key': 'key', 'value_1': 'value_1'} model = models.SmallTestModel(values, persisted=True) From eb0f7a0d6a914dfab6450fdebb25f1c8f67687f5 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 14 Oct 2020 10:58:27 -0400 Subject: [PATCH 032/131] Add required disclaimer to the README. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index bed79bd..7195449 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Google Cloud Spanner ORM This is a lightweight ORM written in Python and built on top of Cloud Spanner. +This is not an officially supported Google product. ## Getting started From e87c0dc866b78c6c0ebed6906ca0785ffb886551 Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Fri, 30 Oct 2020 15:50:11 -0400 Subject: [PATCH 033/131] Move retry logic to api.py --- spanner_orm/api.py | 26 +++++++++++---- spanner_orm/decorator.py | 12 +------ spanner_orm/tests/api_test.py | 50 +++++++++++++++++++++++++++-- spanner_orm/tests/decorator_test.py | 46 -------------------------- 4 files changed, 69 insertions(+), 65 deletions(-) diff --git a/spanner_orm/api.py b/spanner_orm/api.py index 437a268..4bf5f88 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -17,17 +17,30 @@ import abc from typing import Any, Callable, Iterable, Optional, TypeVar -from spanner_orm import error - +from google.api_core import exceptions from google.auth import credentials as auth_credentials from google.cloud import spanner from google.cloud.spanner_v1 import database as spanner_database from google.cloud.spanner_v1 import pool as spanner_pool +from spanner_orm import error CallableReturn = TypeVar('CallableReturn') +class SpannerRetryableApi(abc.ABC): + def _ensure_session(self, api_method, *args, **kwargs): + try: + return api_method(*args, **kwargs) + except exceptions.NotFound as e: + # https://cloud.google.com/spanner/docs/sessions#handle_deleted_sessions + # states that Spanner may delete existing sessions for various reasons. + if not 'Session not found' in e.message: + raise + + spanner_api().connect() + return api_method(*args, **kwargs) + -class SpannerReadApi(abc.ABC): +class SpannerReadApi(SpannerRetryableApi): """Handles sending read requests to Spanner.""" @property @@ -51,10 +64,10 @@ def run_read_only(self, method: Callable[..., CallableReturn], *args: Any, The return value from `method` will be returned from this method """ with self._connection.snapshot(multi_use=True) as snapshot: - return method(snapshot, *args, **kwargs) + return self._ensure_session(method, snapshot, *args, **kwargs) -class SpannerWriteApi(abc.ABC): +class SpannerWriteApi(SpannerRetryableApi): """Handles sending write requests to Spanner.""" @property @@ -80,7 +93,8 @@ def run_write(self, method: Callable[..., CallableReturn], *args: Any, Returns: The return value from `method` will be returned from this method """ - return self._connection.run_in_transaction(method, *args, **kwargs) + return self._ensure_session( + self._connection.run_in_transaction, method, *args, **kwargs) class SpannerConnection: diff --git a/spanner_orm/decorator.py b/spanner_orm/decorator.py index e0ed499..bd2a352 100644 --- a/spanner_orm/decorator.py +++ b/spanner_orm/decorator.py @@ -16,7 +16,6 @@ from typing import Callable, TypeVar -from google.api_core import exceptions from spanner_orm import api T = TypeVar('T') @@ -104,15 +103,6 @@ def wrapper(*args, **kwargs) -> T: return func(*args, **kwargs) spanner_api_method = spanner_api_method_lambda() - try: - return spanner_api_method(spanner_wrapper, *args, **kwargs) - except exceptions.NotFound as e: - # https://cloud.google.com/spanner/docs/sessions#handle_deleted_sessions - # states that Spanner may delete existing sessions for various reasons. - if not 'Session not found' in e.message: - raise - - api.spanner_api().connect() - return spanner_api_method(spanner_wrapper, *args, **kwargs) + return spanner_api_method(spanner_wrapper, *args, **kwargs) return wrapper diff --git a/spanner_orm/tests/api_test.py b/spanner_orm/tests/api_test.py index 86bdc91..6caf3a0 100644 --- a/spanner_orm/tests/api_test.py +++ b/spanner_orm/tests/api_test.py @@ -16,12 +16,27 @@ import unittest from unittest import mock +from absl.testing import parameterized +from google.api_core import exceptions from spanner_orm import api from spanner_orm import error from spanner_orm.admin import api as admin_api -class ApiTest(unittest.TestCase): +def _mock_run_in_transaction(method, *args, **kwargs): + return method(*args, **kwargs) + + +class MockSpannerApi(api.SpannerReadApi, api.SpannerWriteApi): + def __init__(self): + self.connection_mock = mock.MagicMock() + self.connection_mock.run_in_transaction.side_effect = _mock_run_in_transaction + + @property + def _connection(self): + return self.connection_mock + +class ApiTest(parameterized.TestCase): @mock.patch('google.cloud.spanner.Client') def test_api_connection(self, client): @@ -53,12 +68,43 @@ def test_admin_api_create_ddl_connection(self, client): admin_api.connect('', '', '', create_ddl=['create ddl']) self.assertEqual(admin_api.spanner_admin_api()._connection, connection) + @parameterized.parameters('run_read_only', 'run_write') + @mock.patch('spanner_orm.api.spanner_api') + def test_reconnect_on_expected_error(self, api_method, + mock_spanner_api): + mock_api = MockSpannerApi() + + mock_method = mock.Mock() + mock_method.side_effect = [ + exceptions.NotFound('Session not found'), + 'Anything other than an exception' + ] + mock_connect = mock_spanner_api.return_value.connect + + getattr(mock_api, api_method)(mock_method) + + mock_connect.assert_called_once() + mock_method.assert_called() + + @parameterized.parameters('run_read_only', 'run_write') + @mock.patch('spanner_orm.api.spanner_api') + def test_raise_on_expected_error(self, api_method, + mock_spanner_api): + mock_api = MockSpannerApi() + + mock_method = mock.Mock() + mock_method.side_effect = exceptions.NotFound('Database not found') + + with self.assertRaises(exceptions.NotFound): + getattr(mock_api, api_method)(mock_method) + + mock_method.assert_called() + def mock_connection(self, client): connection = mock.Mock() client().instance().database.return_value = connection return connection - if __name__ == '__main__': logging.basicConfig() unittest.main() diff --git a/spanner_orm/tests/decorator_test.py b/spanner_orm/tests/decorator_test.py index 448c6f1..ba30d25 100644 --- a/spanner_orm/tests/decorator_test.py +++ b/spanner_orm/tests/decorator_test.py @@ -17,7 +17,6 @@ from unittest import mock from absl.testing import parameterized -from google.api_core import exceptions from spanner_orm import decorator @@ -65,51 +64,6 @@ def get_book(book_id, genre=None, transaction=None): self.assertEqual(200, result) - @parameterized.parameters( - (decorator.transactional_read, 'run_read_only'), - (decorator.transactional_write, 'run_write'), - ) - @mock.patch('spanner_orm.api.spanner_api') - def test_reconnect_on_expected_error(self, decorator_in_test, - method_name_to_mock, - mock_spanner_api): - mock_api_method = getattr(mock_spanner_api.return_value, - method_name_to_mock) - mock_api_method.side_effect = [ - exceptions.NotFound('404 Session not found'), - 'Anything other than an exception' - ] - mock_connect = mock_spanner_api.return_value.connect - - @decorator_in_test - def get_book(book_id, genre=None, transaction=None): - pass - - get_book(123, genre='horror') - - mock_connect.assert_called_once() - mock_api_method.assert_has_calls( - [mock.call(mock.ANY, 123, genre='horror')] * 2) - - @parameterized.parameters( - (decorator.transactional_read, 'run_read_only'), - (decorator.transactional_write, 'run_write'), - ) - @mock.patch('spanner_orm.api.spanner_api') - def test_raise_on_expected_error(self, decorator_in_test, - method_name_to_mock, - mock_spanner_api): - mock_api_method = getattr(mock_spanner_api.return_value, - method_name_to_mock) - mock_api_method.side_effect = exceptions.NotFound('404 Database not found') - - @decorator_in_test - def get_book(book_id, genre=None, transaction=None): - pass - - with self.assertRaises(exceptions.NotFound): - get_book(123, genre='horror') - def mock_spanner_method(mock_transaction): From c276008d781cbdb17a8c3401224145d24bcca69f Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Wed, 11 Nov 2020 12:28:23 -0500 Subject: [PATCH 034/131] Re-establish connection if snapshot fails --- spanner_orm/api.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/spanner_orm/api.py b/spanner_orm/api.py index 4bf5f88..af738da 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -63,9 +63,11 @@ def run_read_only(self, method: Callable[..., CallableReturn], *args: Any, Returns: The return value from `method` will be returned from this method """ - with self._connection.snapshot(multi_use=True) as snapshot: - return self._ensure_session(method, snapshot, *args, **kwargs) + return self._ensure_session(self._run_read_only, method, *args, **kwargs) + def _run_read_only(self, method, *args, **kwargs): + with self._connection.snapshot(multi_use=True) as snapshot: + return method(snapshot, *args, **kwargs) class SpannerWriteApi(SpannerRetryableApi): """Handles sending write requests to Spanner.""" From d3ff4f98b66c430f8e63727181fe7819489195da Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Wed, 11 Nov 2020 12:41:04 -0500 Subject: [PATCH 035/131] Allow tuples with in_list --- spanner_orm/condition.py | 4 ++-- spanner_orm/tests/query_test.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 0174930..2666acc 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -485,8 +485,8 @@ def _types(self) -> type_pb2.Type: return {self._column_key: list_type} def _validate(self, model_class: Type[Any]) -> None: - if not isinstance(self.value, list): - raise error.ValidationError('{} is not a list'.format(self.value)) + if not isinstance(self.value, Iterable): + raise error.ValidationError('{} is not iterable'.format(self.value)) if self.column not in model_class.fields: raise error.ValidationError('{} is not a column on {}'.format( self.column, model_class.table)) diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index 02ffa28..a2ff64d 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -165,6 +165,7 @@ def test_query_where_comparison_with_object(self, column, value, grpc_type): @parameterized.parameters( ('int_', [1, 2, 3], field.Integer.grpc_type()), + ('int_', (4, 5, 6), field.Integer.grpc_type()), ('string', ['a', 'b', 'c'], field.String.grpc_type()), ('timestamp', [now()], field.Timestamp.grpc_type())) def test_query_where_list_comparison(self, column, values, grpc_type): From 3ba47f82bff1a186d25a918dd1f31ad58b8a147a Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Thu, 12 Nov 2020 18:42:58 -0500 Subject: [PATCH 036/131] Expose SchemaUpdate. I think this should be the return type annotation for upgrade() and downgrade() functions in migrations. --- spanner_orm/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 1a77d85..78c450e 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -88,6 +88,7 @@ transactional_read = decorator.transactional_read transactional_write = decorator.transactional_write +SchemaUpdate = update_module.SchemaUpdate CreateTable = update_module.CreateTable DropTable = update_module.DropTable AddColumn = update_module.AddColumn From fa0536b06e3bf093a57cda1e86c1032f658aada3 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Thu, 12 Nov 2020 18:52:21 -0500 Subject: [PATCH 037/131] Make the migrations template easier to use. 1. Change the return type annotations so that they don't need to be modified. 2. Various changes for the Google python style guide. --- spanner_orm/admin/migration.skel | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/spanner_orm/admin/migration.skel b/spanner_orm/admin/migration.skel index de43076..afa0eb9 100644 --- a/spanner_orm/admin/migration.skel +++ b/spanner_orm/admin/migration.skel @@ -1,4 +1,4 @@ -"""$migration_name +"""Spanner ORM migration: $migration_name Migration ID: $migration_id Created: $current_date @@ -9,10 +9,12 @@ import spanner_orm migration_id = $migration_id prev_migration_id = $prev_migration_id -# Returns a SchemaUpdate object that tells what should be changed -def upgrade() -> spanner_orm.NoUpdate: + +def upgrade() -> spanner_orm.SchemaUpdate: + """See spanner_orm migrations interface.""" return spanner_orm.NoUpdate() -# Returns a SchemaUpdate object that tells how to roll back the changes -def downgrade() -> spanner_orm.NoUpdate: + +def downgrade() -> spanner_orm.SchemaUpdate: + """See spanner_orm migrations interface.""" return spanner_orm.NoUpdate() From ae3ea8bca09e91719544a34abfc918c2853a0d02 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 9 Dec 2020 17:05:08 -0500 Subject: [PATCH 038/131] Fix error when decorated function has a `method` parameter. The first parameter to both run_read_only() and run_write() is called `method`, so if `**kwargs` has a parameter of the same name, there's an error. --- spanner_orm/decorator.py | 14 +++++++------- spanner_orm/tests/decorator_test.py | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/spanner_orm/decorator.py b/spanner_orm/decorator.py index bd2a352..a2dd7bc 100644 --- a/spanner_orm/decorator.py +++ b/spanner_orm/decorator.py @@ -91,18 +91,18 @@ def _transactional(spanner_api_method_lambda: Callable[[], Callable[..., T]], func: Callable[..., T]) -> Callable[..., T]: """Returns decorated function.""" - # Spanner library calls given function with transaction as first argument. - # It will call 'spanner_wrapper', and we will move transaction from first - # argument to 'transaction' kwarg and call actual 'func' + def wrapper(*args, **kwargs) -> T: - def spanner_wrapper(transaction, *args, **kwargs) -> T: - return func(*args, transaction=transaction, **kwargs) + def spanner_wrapper(transaction) -> T: + # Spanner library calls given function with transaction as first argument. + # It will call 'spanner_wrapper', and we will move transaction from first + # argument to 'transaction' kwarg and call actual 'func' + return func(*args, transaction=transaction, **kwargs) - def wrapper(*args, **kwargs) -> T: if 'transaction' in kwargs: return func(*args, **kwargs) spanner_api_method = spanner_api_method_lambda() - return spanner_api_method(spanner_wrapper, *args, **kwargs) + return spanner_api_method(spanner_wrapper) return wrapper diff --git a/spanner_orm/tests/decorator_test.py b/spanner_orm/tests/decorator_test.py index ba30d25..93d84dc 100644 --- a/spanner_orm/tests/decorator_test.py +++ b/spanner_orm/tests/decorator_test.py @@ -36,16 +36,16 @@ def test_transactional_injects_new_transaction(self, decorator_in_test, mock_api_method.side_effect = mock_spanner_method(mock_tx) @decorator_in_test - def get_book(book_id, genre=None, transaction=None): + def get_book(book_id, method=None, transaction=None): self.assertEqual(mock_tx, transaction) self.assertEqual(123, book_id) - self.assertEqual('horror', genre) + self.assertEqual('library', method) return 200 - result = get_book(123, genre='horror') + result = get_book(123, method='library') self.assertEqual(200, result) - mock_api_method.assert_called_once_with(mock.ANY, 123, genre='horror') + mock_api_method.assert_called_once() @parameterized.parameters(decorator.transactional_read, decorator.transactional_write) @@ -53,14 +53,14 @@ def test_transactional_uses_given_transaction(self, decorator_in_test): mock_tx = mock.Mock() @decorator_in_test - def get_book(book_id, genre=None, transaction=None): + def get_book(book_id, method=None, transaction=None): self.assertEqual(mock_tx, transaction) self.assertEqual(123, book_id) - self.assertEqual('horror', genre) + self.assertEqual('library', method) return 200 - result = get_book(123, genre='horror', transaction=mock_tx) + result = get_book(123, method='library', transaction=mock_tx) self.assertEqual(200, result) From 9b01c9fafe5d41a6c0335d09497ceb4d324aabe4 Mon Sep 17 00:00:00 2001 From: Sabrina Gutierrez Date: Tue, 15 Dec 2020 12:28:29 -0800 Subject: [PATCH 039/131] Skip validation (#106) * Adding a param to skip validation --- spanner_orm/model.py | 10 +++++++--- spanner_orm/tests/model_test.py | 20 ++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/spanner_orm/model.py b/spanner_orm/model.py index dfe789c..8fc948f 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -138,13 +138,17 @@ class Model(metaclass=ModelMetaclass): associated tables. Violating this will cause an exception to be raised. """ - def __init__(self, values: Dict[str, Any], persisted: bool = False): + def __init__(self, + values: Dict[str, Any], + persisted: bool = False, + skip_validation: bool = False): start_values = {} self.__dict__['start_values'] = start_values self.__dict__['_persisted'] = persisted - # If the values came from Spanner, trust them and skip validation - if not persisted: + # If the values came from Spanner or validation is explicitly skipped, trust + # them and skip validation + if not persisted and not skip_validation: # An object is invalid if primary key values are missing missing_keys = set(self._primary_keys) - set(values.keys()) if missing_keys: diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 02a80ee..add2d32 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -180,6 +180,26 @@ def test_get_attr(self): self.assertEqual(test_model.value_1, 'value') self.assertEqual(test_model.value_2, None) + @parameterized.parameters( + (True, True), + (True, False), + (False, True), + ) + def test_skip_validation(self, persisted, skip_validation): + models.SmallTestModel( + {'value_1': 'value'}, + persisted=persisted, + skip_validation=skip_validation, + ) + + def test_validation(self): + with self.assertRaises(error.SpannerError): + models.SmallTestModel( + {'value_1': 'value'}, + persisted=False, + skip_validation=False, + ) + def test_id(self): primary_key = {'string': 'foo', 'int_': 5, 'float_': 2.3} all_data = primary_key.copy() From 28647f15432ee91c2f721814a4285b1b63e0561b Mon Sep 17 00:00:00 2001 From: Sabrina Gutierrez Date: Tue, 15 Dec 2020 16:30:30 -0800 Subject: [PATCH 040/131] Create equals (#107) * Adding the equals method --- spanner_orm/model.py | 6 ++++ spanner_orm/tests/model_test.py | 58 +++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 8fc948f..e778c76 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -168,6 +168,12 @@ def __init__(self, if relation in values: self.__dict__[relation] = values[relation] + def __eq__(self, other: Any) -> bool: + """Compares objects by their type and attributes.""" + if type(self) != type(other): + return NotImplemented + return self.values == other.values + @classmethod def spanner_api(cls) -> api.SpannerApi: if not cls.table: diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index add2d32..d53a158 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -23,6 +23,8 @@ from spanner_orm import field from spanner_orm.tests import models +_TIMESTAMP = datetime.datetime.now(tz=datetime.timezone.utc) + class ModelTest(parameterized.TestCase): @@ -200,6 +202,62 @@ def test_validation(self): skip_validation=False, ) + def test_model_equates(self): + timestamp = datetime.datetime.now(tz=datetime.timezone.utc) + test_model1 = models.UnittestModel({ + 'int_': 0, + 'float_': 0, + 'string': '', + 'string_array': ['foo', 'bar'], + 'timestamp': timestamp, + }) + test_model2 = models.UnittestModel({ + 'int_': 0, + 'float_': 0.0, + 'string': '', + 'string_array': ['foo', 'bar'], + 'timestamp': timestamp, + }) + self.assertEqual(test_model1, test_model2) + + @parameterized.parameters( + (models.UnittestModel({ + 'int_': 0, + 'float_': 0, + 'string': '1', + 'timestamp': _TIMESTAMP, + }), + models.UnittestModel({ + 'int_': 0, + 'float_': 0, + 'string': 'a', + 'timestamp': _TIMESTAMP, + })), + (models.UnittestModel({ + 'int_': 0, + 'float_': 0, + 'string': '', + 'string_array': ['foo', 'bar'], + 'timestamp': _TIMESTAMP, + }), + models.UnittestModel({ + 'int_': 0, + 'float_': 0, + 'string': '', + 'string_array': ['bar', 'foo'], + 'timestamp': _TIMESTAMP, + })), + (models.SmallTestModel({ + 'key': 'key', + 'value_1': 'value' + }), models.InheritanceTestModel({ + 'key': 'key', + 'value_1': 'value' + })), + ) + def test_model_are_different(self, test_model1, test_model2): + self.assertNotEqual(test_model1, test_model2) + def test_id(self): primary_key = {'string': 'foo', 'int_': 5, 'float_': 2.3} all_data = primary_key.copy() From 62c7a4ea0135e161143d78ff4ab9c984678f2728 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 16 Dec 2020 18:51:17 -0500 Subject: [PATCH 041/131] Add method to delete a row by its primary key. I'd rather not create a transaction to find() and delete(), when I don't need the result of find() for anything. --- spanner_orm/model.py | 44 +++++++++++++++++++++++++++------ spanner_orm/tests/model_test.py | 11 +++++++++ 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/spanner_orm/model.py b/spanner_orm/model.py index e778c76..3789a01 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -402,6 +402,19 @@ def create_or_update( ) -> None: cls._execute_write(table_apis.upsert, transaction, [kwargs]) + @classmethod + def _delete_by_keyset( + cls: Type[T], + transaction: Optional[spanner_transaction.Transaction], + keyset: spanner.KeySet, + ) -> None: + db_api = table_apis.delete + args = [cls.table, keyset] + if transaction is not None: + db_api(transaction, *args) + else: + cls.spanner_api().run_write(db_api, *args) + @classmethod def delete_batch( cls: Type[T], @@ -418,14 +431,31 @@ def delete_batch( key_list = [] for model in models: key_list.append([getattr(model, column) for column in cls.primary_keys]) - keyset = spanner.KeySet(keys=key_list) + cls._delete_by_keyset( + transaction=transaction, + keyset=spanner.KeySet(keys=key_list), + ) - db_api = table_apis.delete - args = [cls.table, keyset] - if transaction is not None: - db_api(transaction, *args) - else: - cls.spanner_api().run_write(db_api, *args) + @classmethod + def delete_by_key( + cls, + transaction: Optional[spanner_transaction.Transaction] = None, + **keys: Any, + ) -> None: + """Deletes rows from Spanner based on the provided primary key. + + Args: + transaction: The existing transaction to use, or None to start a new + transaction. + **keys: The keys provided are the complete set of primary keys for this + table and the corresponding values make up the unique identifier of the + object being deleted. + """ + cls._delete_by_keyset( + transaction=transaction, + keyset=spanner.KeySet( + keys=[[keys[column] for column in cls.primary_keys]]), + ) @classmethod def save_batch( diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index d53a158..efbbef6 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -19,6 +19,7 @@ from unittest import mock from absl.testing import parameterized +from google.cloud import spanner from spanner_orm import error from spanner_orm import field from spanner_orm.tests import models @@ -145,6 +146,16 @@ def test_delete_batch_deletes(self, delete): self.assertEqual(table, models.SmallTestModel.table) self.assertEqual(keyset.keys, [[model.key]]) + @mock.patch('spanner_orm.table_apis.delete') + def test_delete_by_key_deletes(self, delete): + mock_transaction = mock.Mock() + models.SmallTestModel.delete_by_key(mock_transaction, key='some-key') + delete.assert_called_once_with( + mock_transaction, + models.SmallTestModel.table, + spanner.KeySet(keys=[['some-key']]), + ) + def test_set_attr(self): test_model = models.SmallTestModel({'key': 'key', 'value_1': 'value'}) test_model.value_1 = 'value_1' From fd84d2a79b4651c50c0aa3ac527f728782e4b52d Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Thu, 17 Dec 2020 17:39:04 -0500 Subject: [PATCH 042/131] End the module docstring one liner with punctuation. One of our internal lints requires that. Tested: manually --- spanner_orm/admin/migration.skel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spanner_orm/admin/migration.skel b/spanner_orm/admin/migration.skel index afa0eb9..0f3ac29 100644 --- a/spanner_orm/admin/migration.skel +++ b/spanner_orm/admin/migration.skel @@ -1,4 +1,4 @@ -"""Spanner ORM migration: $migration_name +"""Spanner ORM migration: $migration_name. Migration ID: $migration_id Created: $current_date From f8a2697ed1132b37caa576323626037b2edaf058 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Mon, 21 Dec 2020 17:05:21 -0500 Subject: [PATCH 043/131] Fix pytype errors. --- spanner_orm/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 3789a01..6393383 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -168,7 +168,7 @@ def __init__(self, if relation in values: self.__dict__[relation] = values[relation] - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: Any) -> Union[bool, type(NotImplemented)]: """Compares objects by their type and attributes.""" if type(self) != type(other): return NotImplemented @@ -404,7 +404,7 @@ def create_or_update( @classmethod def _delete_by_keyset( - cls: Type[T], + cls, transaction: Optional[spanner_transaction.Transaction], keyset: spanner.KeySet, ) -> None: From c8502fc02d184b9e22e09b886ef324aedf2fe6f0 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Mon, 21 Dec 2020 18:59:13 -0500 Subject: [PATCH 044/131] Add condition for substring matching. --- spanner_orm/__init__.py | 1 + spanner_orm/condition.py | 28 ++++++++++++++++++++++++ spanner_orm/tests/condition_test.py | 34 +++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+) create mode 100644 spanner_orm/tests/condition_test.py diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 78c450e..9b256c6 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -68,6 +68,7 @@ StringArray = field.StringArray Timestamp = field.Timestamp +contains = condition.contains equal_to = condition.equal_to force_index = condition.force_index greater_than = condition.greater_than diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 2666acc..2b38d06 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -571,6 +571,34 @@ def columns_equal(origin_column: str, dest_model_class: Type[Any], return ColumnsEqualCondition(origin_column, dest_model_class, dest_column) +def contains( + column: Union[field.Field, str], + value: str, +) -> ComparisonCondition: + """Condition where the specified column contains the given substring. + + Args: + column: Name of the column on the origin model or the Field on the origin + model class to compare from + value: The value to compare against + + Returns: + A Condition subclass that will be used in the query + """ + value_escaped = value.translate( + str.maketrans({ + # https://cloud.google.com/spanner/docs/functions-and-operators#comparison_operators + '%': r'\%', + '_': r'\_', + '\\': '\\\\', + })) + return ComparisonCondition( + operator='LIKE', + field_or_name=column, + value=f'%{value_escaped}%', + ) + + def equal_to(column: Union[field.Field, str], value: Any) -> EqualityCondition: """Condition where the specified column is equal to the given value. diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py new file mode 100644 index 0000000..74ca4d0 --- /dev/null +++ b/spanner_orm/tests/condition_test.py @@ -0,0 +1,34 @@ +# Lint as: python3 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for spanner_orm.condition.""" + +import logging +import unittest + +import spanner_orm + + +class ConditionTest(unittest.TestCase): + + def test_contains(self): + contains = spanner_orm.contains('some_column', r'a%b_c\d') + self.assertEqual('some_column', contains.column) + self.assertEqual('LIKE', contains.operator) + self.assertEqual(r'%a\%b\_c\\d%', contains.value) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main() From 51cee6980edff542c0c866841156df6a384b1bd8 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Mon, 21 Dec 2020 15:51:50 -0500 Subject: [PATCH 045/131] Make option-like parameters (mostly `transaction`) keyword-only. Note that this change breaks the API, but I think it's worth it. In some cases, e.g., where(), this avoids needing to pass an unnamed `None` argument which I think is confusing. In the case of save_batch()'s `force_write` parameter, I think this can avoid potentially bad bugs. And in all cases, I think this makes the call sites more clear. Additionally: 1. Make the `transaction` parameter optional for all methods. 2. Fix some type annotations on those parameters. --- README.md | 23 ++++---- spanner_orm/admin/metadata.py | 24 +++++---- spanner_orm/admin/migration_executor.py | 3 +- spanner_orm/admin/update.py | 5 +- spanner_orm/model.py | 68 +++++++++++++++--------- spanner_orm/tests/model_test.py | 70 +++++++++++++++---------- spanner_orm/tests/query_test.py | 4 +- 7 files changed, 117 insertions(+), 80 deletions(-) diff --git a/README.md b/README.md index 7195449..acbaaff 100644 --- a/README.md +++ b/README.md @@ -84,23 +84,24 @@ The two main ways of retrieving data through the ORM are ```where()``` and ```find()```/```find_multi()```: ``` python -# where() is invokes on a model class to retrieve models of that tyep. it takes a -# transaction and then a sequence of conditions. -# Most conditions that specify a Field, Index, Relationship, or Model can take -# either the name of the object or the object itself -test_objects = TestModel.where(None, spanner_orm.greater_than('value', '50')) +# where() is invokes on a model class to retrieve models of that type. it takes +# a sequence of conditions. Most conditions that specify a Field, Index, +# Relationship, or Model can take either the name of the object or the object +# itself +test_objects = TestModel.where(spanner_orm.greater_than('value', '50')) # To also retrieve related objects, the includes() condition should be used: -test_and_other_objects = TestModel.where(None, - spanner_orm.greater_than(TestModel.value, '50'), - spanner_orm.includes(TestModel.fake_relationship)) +test_and_other_objects = TestModel.where( + spanner_orm.greater_than(TestModel.value, '50'), + spanner_orm.includes(TestModel.fake_relationship), +) # To create a transaction, run_read_only() or run_write() are used with the # method to be run inside the transaction and any arguments to passs to the method. # The method is invoked with the transaction as the first argument and then the # rest of the provided arguments: def callback_1(transaction, argument): - return TestModel.find(transaction, id=argument) + return TestModel.find(id=argument, transaction=transaction) specific_object = spanner_orm.spanner_api().run_read_only(callback, 1) @@ -108,7 +109,7 @@ specific_object = spanner_orm.spanner_api().run_read_only(callback, 1) # call a bit: @transactional_read def finder(argument, transaction=None): - return TestModel.find(transaction, id=argument) + return TestModel.find(id=argument, transaction=transaction) specific_object = finder(1) ``` @@ -131,7 +132,7 @@ models = [] for i in range(10): key = 'test_{}'.format(i) models.append(TestModel({'key': key, 'value': value})) -TestModel.save_batch(None, models) +TestModel.save_batch(models) ``` ```spanner_orm.spanner_api().run_write()``` can be used for executing read-write diff --git a/spanner_orm/admin/metadata.py b/spanner_orm/admin/metadata.py index 4156878..988fe2b 100644 --- a/spanner_orm/admin/metadata.py +++ b/spanner_orm/admin/metadata.py @@ -71,9 +71,10 @@ def model(cls, table_name) -> Optional[Type[model.Model]]: def tables(cls) -> Dict[str, Dict[str, Any]]: """Compiles table information from column schema.""" column_data = collections.defaultdict(dict) - columns = column.ColumnSchema.where(None, - condition.equal_to('table_catalog', ''), - condition.equal_to('table_schema', '')) + columns = column.ColumnSchema.where( + condition.equal_to('table_catalog', ''), + condition.equal_to('table_schema', ''), + ) for column_row in columns: new_field = field.Field( column_row.field_type(), nullable=column_row.nullable()) @@ -82,9 +83,10 @@ def tables(cls) -> Dict[str, Dict[str, Any]]: column_data[column_row.table_name][column_row.column_name] = new_field table_data = collections.defaultdict(dict) - tables = table.TableSchema.where(None, - condition.equal_to('table_catalog', ''), - condition.equal_to('table_schema', '')) + tables = table.TableSchema.where( + condition.equal_to('table_catalog', ''), + condition.equal_to('table_schema', ''), + ) for table_row in tables: name = table_row.table_name table_data[name]['parent_table'] = table_row.parent_table_name @@ -98,9 +100,10 @@ def indexes(cls) -> Dict[str, Dict[str, Any]]: # Results are ordered by that so the index columns are added in the # correct order. index_column_schemas = index_column.IndexColumnSchema.where( - None, condition.equal_to('table_catalog', ''), + condition.equal_to('table_catalog', ''), condition.equal_to('table_schema', ''), - condition.order_by(('ordinal_position', condition.OrderType.ASC))) + condition.order_by(('ordinal_position', condition.OrderType.ASC)), + ) index_columns = collections.defaultdict(list) storing_columns = collections.defaultdict(list) @@ -112,8 +115,9 @@ def indexes(cls) -> Dict[str, Dict[str, Any]]: storing_columns[key].append(schema.column_name) index_schemas = index_schema.IndexSchema.where( - None, condition.equal_to('table_catalog', ''), - condition.equal_to('table_schema', '')) + condition.equal_to('table_catalog', ''), + condition.equal_to('table_schema', ''), + ) indexes = collections.defaultdict(dict) for schema in index_schemas: key = (schema.table_name, schema.index_name) diff --git a/spanner_orm/admin/migration_executor.py b/spanner_orm/admin/migration_executor.py index 1d9e9d5..3b3b8e2 100644 --- a/spanner_orm/admin/migration_executor.py +++ b/spanner_orm/admin/migration_executor.py @@ -165,8 +165,7 @@ def _update_status(self, migration_id: str, new_status: bool) -> None: 'migrated': new_status, 'update_time': datetime.datetime.utcnow(), }) - migration_status.MigrationStatus.save_batch( - None, [new_model], force_write=True) + migration_status.MigrationStatus.save_batch([new_model], force_write=True) self._migration_status()[migration_id] = new_status def _validate_migrations(self) -> None: diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 4277e10..2347bdc 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -182,8 +182,9 @@ def validate(self) -> None: # Verify no indices exist on the column we're trying to drop num_indexed_columns = index_column.IndexColumnSchema.count( - None, condition.equal_to('column_name', self._column), - condition.equal_to('table_name', self._table)) + condition.equal_to('column_name', self._column), + condition.equal_to('table_name', self._table), + ) if num_indexed_columns > 0: raise error.SpannerError('Column {} is indexed'.format(self._column)) diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 6393383..c4a9821 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -184,6 +184,7 @@ def spanner_api(cls) -> api.SpannerApi: @classmethod def all( cls: Type[T], + *, transaction: Optional[spanner_transaction.Transaction] = None, ) -> List[T]: """Returns all objects of this type stored in Spanner. @@ -206,17 +207,17 @@ def all( @classmethod def count( cls, - transaction: Optional[spanner_transaction.Transaction], *conditions: condition.Condition, + transaction: Optional[spanner_transaction.Transaction] = None, ) -> int: """Returns the number of objects in Spanner that match the given conditions. Args: - transaction: The existing transaction to use, or None to start a new - transaction *conditions: Instances of subclasses of Condition that help specify which rows should be included in the count. The includes condition is not allowed here + transaction: The existing transaction to use, or None to start a new + transaction Returns: The integer result of the COUNT query @@ -229,6 +230,7 @@ def count( @classmethod def count_equal( cls, + *, transaction: Optional[spanner_transaction.Transaction] = None, **constraints: Any, ) -> int: @@ -253,11 +255,12 @@ def count_equal( conditions.append(condition.in_list(column, value)) else: conditions.append(condition.equal_to(column, value)) - return cls.count(transaction, *conditions) + return cls.count(*conditions, transaction=transaction) @classmethod def find( cls: Type[T], + *, transaction: Optional[spanner_transaction.Transaction] = None, **keys: Any, ) -> Optional[T]: @@ -273,23 +276,24 @@ def find( Returns: The requested object or None if no such object exists """ - resources = cls.find_multi(transaction, [keys]) + resources = cls.find_multi([keys], transaction=transaction) return resources[0] if resources else None @classmethod def find_multi( cls: Type[T], - transaction: Optional[spanner_transaction.Transaction], keys: Iterable[Dict[str, Any]], + *, + transaction: Optional[spanner_transaction.Transaction] = None, ) -> List[T]: """Retrieves objects from Spanner based on the provided keys. Args: - transaction: The existing transaction to use, or None to start a new - transaction keys: An iterable of dictionaries, each dictionary representing the set of primary key values necessary to uniquely identify an object in this table. + transaction: The existing transaction to use, or None to start a new + transaction Returns: A list containing all requested objects that exist in the table (can be @@ -307,16 +311,16 @@ def find_multi( @classmethod def where( cls: Type[T], - transaction: Optional[spanner_transaction.Transaction], *conditions: condition.Condition, + transaction: Optional[spanner_transaction.Transaction] = None, ) -> List[T]: """Retrieves objects from Spanner based on the provided conditions. Args: - transaction: The existing transaction to use, or None to start a new - transaction *conditions: Instances of subclasses of Condition that help specify which objects should be retrieved + transaction: The existing transaction to use, or None to start a new + transaction Returns: A list containing all requested objects that exist in the table (can be @@ -330,6 +334,7 @@ def where( @classmethod def where_equal( cls: Type[T], + *, transaction: Optional[spanner_transaction.Transaction] = None, **constraints: Any, ) -> List[T]: @@ -352,7 +357,7 @@ def where_equal( conditions.append(condition.in_list(column, value)) else: conditions.append(condition.equal_to(column, value)) - return cls.where(transaction, *conditions) + return cls.where(*conditions, transaction=transaction) @classmethod def _results_to_models( @@ -378,6 +383,7 @@ def _execute_read( @classmethod def create( cls, + *, transaction: Optional[spanner_transaction.Transaction] = None, **kwargs: Any, ) -> None: @@ -397,6 +403,7 @@ def create( @classmethod def create_or_update( cls, + *, transaction: Optional[spanner_transaction.Transaction] = None, **kwargs: Any, ) -> None: @@ -418,15 +425,16 @@ def _delete_by_keyset( @classmethod def delete_batch( cls: Type[T], - transaction: Optional[spanner_transaction.Transaction], models: List[T], + *, + transaction: Optional[spanner_transaction.Transaction] = None, ) -> None: """Deletes rows from Spanner based on the provided models' primary keys. Args: + models: A list of models to be deleted from Spanner. transaction: The existing transaction to use, or None to start a new transaction - models: A list of models to be deleted from Spanner. """ key_list = [] for model in models: @@ -439,6 +447,7 @@ def delete_batch( @classmethod def delete_by_key( cls, + *, transaction: Optional[spanner_transaction.Transaction] = None, **keys: Any, ) -> None: @@ -460,20 +469,21 @@ def delete_by_key( @classmethod def save_batch( cls: Type[T], - transaction: Optional[spanner_transaction.Transaction], models: List[T], + *, + transaction: Optional[spanner_transaction.Transaction] = None, force_write: bool = False, ) -> None: """Writes rows to Spanner based on the provided model data. Args: - transaction: The existing transaction to use, or None to start a new - transaction models: A list of models to be written to Spanner. If the _persisted flag is set, by default we try to issue an UPDATE with values set for all columns in the table. Otherwise, we try to issue an INSERT for all columns in the table. If we try to INSERTa row that already exists (or update one that is missing), an exception will be thrown. + transaction: The existing transaction to use, or None to start a new + transaction force_write: If true, we use UPSERT instead of UPDATE/INSERT, so no exceptions are thrown based on the presence or absence of data in Spanner @@ -495,6 +505,7 @@ def save_batch( @classmethod def update( cls, + *, transaction: Optional[spanner_transaction.Transaction] = None, **kwargs: Any, ) -> None: @@ -596,7 +607,11 @@ def changes(self) -> Dict[str, Any]: if values[key] != self.start_values.get(key) } - def delete(self, transaction: spanner_transaction.Transaction = None) -> None: + def delete( + self, + *, + transaction: Optional[spanner_transaction.Transaction] = None, + ) -> None: """Deletes this object from the Spanner database. Args: @@ -625,7 +640,9 @@ def id(self) -> Dict[str, Any]: def reload( self, - transaction: spanner_transaction.Transaction = None) -> Optional['Model']: + *, + transaction: Optional[spanner_transaction.Transaction] = None, + ) -> Optional['Model']: """Refreshes this object with information from Spanner. Args: @@ -637,7 +654,7 @@ def reload( in Spanner, or None if no information was found (object was deleted or never was persisted) """ - updated_object = self._metaclass.find(transaction, **self.id()) + updated_object = self._metaclass.find(transaction=transaction, **self.id()) if updated_object is None: return None start_values = {} @@ -652,8 +669,11 @@ def reload( self._persisted = True return self - def save(self, - transaction: spanner_transaction.Transaction = None) -> 'Model': + def save( + self, + *, + transaction: Optional[spanner_transaction.Transaction] = None, + ) -> 'Model': """Persists this object to Spanner. Note: if the _persisted flag doesn't match whether this object is actually @@ -671,8 +691,8 @@ def save(self, changed_values = self.changes() if changed_values: changed_values.update(self.id()) - self._metaclass.update(transaction, **changed_values) + self._metaclass.update(transaction=transaction, **changed_values) else: - self._metaclass.create(transaction, **self.values) + self._metaclass.create(transaction=transaction, **self.values) self._persisted = True return self diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index efbbef6..36ecb53 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -33,7 +33,11 @@ class ModelTest(parameterized.TestCase): def test_find_calls_api(self, find): mock_transaction = mock.Mock() models.UnittestModel.find( - mock_transaction, string='string', int_=1, float_=2.3) + string='string', + int_=1, + float_=2.3, + transaction=mock_transaction, + ) find.assert_called_once() (transaction, table, columns, keyset), _ = find.call_args @@ -47,7 +51,7 @@ def test_find_result(self, find): mock_transaction = mock.Mock() find.return_value = [['key', 'value_1', None]] - result = models.SmallTestModel.find(mock_transaction, key='key') + result = models.SmallTestModel.find(key='key', transaction=mock_transaction) if result: self.assertEqual(result.key, 'key') self.assertEqual(result.value_1, 'value_1') @@ -58,11 +62,14 @@ def test_find_result(self, find): @mock.patch('spanner_orm.table_apis.find') def test_find_multi_calls_api(self, find): mock_transaction = mock.Mock() - models.UnittestModel.find_multi(mock_transaction, [{ - 'string': 'string', - 'int_': 1, - 'float_': 2.3 - }]) + models.UnittestModel.find_multi( + [{ + 'string': 'string', + 'int_': 1, + 'float_': 2.3 + }], + transaction=mock_transaction, + ) find.assert_called_once() (transaction, table, columns, keyset), _ = find.call_args @@ -75,9 +82,12 @@ def test_find_multi_calls_api(self, find): def test_find_multi_result(self, find): mock_transaction = mock.Mock() find.return_value = [['key', 'value_1', None]] - results = models.SmallTestModel.find_multi(mock_transaction, [{ - 'key': 'key' - }]) + results = models.SmallTestModel.find_multi( + [{ + 'key': 'key' + }], + transaction=mock_transaction, + ) self.assertEqual(results[0].key, 'key') self.assertEqual(results[0].value_1, 'value_1') @@ -86,7 +96,11 @@ def test_find_multi_result(self, find): @mock.patch('spanner_orm.table_apis.insert') def test_create_calls_api(self, insert): mock_transaction = mock.Mock() - models.SmallTestModel.create(mock_transaction, key='key', value_1='value') + models.SmallTestModel.create( + key='key', + value_1='value', + transaction=mock_transaction, + ) insert.assert_called_once() (transaction, table, columns, values), _ = insert.call_args @@ -112,7 +126,8 @@ def test_save_batch_inserts(self, insert): mock_transaction = mock.Mock() values = {'key': 'key', 'value_1': 'value'} not_persisted = models.SmallTestModel(values) - models.SmallTestModel.save_batch(mock_transaction, [not_persisted]) + models.SmallTestModel.save_batch([not_persisted], + transaction=mock_transaction) self.assert_api_called(insert, mock_transaction) @mock.patch('spanner_orm.table_apis.update') @@ -120,7 +135,7 @@ def test_save_batch_updates(self, update): mock_transaction = mock.Mock() values = {'key': 'key', 'value_1': 'value'} persisted = models.SmallTestModel(values, persisted=True) - models.SmallTestModel.save_batch(mock_transaction, [persisted]) + models.SmallTestModel.save_batch([persisted], transaction=mock_transaction) self.assert_api_called(update, mock_transaction) @@ -130,7 +145,10 @@ def test_save_batch_force_write_upserts(self, upsert): values = {'key': 'key', 'value_1': 'value'} not_persisted = models.SmallTestModel(values) models.SmallTestModel.save_batch( - mock_transaction, [not_persisted], force_write=True) + [not_persisted], + force_write=True, + transaction=mock_transaction, + ) self.assert_api_called(upsert, mock_transaction) @mock.patch('spanner_orm.table_apis.delete') @@ -138,7 +156,7 @@ def test_delete_batch_deletes(self, delete): mock_transaction = mock.Mock() values = {'key': 'key', 'value_1': 'value'} model = models.SmallTestModel(values) - models.SmallTestModel.delete_batch(mock_transaction, [model]) + models.SmallTestModel.delete_batch([model], transaction=mock_transaction) delete.assert_called_once() (transaction, table, keyset), _ = delete.call_args @@ -149,7 +167,10 @@ def test_delete_batch_deletes(self, delete): @mock.patch('spanner_orm.table_apis.delete') def test_delete_by_key_deletes(self, delete): mock_transaction = mock.Mock() - models.SmallTestModel.delete_by_key(mock_transaction, key='some-key') + models.SmallTestModel.delete_by_key( + key='some-key', + transaction=mock_transaction, + ) delete.assert_called_once_with( mock_transaction, models.SmallTestModel.table, @@ -343,10 +364,7 @@ def test_reload(self, find): find.return_value = None self.assertIsNone(model.reload()) - find.assert_called_once() - (transaction,), kwargs = find.call_args - self.assertIsNone(transaction) - self.assertEqual(kwargs, model.id()) + find.assert_called_once_with(**model.id(), transaction=None) @mock.patch('spanner_orm.model.Model.find') def test_reload_reloads(self, find): @@ -365,10 +383,7 @@ def test_save_creates(self, create): model = models.SmallTestModel(values, persisted=False) model.save() - create.assert_called_once() - (transaction,), kwargs = create.call_args - self.assertIsNone(transaction) - self.assertEqual(kwargs, {**values, 'value_2': None}) + create.assert_called_once_with(**values, value_2=None, transaction=None) @mock.patch('spanner_orm.model.Model.update') def test_save_updates(self, update): @@ -379,10 +394,7 @@ def test_save_updates(self, update): model.value_1 = values['value_1'] model.save() - update.assert_called_once() - (transaction,), kwargs = update.call_args - self.assertIsNone(transaction) - self.assertEqual(kwargs, values) + update.assert_called_once_with(**values, transaction=None) @mock.patch('spanner_orm.model.Model.update') def test_save_no_changes(self, update): @@ -396,7 +408,7 @@ def test_delete_deletes(self, delete): mock_transaction = mock.Mock() values = {'key': 'key', 'value_1': 'value_1'} model = models.SmallTestModel(values) - model.delete(mock_transaction) + model.delete(transaction=mock_transaction) delete.assert_called_once() (transaction, table, keyset), _ = delete.call_args diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index a2ff64d..c445eb7 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -37,7 +37,7 @@ class QueryTest(parameterized.TestCase): def test_where(self, sql_query): sql_query.return_value = [] - models.UnittestModel.where_equal(True, int_=3) + models.UnittestModel.where_equal(int_=3, transaction=True) (_, sql, parameters, types), _ = sql_query.call_args expected_sql = 'SELECT .* FROM table WHERE table.int_ = @int_0' @@ -49,7 +49,7 @@ def test_where(self, sql_query): def test_count(self, sql_query): sql_query.return_value = [[0]] column, value = 'int_', 3 - models.UnittestModel.count_equal(True, int_=3) + models.UnittestModel.count_equal(int_=3, transaction=True) (_, sql, parameters, types), _ = sql_query.call_args column_key = '{}0'.format(column) From e8834c1055133b5d1858605005c1cd7d431a632c Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 23 Dec 2020 18:21:36 -0500 Subject: [PATCH 046/131] Raise an error if a CreateTable update has secondary indexes. The current CreateTable DDL doesn't support them, and I don't see any way to create a table and its secondary indexes in a single DDL statement: https://cloud.google.com/spanner/docs/data-definition-language#create_table --- spanner_orm/admin/update.py | 12 +++++++++--- spanner_orm/tests/models.py | 13 +++++++++++++ spanner_orm/tests/update_test.py | 13 ++++++++++++- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 2347bdc..89074c5 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -20,6 +20,7 @@ from spanner_orm import condition from spanner_orm import error from spanner_orm import field +from spanner_orm import index from spanner_orm import model from spanner_orm.admin import api from spanner_orm.admin import index_column @@ -76,6 +77,11 @@ def validate(self) -> None: self._validate_primary_keys() + if self._model.indexes.keys() - {index.Index.PRIMARY_INDEX}: + raise error.SpannerError( + 'Secondary indexes cannot be created by CreateTable; use CreateIndex ' + 'in a separate migration.') + def _validate_parent(self) -> None: """Verifies that the parent table information is valid.""" parent_primary_keys = self._model.interleaved.primary_keys @@ -129,10 +135,10 @@ def _validate_not_interleaved(self, if model_.interleaved == existing_model: raise error.SpannerError('Table {} has interleaved table {}'.format( self._table, model_.table)) - for index in model_.indexes.values(): - if index.parent == self._table: + for index_ in model_.indexes.values(): + if index_.parent == self._table: raise error.SpannerError('Table {} has interleaved index {}'.format( - self._table, index.name)) + self._table, index_.name)) class AddColumn(SchemaUpdate): diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 9aef1ed..a8dbd87 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -81,3 +81,16 @@ class UnittestModel(model.Model): string_array = field.Field(field.StringArray, nullable=True) test_index = index.Index(['string_2']) + +class UnittestModelWithoutSecondaryIndexes(model.Model): + """Same as UnittestModel, but with no secondary indexes.""" + + __table__ = 'table' + int_ = field.Field(field.Integer, primary_key=True) + int_2 = field.Field(field.Integer, nullable=True) + float_ = field.Field(field.Float, primary_key=True) + float_2 = field.Field(field.Float, nullable=True) + string = field.Field(field.String, primary_key=True) + string_2 = field.Field(field.String, nullable=True) + timestamp = field.Field(field.Timestamp) + string_array = field.Field(field.StringArray, nullable=True) diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index 633b0c3..0b0fb07 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -61,7 +61,7 @@ def test_drop_column_error_on_primary_key(self, get_model, index_count): @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_create_table(self, get_model): get_model.return_value = None - new_model = models.UnittestModel + new_model = models.UnittestModelWithoutSecondaryIndexes test_update = update.CreateTable(new_model) test_update.validate() @@ -94,6 +94,17 @@ def test_create_table_error_on_existing_table(self, get_model): with self.assertRaisesRegex(error.SpannerError, 'already exists'): test_update.validate() + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') + def test_create_table_error_on_table_with_index(self, get_model): + get_model.return_value = None + new_model = models.IndexTestModel + test_update = update.CreateTable(new_model) + with self.assertRaisesRegex( + error.SpannerError, + 'indexes cannot be created', + ): + test_update.validate() + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.indexes') @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.tables') @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') From cf5a9ab1a33a1c5f382a23274cf423dd8a2a62dc Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 23 Dec 2020 18:48:27 -0500 Subject: [PATCH 047/131] Add support for NULL_FILTERED and UNIQUE in CreateIndex. --- spanner_orm/admin/update.py | 13 ++++++++++-- spanner_orm/tests/update_test.py | 34 ++++++++++++++++++++++++++------ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 89074c5..659fbc9 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -240,16 +240,25 @@ def __init__(self, index_name: str, columns: Iterable[str], interleaved: Optional[str] = None, + null_filtered: bool = False, + unique: bool = False, storing_columns: Optional[Iterable[str]] = None): self._table = table_name self._index = index_name self._columns = columns self._parent_table = interleaved + self._null_filtered = null_filtered + self._unique = unique self._storing_columns = storing_columns or [] def ddl(self) -> str: - statement = 'CREATE INDEX {} ON {} ({})'.format(self._index, self._table, - ', '.join(self._columns)) + statement = 'CREATE' + if self._unique: + statement += ' UNIQUE' + if self._null_filtered: + statement += ' NULL_FILTERED' + statement += (f' INDEX {self._index} ' + f'ON {self._table} ({", ".join(self._columns)})') if self._storing_columns: statement += 'STORING ({})'.format(', '.join(self._storing_columns)) if self._parent_table: diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index 0b0fb07..855e449 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -16,13 +16,15 @@ import unittest from unittest import mock +from absl.testing import parameterized + from spanner_orm import error from spanner_orm import field from spanner_orm.admin import update from spanner_orm.tests import models -class UpdateTest(unittest.TestCase): +class UpdateTest(parameterized.TestCase): @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_add_column(self, get_model): @@ -118,15 +120,35 @@ def test_drop_table(self, get_model, tables, indexes): test_update.validate() self.assertEqual(test_update.ddl(), 'DROP TABLE {}'.format(table_name)) + @parameterized.named_parameters( + ( + 'basic', + update.CreateIndex( + table_name=models.SmallTestModel.table, + index_name='foo', + columns=['value_1'], + ), + f'CREATE INDEX foo ON {models.SmallTestModel.table} (value_1)', + ), + ( + 'with_options', + update.CreateIndex( + table_name=models.SmallTestModel.table, + index_name='foo', + columns=['value_1'], + null_filtered=True, + unique=True, + ), + (f'CREATE UNIQUE NULL_FILTERED INDEX foo ' + f'ON {models.SmallTestModel.table} (value_1)'), + ), + ) @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') - def test_add_index(self, get_model): - table_name = models.SmallTestModel.table + def test_add_index(self, test_update, expected_ddl, get_model): get_model.return_value = models.SmallTestModel - test_update = update.CreateIndex(table_name, 'foo', ['value_1']) test_update.validate() - self.assertEqual(test_update.ddl(), - 'CREATE INDEX foo ON {} (value_1)'.format(table_name)) + self.assertEqual(test_update.ddl(), expected_ddl) if __name__ == '__main__': From f04b2cf3259020829e26cc8cdfb2d74745aba1ad Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 30 Dec 2020 11:45:18 -0500 Subject: [PATCH 048/131] Export more of condition.py in the top-level __init__.py. --- spanner_orm/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 9b256c6..122f7aa 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -68,13 +68,17 @@ StringArray = field.StringArray Timestamp = field.Timestamp +Condition = condition.Condition +ORDER_ASC = condition.OrderType.ASC +ORDER_DESC = condition.OrderType.DESC +columns_equal = condition.columns_equal contains = condition.contains equal_to = condition.equal_to force_index = condition.force_index greater_than = condition.greater_than greater_than_or_equal_to = condition.greater_than_or_equal_to -includes = condition.includes in_list = condition.in_list +includes = condition.includes less_than = condition.less_than less_than_or_equal_to = condition.less_than_or_equal_to limit = condition.limit @@ -82,9 +86,8 @@ not_greater_than = condition.not_greater_than not_in_list = condition.not_in_list not_less_than = condition.not_less_than +or_ = condition.or_ order_by = condition.order_by -ORDER_ASC = condition.OrderType.ASC -ORDER_DESC = condition.OrderType.DESC transactional_read = decorator.transactional_read transactional_write = decorator.transactional_write From f0941d4abb5d381e4829f53908e56b4bb1722dc9 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 30 Dec 2020 13:10:52 -0500 Subject: [PATCH 049/131] Add Condition for aritrary SQL. I think this will make it significantly easier to write new conditions, including a STARTS_WITH() condition I want to add. I think it could also simplify a bunch of the code in the existing conditions. --- setup.py | 2 +- spanner_orm/condition.py | 96 +++++++++++++++++++++- spanner_orm/tests/condition_test.py | 122 +++++++++++++++++++++++++++- 3 files changed, 217 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 566b38b..5df4db8 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ packages=['spanner_orm', 'spanner_orm.admin'], include_package_data=True, python_requires='~=3.7', - install_requires=['google-cloud-spanner >= 1.6, <2.0.0dev'], + install_requires=['google-cloud-spanner >= 1.6, <2.0.0dev', 'frozendict'], tests_require=['absl-py', 'portpicker'], entry_points={ 'console_scripts': ['spanner-orm = spanner_orm.admin.scripts:main'] diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 2b38d06..30a4827 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -15,14 +15,17 @@ """Used with Model#where and Model#count to help create Spanner queries.""" import abc +import dataclasses import enum -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union +import string +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union from spanner_orm import error from spanner_orm import field from spanner_orm import index from spanner_orm import relationship +import frozendict from google.cloud.spanner_v1.proto import type_pb2 @@ -115,6 +118,97 @@ def _validate(self, model_class: Type[Any]) -> None: raise NotImplementedError +@dataclasses.dataclass +class Param: + """Parameter for substitution into a SQL query.""" + value: Any + type: type_pb2.Type + + +@dataclasses.dataclass +class Column: + """Named column; consider using field.Field instead.""" + name: str + + +# Something that can be substituted into a SQL query. +Substitution = Union[Param, field.Field, Column] + + +class ArbitraryCondition(Condition): + """Condition with support for arbitrary SQL.""" + + def __init__( + self, + sql_template: str, + substitutions: Mapping[str, Substitution] = frozendict.frozendict(), + *, + segment: Segment, + ): + """Initializer. + + Args: + sql_template: string.Template-compatible template string for the SQL. + substitutions: Substitutions to make in sql_template. + segment: Segment for this Condition. + """ + super().__init__() + self._sql_template = string.Template(sql_template) + self._substitutions = substitutions + self._segment = segment + + # This validates the template. + self._sql_template.substitute({k: '' for k in self._substitutions}) + + def segment(self) -> Segment: + """See base class.""" + return self._segment + + def _validate(self, model_class: Type[Any]) -> None: + """See base class.""" + for substitution in self._substitutions.values(): + if isinstance(substitution, field.Field): + if substitution not in model_class.fields.values(): + raise error.ValidationError( + f'Field {substitution.name!r} does not belong to the Model for ' + f'table {model_class.table!r}.') + elif isinstance(substitution, Column): + if substitution.name not in model_class.fields: + raise error.ValidationError( + f'Column {substitution.name!r} does not exist in the Model for ' + f'table {model_class.table!r}.') + + def _params(self) -> Dict[str, Any]: + """See base class.""" + return { + self.key(k): v.value + for k, v in self._substitutions.items() + if isinstance(v, Param) + } + + def _types(self) -> Dict[str, type_pb2.Type]: + """See base class.""" + return { + self.key(k): v.type + for k, v in self._substitutions.items() + if isinstance(v, Param) + } + + def _sql_for_substitution(self, key: str, substitution: Substitution) -> str: + if isinstance(substitution, Param): + return f'@{self.key(key)}' + else: + assert isinstance(substitution, (field.Field, Column)) + return f'{self.model_class.column_prefix}.{substitution.name}' + + def _sql(self) -> str: + """See base class.""" + return self._sql_template.substitute({ + k: self._sql_for_substitution(k, v) + for k, v in self._substitutions.items() + }) + + class ColumnsEqualCondition(Condition): """Used to join records by matching column values.""" diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py index 74ca4d0..b298d31 100644 --- a/spanner_orm/tests/condition_test.py +++ b/spanner_orm/tests/condition_test.py @@ -15,12 +15,132 @@ """Tests for spanner_orm.condition.""" import logging +import os import unittest +from absl.testing import parameterized +from google.cloud.spanner_v1.proto import type_pb2 + import spanner_orm +from spanner_orm import condition +from spanner_orm import error +from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib +from spanner_orm.tests import models + + +class ConditionTest( + spanner_emulator_testlib.TestCase, + parameterized.TestCase, +): + + def setUp(self): + super().setUp() + self.run_orm_migrations( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + 'migrations_for_emulator_test', + )) + + @parameterized.named_parameters( + ( + 'minimal', + condition.ArbitraryCondition( + 'FALSE', + segment=condition.Segment.WHERE, + ), + {}, + {}, + 'FALSE', + (), + ), + ( + 'full', + condition.ArbitraryCondition( + '$key = IF($true_param, ${key_param}, $value_1)', + dict( + key=models.SmallTestModel.key, + true_param=condition.Param( + True, + type=type_pb2.Type(code=type_pb2.BOOL), + ), + key_param=condition.Param( + 'some-key', + type=type_pb2.Type(code=type_pb2.STRING), + ), + value_1=condition.Column('value_1'), + ), + segment=condition.Segment.WHERE, + ), + dict( + true_param0=True, + key_param0='some-key', + ), + dict( + true_param0=type_pb2.Type(code=type_pb2.BOOL), + key_param0=type_pb2.Type(code=type_pb2.STRING), + ), + ('SmallTestModel.key = ' + 'IF(@true_param0, @key_param0, SmallTestModel.value_1)'), + ('some-key',), + ), + ) + def test_arbitrary_condition( + self, + condition_, + expected_params, + expected_types, + expected_sql, + expected_row_keys, + ): + models.SmallTestModel( + dict( + key='some-key', + value_1='some-value', + value_2='other-value', + )).save() + rows = models.SmallTestModel.where(condition_) + self.assertEqual(expected_params, condition_.params()) + self.assertEqual(expected_types, condition_.types()) + self.assertEqual(expected_sql, condition_.sql()) + self.assertCountEqual(expected_row_keys, tuple(row.key for row in rows)) + @parameterized.named_parameters( + ('key_not_found', '$not_found', KeyError, 'not_found'), + ('invalid_template', '$', ValueError, 'Invalid placeholder'), + ) + def test_arbitrary_condition_template_error( + self, + template, + error_class, + error_regex, + ): + with self.assertRaisesRegex(error_class, error_regex): + condition.ArbitraryCondition(template, segment=condition.Segment.WHERE) -class ConditionTest(unittest.TestCase): + @parameterized.named_parameters( + ( + 'field_from_wrong_model', + models.ChildTestModel.key, + 'does not belong to the Model', + ), + ( + 'column_not_found', + condition.Column('not_found'), + 'does not exist in the Model', + ), + ) + def test_arbitrary_condition_validation_error( + self, + substitution, + error_regex, + ): + condition_ = condition.ArbitraryCondition( + '$substitution', + dict(substitution=substitution), + segment=condition.Segment.WHERE, + ) + with self.assertRaisesRegex(error.ValidationError, error_regex): + models.SmallTestModel.where(condition_) def test_contains(self): contains = spanner_orm.contains('some_column', r'a%b_c\d') From a20c0caa0477088a4b26e7fb850672e0d62445bb Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 5 Jan 2021 22:26:13 -0500 Subject: [PATCH 050/131] Add a function to find an object or raise NotFound. This should save some uninteresting code at call sites where they can already safely assume the object exists. --- spanner_orm/model.py | 29 ++++++++++++++++++++++++++++ spanner_orm/tests/model_test.py | 34 ++++++++++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/spanner_orm/model.py b/spanner_orm/model.py index c4a9821..676daf6 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -29,6 +29,7 @@ from spanner_orm import relationship from spanner_orm import table_apis +from google.api_core import exceptions from google.cloud import spanner from google.cloud.spanner_v1 import transaction as spanner_transaction @@ -279,6 +280,34 @@ def find( resources = cls.find_multi([keys], transaction=transaction) return resources[0] if resources else None + @classmethod + def find_required( + cls: Type[T], + *, + transaction: Optional[spanner_transaction.Transaction] = None, + **keys: Any, + ) -> T: + """Retrieves an object from Spanner based on the provided key. + + Args: + transaction: The existing transaction to use, or None to start a new + transaction + **keys: The keys provided are the complete set of primary keys for this + table and the corresponding values make up the unique identifier of the + object being retrieved + + Returns: + The requested object. + + Raises: + exceptions.NotFound: The object wasn't found. + """ + result = cls.find(**keys, transaction=transaction) + if result is None: + raise exceptions.NotFound( + f'{cls.__qualname__} has no object with primary key {keys}') + return result + @classmethod def find_multi( cls: Type[T], diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 36ecb53..d2c6254 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -14,20 +14,34 @@ # limitations under the License. import datetime import logging +import os from typing import List import unittest from unittest import mock from absl.testing import parameterized +from google.api_core import exceptions from google.cloud import spanner from spanner_orm import error from spanner_orm import field +from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib from spanner_orm.tests import models _TIMESTAMP = datetime.datetime.now(tz=datetime.timezone.utc) -class ModelTest(parameterized.TestCase): +class ModelTest( + spanner_emulator_testlib.TestCase, + parameterized.TestCase, +): + + def setUp(self): + super().setUp() + self.run_orm_migrations( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + 'migrations_for_emulator_test', + )) @mock.patch('spanner_orm.table_apis.find') def test_find_calls_api(self, find): @@ -59,6 +73,24 @@ def test_find_result(self, find): else: self.fail('Failed to find result') + def test_find_required(self): + test_model = models.SmallTestModel( + dict( + key='some-key', + value_1='foo', + value_2='bar', + )) + test_model.save() + self.assertEqual( + test_model, + models.SmallTestModel.find_required(key='some-key'), + ) + + def test_find_required_not_found(self): + with self.assertRaisesRegex(exceptions.NotFound, + 'SmallTestModel has no object'): + models.SmallTestModel.find_required(key='some-key') + @mock.patch('spanner_orm.table_apis.find') def test_find_multi_calls_api(self, find): mock_transaction = mock.Mock() From 281aebdd066fb6e57d6ae48ed57c87d0f3cc12ee Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 5 Jan 2021 23:06:41 -0500 Subject: [PATCH 051/131] Infer parameter type if possible. This should save some boilerplate. --- spanner_orm/condition.py | 86 +++++++++++++++++++++- spanner_orm/tests/condition_test.py | 107 +++++++++++++++++++++++++--- 2 files changed, 183 insertions(+), 10 deletions(-) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 30a4827..cbdeeba 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -15,7 +15,10 @@ """Used with Model#where and Model#count to help create Spanner queries.""" import abc +import base64 import dataclasses +import datetime +import decimal import enum import string from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union @@ -26,6 +29,7 @@ from spanner_orm import relationship import frozendict +from google.api_core import datetime_helpers from google.cloud.spanner_v1.proto import type_pb2 @@ -118,11 +122,89 @@ def _validate(self, model_class: Type[Any]) -> None: raise NotImplementedError +def _spanner_type_of_python_object(value: Any) -> type_pb2.Type: + """Returns the Cloud Spanner type of the given object. + + Args: + value: Object to guess the type of. + + Raises: + TypeError: The value either doesn't correspond to a valid Cloud Spanner + type, or this function was unable to guess the type. + """ + # See + # https://github.com/googleapis/python-spanner/blob/master/google/cloud/spanner_v1/proto/type.proto + # for the Cloud Spanner types, and + # https://github.com/googleapis/python-spanner/blob/e981adb3157bb06e4cb466ca81d74d85da976754/google/cloud/spanner_v1/_helpers.py#L91-L133 + # for Python types. + if value is None: + raise TypeError( + 'Cannot infer type of None, because any SQL type can be NULL.') + simple_type_code = { + bool: type_pb2.BOOL, + int: type_pb2.INT64, + float: type_pb2.FLOAT64, + datetime_helpers.DatetimeWithNanoseconds: type_pb2.TIMESTAMP, + datetime.datetime: type_pb2.TIMESTAMP, + datetime.date: type_pb2.DATE, + bytes: type_pb2.BYTES, + str: type_pb2.STRING, + decimal.Decimal: type_pb2.NUMERIC, + }.get(type(value)) + if simple_type_code is not None: + return type_pb2.Type(code=simple_type_code) + elif isinstance(value, (list, tuple)): + element_types = tuple( + _spanner_type_of_python_object(item) + for item in value + if item is not None) + unique_element_type_count = len({ + # Protos aren't hashable, so use their serializations. + element_type.SerializeToString(deterministic=True) + for element_type in element_types + }) + if unique_element_type_count == 1: + return type_pb2.Type( + code=type_pb2.ARRAY, + array_element_type=element_types[0], + ) + else: + raise TypeError( + f'Array does not have elements of exactly one type: {value!r}') + else: + raise TypeError('Unknown type: {value!r}') + + @dataclasses.dataclass class Param: - """Parameter for substitution into a SQL query.""" + """Parameter for substitution into a SQL query. + + Attributes: + value: Value of the parameter. + type: Type of the parameter. If unspecified, the type will be guessed from + the value. + """ value: Any - type: type_pb2.Type + type: type_pb2.Type = dataclasses.field(default_factory=type_pb2.Type) + + def __post_init__(self): + if not self.type.code: + self.type = _spanner_type_of_python_object(self.value) + + # BYTES must be base64-encoded, see + # https://github.com/googleapis/python-spanner/blob/87789c939990794bfd91f5300bedc449fd74bd7e/google/cloud/spanner_v1/proto/type.proto#L108-L110 + if (isinstance(self.value, bytes) and + self.type == type_pb2.Type(code=type_pb2.BYTES)): + self.value = base64.b64encode(self.value).decode() + elif (isinstance(self.value, (list, tuple)) and + all(isinstance(x, bytes) for x in self.value if x is not None) and + self.type == type_pb2.Type( + code=type_pb2.ARRAY, + array_element_type=type_pb2.Type(code=type_pb2.BYTES), + )): + self.value = tuple( + None if item is None else base64.b64encode(item).decode() + for item in self.value) @dataclasses.dataclass diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py index b298d31..14eea1d 100644 --- a/spanner_orm/tests/condition_test.py +++ b/spanner_orm/tests/condition_test.py @@ -14,11 +14,14 @@ # limitations under the License. """Tests for spanner_orm.condition.""" +import datetime +import decimal import logging import os import unittest from absl.testing import parameterized +from google.api_core import datetime_helpers from google.cloud.spanner_v1.proto import type_pb2 import spanner_orm @@ -41,6 +44,100 @@ def setUp(self): 'migrations_for_emulator_test', )) + @parameterized.parameters( + (True, type_pb2.Type(code=type_pb2.BOOL)), + (0, type_pb2.Type(code=type_pb2.INT64)), + (0.0, type_pb2.Type(code=type_pb2.FLOAT64)), + ( + datetime_helpers.DatetimeWithNanoseconds(2021, 1, 5), + type_pb2.Type(code=type_pb2.TIMESTAMP), + ), + (datetime.datetime(2021, 1, 5), type_pb2.Type(code=type_pb2.TIMESTAMP)), + (datetime.date(2021, 1, 5), type_pb2.Type(code=type_pb2.DATE)), + (b'\x01', type_pb2.Type(code=type_pb2.BYTES)), + ('foo', type_pb2.Type(code=type_pb2.STRING)), + (decimal.Decimal('1.23'), type_pb2.Type(code=type_pb2.NUMERIC)), + ( + (0, 1), + type_pb2.Type( + code=type_pb2.ARRAY, + array_element_type=type_pb2.Type(code=type_pb2.INT64), + ), + ), + ( + ['a', None, 'b'], + type_pb2.Type( + code=type_pb2.ARRAY, + array_element_type=type_pb2.Type(code=type_pb2.STRING), + ), + ), + ) + def test_param_infers_type(self, value, expected_type): + param = condition.Param(value) + self.assertEqual(expected_type, param.type) + # Test that the value and inferred type are compatible. This will raise an + # exception if they're not. + self.assertEmpty( + models.SmallTestModel.where( + condition.ArbitraryCondition( + '$param IS NULL', + dict(param=param), + segment=condition.Segment.WHERE, + ))) + + @parameterized.parameters( + (None, 'Cannot infer type of None'), + ((0, 'some-string'), 'elements of exactly one type'), + ((0, 'some-string', None), 'elements of exactly one type'), + (object(), 'Unknown type'), + ) + def test_param_infer_type_error(self, value, error_regex): + with self.assertRaisesRegex(TypeError, error_regex): + condition.Param(value) + + def test_param_explicit_type(self): + explicit_type = type_pb2.Type(code=type_pb2.STRING) + self.assertEqual( + explicit_type, + condition.Param(None, type=explicit_type).type, + ) + + @parameterized.named_parameters( + ( + 'bytes', + condition.ArbitraryCondition( + '$param = b"\x01\x02"', + dict(param=condition.Param(b'\x01\x02')), + segment=condition.Segment.WHERE, + ), + ), + ( + 'array_of_bytes', + condition.ArbitraryCondition( + '${param}[OFFSET(0)] = b"\x01\x02"', + dict(param=condition.Param([b'\x01\x02'])), + segment=condition.Segment.WHERE, + ), + ), + ( + 'array_of_bytes_and_null', + condition.ArbitraryCondition( + '${param}[OFFSET(0)] IS NULL', + dict(param=condition.Param((None, b'\x01\x02'))), + segment=condition.Segment.WHERE, + ), + ), + ) + def test_param_correctly_encodes(self, tautology): + test_model = models.SmallTestModel( + dict( + key='some-key', + value_1='some-value', + value_2='other-value', + )) + test_model.save() + self.assertCountEqual((test_model,), models.SmallTestModel.where(tautology)) + @parameterized.named_parameters( ( 'minimal', @@ -59,14 +156,8 @@ def setUp(self): '$key = IF($true_param, ${key_param}, $value_1)', dict( key=models.SmallTestModel.key, - true_param=condition.Param( - True, - type=type_pb2.Type(code=type_pb2.BOOL), - ), - key_param=condition.Param( - 'some-key', - type=type_pb2.Type(code=type_pb2.STRING), - ), + true_param=condition.Param(True), + key_param=condition.Param('some-key'), value_1=condition.Column('value_1'), ), segment=condition.Segment.WHERE, From 66cddbcedb29a96173c9543087b7a7a3b1f64d31 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 6 Jan 2021 17:22:59 -0500 Subject: [PATCH 052/131] Make Param type inference explicit. This makes it possible to use a more restrictive type annotation than Any, with the main benefit of excluding Optional[primitive] types. With the old Any, I think there was a risk of code passing Optional[primitive], only testing when the value isn't None, and getting errors in production when the value is None. With this change, the type checker should complain about that code. --- spanner_orm/condition.py | 80 ++++++++++++++++++++--------- spanner_orm/tests/condition_test.py | 27 ++++------ 2 files changed, 65 insertions(+), 42 deletions(-) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index cbdeeba..7dc951d 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -21,7 +21,7 @@ import decimal import enum import string -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, TypeVar, Union from spanner_orm import error from spanner_orm import field @@ -32,6 +32,8 @@ from google.api_core import datetime_helpers from google.cloud.spanner_v1.proto import type_pb2 +T = TypeVar('T') + class Segment(enum.Enum): """The segment of the SQL query that a Condition belongs to.""" @@ -122,15 +124,44 @@ def _validate(self, model_class: Type[Any]) -> None: raise NotImplementedError -def _spanner_type_of_python_object(value: Any) -> type_pb2.Type: +GuessableParamType = Union[ + bool, # + int, # + float, # + datetime_helpers.DatetimeWithNanoseconds, # + datetime.datetime, # + datetime.date, # + bytes, # + str, # + decimal.Decimal, # + # These types technically include List[None] and Tuple[None, ...], but + # those can't be guessed. + List[Optional[bool]], # + List[Optional[int]], # + List[Optional[float]], # + List[Optional[datetime_helpers.DatetimeWithNanoseconds]], # + List[Optional[datetime.datetime]], # + List[Optional[datetime.date]], # + List[Optional[bytes]], # + List[Optional[str]], # + List[Optional[decimal.Decimal]], # + Tuple[Optional[bool], ...], # + Tuple[Optional[int], ...], # + Tuple[Optional[float], ...], # + Tuple[Optional[datetime_helpers.DatetimeWithNanoseconds], ...], # + Tuple[Optional[datetime.datetime], ...], # + Tuple[Optional[datetime.date], ...], # + Tuple[Optional[bytes], ...], # + Tuple[Optional[str], ...], # + Tuple[Optional[decimal.Decimal], ...], # +] + + +def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type: """Returns the Cloud Spanner type of the given object. Args: value: Object to guess the type of. - - Raises: - TypeError: The value either doesn't correspond to a valid Cloud Spanner - type, or this function was unable to guess the type. """ # See # https://github.com/googleapis/python-spanner/blob/master/google/cloud/spanner_v1/proto/type.proto @@ -177,34 +208,33 @@ def _spanner_type_of_python_object(value: Any) -> type_pb2.Type: @dataclasses.dataclass class Param: - """Parameter for substitution into a SQL query. - - Attributes: - value: Value of the parameter. - type: Type of the parameter. If unspecified, the type will be guessed from - the value. - """ + """Parameter for substitution into a SQL query.""" value: Any - type: type_pb2.Type = dataclasses.field(default_factory=type_pb2.Type) + type: type_pb2.Type - def __post_init__(self): - if not self.type.code: - self.type = _spanner_type_of_python_object(self.value) + @classmethod + def from_value(cls: Type[T], value: GuessableParamType) -> T: + """Returns a Param with the type guessed from a Python value.""" + guessed_type = _spanner_type_of_python_object(value) # BYTES must be base64-encoded, see # https://github.com/googleapis/python-spanner/blob/87789c939990794bfd91f5300bedc449fd74bd7e/google/cloud/spanner_v1/proto/type.proto#L108-L110 - if (isinstance(self.value, bytes) and - self.type == type_pb2.Type(code=type_pb2.BYTES)): - self.value = base64.b64encode(self.value).decode() - elif (isinstance(self.value, (list, tuple)) and - all(isinstance(x, bytes) for x in self.value if x is not None) and - self.type == type_pb2.Type( + if (isinstance(value, bytes) and + guessed_type == type_pb2.Type(code=type_pb2.BYTES)): + encoded_value = base64.b64encode(value).decode() + elif (isinstance(value, (list, tuple)) and + all(isinstance(x, bytes) for x in value if x is not None) and + guessed_type == type_pb2.Type( code=type_pb2.ARRAY, array_element_type=type_pb2.Type(code=type_pb2.BYTES), )): - self.value = tuple( + encoded_value = tuple( None if item is None else base64.b64encode(item).decode() - for item in self.value) + for item in value) + else: + encoded_value = value + + return cls(value=encoded_value, type=guessed_type) @dataclasses.dataclass diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py index 14eea1d..a3f93bc 100644 --- a/spanner_orm/tests/condition_test.py +++ b/spanner_orm/tests/condition_test.py @@ -72,8 +72,8 @@ def setUp(self): ), ), ) - def test_param_infers_type(self, value, expected_type): - param = condition.Param(value) + def test_param_from_value(self, value, expected_type): + param = condition.Param.from_value(value) self.assertEqual(expected_type, param.type) # Test that the value and inferred type are compatible. This will raise an # exception if they're not. @@ -91,23 +91,16 @@ def test_param_infers_type(self, value, expected_type): ((0, 'some-string', None), 'elements of exactly one type'), (object(), 'Unknown type'), ) - def test_param_infer_type_error(self, value, error_regex): + def test_param_from_value_error(self, value, error_regex): with self.assertRaisesRegex(TypeError, error_regex): - condition.Param(value) - - def test_param_explicit_type(self): - explicit_type = type_pb2.Type(code=type_pb2.STRING) - self.assertEqual( - explicit_type, - condition.Param(None, type=explicit_type).type, - ) + condition.Param.from_value(value) @parameterized.named_parameters( ( 'bytes', condition.ArbitraryCondition( '$param = b"\x01\x02"', - dict(param=condition.Param(b'\x01\x02')), + dict(param=condition.Param.from_value(b'\x01\x02')), segment=condition.Segment.WHERE, ), ), @@ -115,7 +108,7 @@ def test_param_explicit_type(self): 'array_of_bytes', condition.ArbitraryCondition( '${param}[OFFSET(0)] = b"\x01\x02"', - dict(param=condition.Param([b'\x01\x02'])), + dict(param=condition.Param.from_value([b'\x01\x02'])), segment=condition.Segment.WHERE, ), ), @@ -123,12 +116,12 @@ def test_param_explicit_type(self): 'array_of_bytes_and_null', condition.ArbitraryCondition( '${param}[OFFSET(0)] IS NULL', - dict(param=condition.Param((None, b'\x01\x02'))), + dict(param=condition.Param.from_value((None, b'\x01\x02'))), segment=condition.Segment.WHERE, ), ), ) - def test_param_correctly_encodes(self, tautology): + def test_param_from_value_correctly_encodes(self, tautology): test_model = models.SmallTestModel( dict( key='some-key', @@ -156,8 +149,8 @@ def test_param_correctly_encodes(self, tautology): '$key = IF($true_param, ${key_param}, $value_1)', dict( key=models.SmallTestModel.key, - true_param=condition.Param(True), - key_param=condition.Param('some-key'), + true_param=condition.Param.from_value(True), + key_param=condition.Param.from_value('some-key'), value_1=condition.Column('value_1'), ), segment=condition.Segment.WHERE, From afa41c4c4608b077ff3cb74df5c1c93eb489f19e Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 6 Jan 2021 17:57:51 -0500 Subject: [PATCH 053/131] Export things needed to use ArbitraryCondition at the top level. --- spanner_orm/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 122f7aa..e84ebab 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -68,9 +68,13 @@ StringArray = field.StringArray Timestamp = field.Timestamp +ArbitraryCondition = condition.ArbitraryCondition +Column = condition.Column Condition = condition.Condition ORDER_ASC = condition.OrderType.ASC ORDER_DESC = condition.OrderType.DESC +Param = condition.Param +Segment = condition.Segment columns_equal = condition.columns_equal contains = condition.contains equal_to = condition.equal_to From 5d6ba9d743f42d5c7a89ee56b9fffa7ad2210909 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Thu, 7 Jan 2021 13:48:31 -0500 Subject: [PATCH 054/131] Use abspath() instead of realpath() for migrations directories. Resolving symlinks seems to break the tests in some environments, and I can't think of a reason it would be necessary. --- spanner_orm/tests/condition_test.py | 2 +- spanner_orm/tests/migrations_emulator_test.py | 2 +- spanner_orm/tests/model_test.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py index a3f93bc..f78098f 100644 --- a/spanner_orm/tests/condition_test.py +++ b/spanner_orm/tests/condition_test.py @@ -40,7 +40,7 @@ def setUp(self): super().setUp() self.run_orm_migrations( os.path.join( - os.path.dirname(os.path.realpath(__file__)), + os.path.dirname(os.path.abspath(__file__)), 'migrations_for_emulator_test', )) diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py index 3e4f8a1..93ede5d 100644 --- a/spanner_orm/tests/migrations_emulator_test.py +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -23,7 +23,7 @@ class MigrationsEmulatorTest(spanner_emulator_testlib.TestCase): TEST_MIGRATIONS_DIR = os.path.join( - os.path.dirname(os.path.realpath(__file__)), + os.path.dirname(os.path.abspath(__file__)), 'migrations_for_emulator_test', ) diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index d2c6254..b521b34 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -39,7 +39,7 @@ def setUp(self): super().setUp() self.run_orm_migrations( os.path.join( - os.path.dirname(os.path.realpath(__file__)), + os.path.dirname(os.path.abspath(__file__)), 'migrations_for_emulator_test', )) From c054814c89808b2242e476952e60dc58f7706030 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Mon, 11 Jan 2021 17:52:30 -0500 Subject: [PATCH 055/131] Minor formatting fixes. --- spanner_orm/tests/migrations_emulator_test.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py index 93ede5d..7e7a8e2 100644 --- a/spanner_orm/tests/migrations_emulator_test.py +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -16,15 +16,15 @@ import os import unittest - import spanner_orm from spanner_orm.tests import models from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib + class MigrationsEmulatorTest(spanner_emulator_testlib.TestCase): TEST_MIGRATIONS_DIR = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - 'migrations_for_emulator_test', + os.path.dirname(os.path.abspath(__file__)), + 'migrations_for_emulator_test', ) def setUp(self): @@ -35,10 +35,15 @@ def test_basic(self): test_model = models.SmallTestModel({'key': 'key', 'value_1': 'value'}) test_model.save() self.assertEqual( - [x.values for x in models.SmallTestModel.all()], - [{'key': 'key', 'value_1': 'value', 'value_2': None}], + [x.values for x in models.SmallTestModel.all()], + [{ + 'key': 'key', + 'value_1': 'value', + 'value_2': None, + }], ) + if __name__ == '__main__': logging.basicConfig() unittest.main() From e923144afd0eacd1e961f9609c456f14af74545a Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Sun, 17 Jan 2021 11:49:38 -0500 Subject: [PATCH 056/131] Fix merge glitch --- spanner_orm/tests/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 0e2027b..e4ccca7 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -113,4 +113,4 @@ class UnittestModelWithoutSecondaryIndexes(model.Model): string_2 = field.Field(field.String, nullable=True) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) ->>>>>>> foreign_key_data_class + From 92fd4bf706fec2f3b6d8263a6627bebf804790f0 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Sun, 17 Jan 2021 19:23:35 -0500 Subject: [PATCH 057/131] Clarify why uniqueness is enforced --- spanner_orm/foreign_key_relationship.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py index 44b3244..3152343 100644 --- a/spanner_orm/foreign_key_relationship.py +++ b/spanner_orm/foreign_key_relationship.py @@ -80,5 +80,6 @@ def _parse_constraint(self) -> ForeignKeyRelationshipConstraint: @property def single(self) -> bool: # Spanner enforces uniqueness for values of fields referenced by - # foreign keys. + # foreign keys, because it creates a unique index on the referenced + # key. return True From 2ef6dd5629d79d390196b0dccde7e99629ae47a1 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Sun, 17 Jan 2021 21:25:45 -0500 Subject: [PATCH 058/131] Begin clean-up refactoring --- spanner_orm/admin/update.py | 2 +- spanner_orm/condition.py | 15 ++++++++------- spanner_orm/foreign_key_relationship.py | 6 +----- spanner_orm/tests/query_test.py | 19 +++++++++++++++---- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index ad0c961..73f12a3 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -334,7 +334,7 @@ def validate(self) -> None: if db_index.primary: raise error.SpannerError('Index {} is the primary index'.format( self._index)) - + class NoUpdate(SchemaUpdate): """Update that does nothing, for migrations that don't update db schemas.""" diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index f949b7c..2ad0141 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -460,7 +460,14 @@ def conditions(self) -> List[Condition]: raise error.SpannerError( 'Condition must be bound before conditions is called') relation_conditions = [] - if not self.foreign_key_relation: + if self.foreign_key_relation: + for pair in self.relation.constraint.columns.items(): + referencing_column, referenced_column = pair + relation_conditions.append( + ColumnsEqualCondition(referenced_column, self.model_class, + referencing_column)) + + else: for constraint in self.relation.constraints: # This is backward from what you might imagine because the condition # will be processed from the context of the destination model. @@ -468,12 +475,6 @@ def conditions(self) -> List[Condition]: ColumnsEqualCondition(constraint.destination_column, constraint.origin_class, constraint.origin_column)) - else: - for pair in self.relation.constraint.columns.items(): - referencing_column, referenced_column = pair - relation_conditions.append( - ColumnsEqualCondition(referenced_column, self.model_class, - referencing_column)) return relation_conditions + self._conditions @property diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py index 3152343..4ee5f45 100644 --- a/spanner_orm/foreign_key_relationship.py +++ b/spanner_orm/foreign_key_relationship.py @@ -32,8 +32,7 @@ class ForeignKeyRelationship(object): def __init__(self, referenced_table_name: str, - columns: Mapping[str, str], - single: bool = False): + columns: Mapping[str, str]): """Creates a ForeignKeyRelationship. Args: @@ -41,14 +40,11 @@ def __init__(self, columns: Dictionary where the keys are names of columns from the referencing table and the values are the names of the columns in the referenced table. - single: True if the referenced table should be treated as a single object - instead of a list of objects. """ self.origin = None self.name = None self._referenced_table_name = referenced_table_name self._columns = columns - self._single = single @property def constraint(self) -> ForeignKeyRelationshipConstraint: diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index c136220..d2270a0 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -317,17 +317,28 @@ def fk_includes_result(self, related=1): lambda x: x.fk_includes_result(related=1), ), ) - def test_includes_single_related_object_result(self, includes_kwargs, x, y): + def test_includes_single_related_object_result( + self, + includes_kwargs, + referenced_table_fn, + includes_result_fn, + ): select_query = self.includes(**includes_kwargs) - child_values, parent_values, rows = y(self) + child_values, parent_values, rows = includes_result_fn(self) result = select_query.process_results(rows)[0] - self.assertIsInstance(x(result), models.SmallTestModel) + self.assertIsInstance( + referenced_table_fn(result), + models.SmallTestModel, + ) for name, value in child_values.items(): self.assertEqual(getattr(result, name), value) for name, value in parent_values.items(): - self.assertEqual(getattr(x(result), name), value) + self.assertEqual( + getattr(referenced_table_fn(result), name), + value + ) @parameterized.named_parameters( ( From ef0dc61ed765df9a609500837f066df7e105663a Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Sun, 17 Jan 2021 21:32:07 -0500 Subject: [PATCH 059/131] More clean-up refactoring --- spanner_orm/tests/query_test.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index d2270a0..99a9025 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -354,12 +354,17 @@ def test_includes_single_related_object_result( lambda x: x.fk_includes_result(related=0), ), ) - def test_includes_single_no_related_object_result(self, includes_kwargs, x, y): + def test_includes_single_no_related_object_result( + self, + includes_kwargs, + referenced_table_fn, + includes_result_fn + ): select_query = self.includes(**includes_kwargs) - child_values, _, rows = y(self) + child_values, _, rows = includes_result_fn(self) result = select_query.process_results(rows)[0] - self.assertIsNone(x(result)) + self.assertIsNone(referenced_table_fn(result)) for name, value in child_values.items(): self.assertEqual(getattr(result, name), value) @@ -389,9 +394,9 @@ def test_includes_subcondition_result(self): ), ) def test_includes_error_on_multiple_results_for_single( - self, includes_kwargs, x): + self, includes_kwargs, includes_result_fn): select_query = self.includes(**includes_kwargs) - _, _, rows = x(self) + _, _, rows = includes_result_fn(self) with self.assertRaises(error.SpannerError): _ = select_query.process_results(rows) From d7126ec815f6fb8b1a368effdb53f1bdb1a9e5c1 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Sun, 17 Jan 2021 22:38:22 -0500 Subject: [PATCH 060/131] Add tests for foreign_key_relation arg --- spanner_orm/condition.py | 1 + spanner_orm/tests/query_test.py | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 2ad0141..704347f 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -413,6 +413,7 @@ def __init__( foreign_key_relationship.ForeignKeyRelationship, str], conditions: List[Condition] = None, + # Default argument is `False` for backwards-compatability. foreign_key_relation=False, ): """Initializer. diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index 99a9025..08e1d44 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -218,6 +218,17 @@ def includes( [include_condition], ) + @parameterized.parameters( + (models.RelationshipTestModel.parent, True), + (models.ForeignKeyTestModel.foreign_key_1, False) + ) + def test_bad_includes_args(self, relation_key, foreign_key_relation): + with self.assertRaisesRegex(ValueError, 'Must pass'): + self.includes( + relation_key, + foreign_key_relation=foreign_key_relation, + ) + @parameterized.named_parameters( ( 'legacy_relationship', From 64828a1951e5dd3d8ebe75efbe90b35fcc13dcef Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Sun, 17 Jan 2021 22:48:50 -0500 Subject: [PATCH 061/131] Condense validation code --- spanner_orm/condition.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 704347f..46dcbc3 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -513,23 +513,19 @@ def _types(self) -> Dict[str, type_pb2.Type]: def _validate(self, model_class: Type[Any]) -> None: if self.foreign_key_relation: - if self.name not in model_class.foreign_key_relations: - raise error.ValidationError('{} is not a relation on {}'.format( - self.name, model_class.table)) - if self.relation and self.relation != model_class.foreign_key_relations[ - self.name]: - raise error.ValidationError('{} does not belong to {}'.format( - self.relation.name, model_class.table)) - other_model_class = model_class.foreign_key_relations[ - self.name].constraint.referenced_table + model_class_relations = model_class.foreign_key_relations + referenced_table_fn = lambda x: x.constraint.referenced_table else: - if self.name not in model_class.relations: - raise error.ValidationError('{} is not a relation on {}'.format( - self.name, model_class.table)) - if self.relation and self.relation != model_class.relations[self.name]: - raise error.ValidationError('{} does not belong to {}'.format( - self.relation.name, model_class.table)) - other_model_class = model_class.relations[self.name].destination + model_class_relations = model_class.relations + referenced_table_fn = lambda x: x.destination + + if self.name not in model_class_relations: + raise error.ValidationError('{} is not a relation on {}'.format( + self.name, model_class.table)) + if self.relation and self.relation != model_class_relations[self.name]: + raise error.ValidationError('{} does not belong to {}'.format( + self.relation.name, model_class.table)) + other_model_class = referenced_table_fn(model_class_relations[self.name]) for condition in self._conditions: condition._validate(other_model_class) # pylint: disable=protected-access From 30cf2bee6776726752711a8977282172bc5d2a63 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Sun, 17 Jan 2021 22:51:34 -0500 Subject: [PATCH 062/131] Fix minor typos --- spanner_orm/condition.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 46dcbc3..bf243b7 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -521,10 +521,10 @@ def _validate(self, model_class: Type[Any]) -> None: if self.name not in model_class_relations: raise error.ValidationError('{} is not a relation on {}'.format( - self.name, model_class.table)) + self.name, model_class.table)) if self.relation and self.relation != model_class_relations[self.name]: raise error.ValidationError('{} does not belong to {}'.format( - self.relation.name, model_class.table)) + self.relation.name, model_class.table)) other_model_class = referenced_table_fn(model_class_relations[self.name]) for condition in self._conditions: @@ -931,7 +931,7 @@ def includes(relation: Union[relationship.Relationship, associated objects conditions: Conditions to apply on the subquery foreign_key_relation: True if the relation is a foreign key relation, - False if it is a legacy relation (eg not enforced in Spanner) + False if it is a legacy relation (ie not enforced in Spanner) Returns: A Condition subclass that will be used in the query From 4f78dd4f60a9c9fcc042ae6e4fbb1e38c06d1373 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Mon, 25 Jan 2021 20:18:05 -0500 Subject: [PATCH 063/131] Handle degenerate cases in OrCondition. --- spanner_orm/condition.py | 16 ++++--- spanner_orm/tests/condition_test.py | 74 +++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 6 deletions(-) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index bf243b7..3e8aa76 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -583,9 +583,6 @@ class OrCondition(Condition): def __init__(self, *condition_lists: List[Condition]): super().__init__() - if len(condition_lists) < 2: - raise error.SpannerError( - 'OrCondition requires at least two lists of conditions') self.condition_lists = condition_lists self.all_conditions = [] for conditions in condition_lists: @@ -612,9 +609,16 @@ def _sql(self) -> str: params += len(condition.params()) for conditions in self.condition_lists: - new_segment = ' AND '.join([condition.sql() for condition in conditions]) - segments.append('({new_segment})'.format(new_segment=new_segment)) - return '({segments})'.format(segments=' OR '.join(segments)) + if conditions: + new_segment = ' AND '.join( + [condition.sql() for condition in conditions]) + segments.append('({new_segment})'.format(new_segment=new_segment)) + else: + segments.append('TRUE') + if segments: + return '({segments})'.format(segments=' OR '.join(segments)) + else: + return 'FALSE' def segment(self) -> Segment: return Segment.WHERE diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py index f78098f..3abc346 100644 --- a/spanner_orm/tests/condition_test.py +++ b/spanner_orm/tests/condition_test.py @@ -226,6 +226,80 @@ def test_arbitrary_condition_validation_error( with self.assertRaisesRegex(error.ValidationError, error_regex): models.SmallTestModel.where(condition_) + @parameterized.named_parameters( + ( + 'empty_or', + condition.OrCondition(), + {}, + {}, + 'FALSE', + '', + ), + ( + 'empty_and', + condition.OrCondition([]), + {}, + {}, + '(TRUE)', + 'ab', + ), + ( + 'single', + condition.OrCondition( + [condition.equal_to(models.SmallTestModel.key, 'a')]), + dict(key0='a'), + dict(key0=type_pb2.Type(code=type_pb2.STRING)), + '((SmallTestModel.key = @key0))', + 'a', + ), + ( + 'multiple', + condition.OrCondition( + [ + condition.equal_to(models.SmallTestModel.key, 'a'), + condition.equal_to(models.SmallTestModel.value_1, 'a'), + ], + [ + condition.equal_to(models.SmallTestModel.key, 'b'), + condition.equal_to(models.SmallTestModel.value_1, 'b'), + ], + ), + dict( + key0='a', + value_11='a', + key2='b', + value_13='b', + ), + dict( + key0=type_pb2.Type(code=type_pb2.STRING), + value_11=type_pb2.Type(code=type_pb2.STRING), + key2=type_pb2.Type(code=type_pb2.STRING), + value_13=type_pb2.Type(code=type_pb2.STRING), + ), + ('(' + '(SmallTestModel.key = @key0 AND SmallTestModel.value_1 = @value_11)' + ' OR ' + '(SmallTestModel.key = @key2 AND SmallTestModel.value_1 = @value_13)' + ')'), + 'ab', + ), + ) + def test_or_condition( + self, + condition_, + expected_params, + expected_types, + expected_sql, + expected_row_keys, + ): + models.SmallTestModel(dict(key='a', value_1='a', value_2='a')).save() + models.SmallTestModel(dict(key='b', value_1='b', value_2='b')).save() + rows = models.SmallTestModel.where(condition_) + self.assertEqual(expected_params, condition_.params()) + self.assertEqual(expected_types, condition_.types()) + self.assertEqual(expected_sql, condition_.sql()) + self.assertCountEqual(expected_row_keys, tuple(row.key for row in rows)) + def test_contains(self): contains = spanner_orm.contains('some_column', r'a%b_c\d') self.assertEqual('some_column', contains.column) From 16bfe2d7b49ed7250aefd765778c8ce5637756f6 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Fri, 29 Jan 2021 16:10:14 -0500 Subject: [PATCH 064/131] Use YAPF for formatting. --- README.md | 7 + setup.cfg | 2 + spanner_orm/__init__.py | 1 - spanner_orm/admin/scripts.py | 4 +- spanner_orm/api.py | 22 +- spanner_orm/condition.py | 13 +- spanner_orm/foreign_key_relationship.py | 4 +- spanner_orm/metadata.py | 6 +- spanner_orm/model.py | 11 +- .../testlib/spanner_emulator/emulator.py | 6 +- .../testlib/spanner_emulator/testlib.py | 1 - spanner_orm/tests/api_test.py | 13 +- spanner_orm/tests/migrations_emulator_test.py | 24 +- .../create_foreign_key_test_model.py | 11 +- .../create_unittest_model.py | 1 + spanner_orm/tests/migrations_test.py | 38 ++- spanner_orm/tests/models.py | 14 +- spanner_orm/tests/query_test.py | 242 +++++++++--------- 18 files changed, 219 insertions(+), 201 deletions(-) create mode 100644 setup.cfg diff --git a/README.md b/README.md index acbaaff..5eaaa3e 100644 --- a/README.md +++ b/README.md @@ -209,6 +209,13 @@ pip install pytype pytype -V 3.7 spanner_orm -d import-error ``` +To check formatting, run (change `--diff` to `--in-place` to fix formatting): + +``` +pip install yapf +yapf --diff --recursive --parallel . +``` + Then run tests with: ``` diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..db5b80a --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[yapf] +based_on_style = yapf diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index e84ebab..76e0623 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Sets up shortcuts for imports from the library.""" import logging diff --git a/spanner_orm/admin/scripts.py b/spanner_orm/admin/scripts.py index dd445f9..6a1131a 100644 --- a/spanner_orm/admin/scripts.py +++ b/spanner_orm/admin/scripts.py @@ -45,9 +45,7 @@ def main(as_module: bool = False) -> None: # 'subcommand' is actually required, but required subparsers are not supported # for Python < 3.7. subparsers = parser.add_subparsers( - dest='subcommand', - title='subcommands', - description='valid subcommands') + dest='subcommand', title='subcommands', description='valid subcommands') generate_parser = subparsers.add_parser( 'generate', help='Generate a new migration') diff --git a/spanner_orm/api.py b/spanner_orm/api.py index af738da..36d7fa4 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -26,7 +26,9 @@ CallableReturn = TypeVar('CallableReturn') + class SpannerRetryableApi(abc.ABC): + def _ensure_session(self, api_method, *args, **kwargs): try: return api_method(*args, **kwargs) @@ -69,6 +71,7 @@ def _run_read_only(self, method, *args, **kwargs): with self._connection.snapshot(multi_use=True) as snapshot: return method(snapshot, *args, **kwargs) + class SpannerWriteApi(SpannerRetryableApi): """Handles sending write requests to Spanner.""" @@ -95,8 +98,8 @@ def run_write(self, method: Callable[..., CallableReturn], *args: Any, Returns: The return value from `method` will be returned from this method """ - return self._ensure_session( - self._connection.run_in_transaction, method, *args, **kwargs) + return self._ensure_session(self._connection.run_in_transaction, method, + *args, **kwargs) class SpannerConnection: @@ -120,7 +123,8 @@ def __init__(self, def connect(self): """Establish a new connection to the specified Spanner database.""" - client = spanner.Client(project=self._project, credentials=self._credentials) + client = spanner.Client( + project=self._project, credentials=self._credentials) instance = client.instance(self._instance) self.database = instance.database( self._database, pool=self._pool, ddl_statements=self._create_ddl or ()) @@ -140,12 +144,12 @@ def _connection(self): _api = None # type: Optional[SpannerApi] -def connect(instance: str, - database: str, - project: Optional[str] = None, - credentials: Optional[auth_credentials.Credentials] = None, - pool: Optional[spanner_pool.AbstractSessionPool] = None - ) -> SpannerApi: +def connect( + instance: str, + database: str, + project: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + pool: Optional[spanner_pool.AbstractSessionPool] = None) -> SpannerApi: """Connects to the Spanner database and sets the global spanner_api.""" connection = SpannerConnection( instance, database, project=project, credentials=credentials, pool=pool) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 3e8aa76..25661e3 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -431,8 +431,8 @@ def __init__( self.foreign_key_relation = foreign_key_relation if isinstance(relation_or_name, relationship.Relationship): if foreign_key_relation: - raise ValueError( - 'Must pass foreign key relation if ''`foreign_key_relation=True`.') + raise ValueError('Must pass foreign key relation if ' + '`foreign_key_relation=True`.') self.name = relation_or_name.name self.relation = relation_or_name elif isinstance(relation_or_name, @@ -465,9 +465,9 @@ def conditions(self) -> List[Condition]: for pair in self.relation.constraint.columns.items(): referencing_column, referenced_column = pair relation_conditions.append( - ColumnsEqualCondition(referenced_column, self.model_class, - referencing_column)) - + ColumnsEqualCondition(referenced_column, self.model_class, + referencing_column)) + else: for constraint in self.relation.constraints: # This is backward from what you might imagine because the condition @@ -940,8 +940,7 @@ def includes(relation: Union[relationship.Relationship, Returns: A Condition subclass that will be used in the query """ - return IncludesCondition( - relation, conditions, foreign_key_relation) + return IncludesCondition(relation, conditions, foreign_key_relation) def in_list(column: Union[field.Field, str], diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py index 4ee5f45..8e1cb8d 100644 --- a/spanner_orm/foreign_key_relationship.py +++ b/spanner_orm/foreign_key_relationship.py @@ -30,9 +30,7 @@ class ForeignKeyRelationshipConstraint: class ForeignKeyRelationship(object): """Helps define a foreign key relationship between two models.""" - def __init__(self, - referenced_table_name: str, - columns: Mapping[str, str]): + def __init__(self, referenced_table_name: str, columns: Mapping[str, str]): """Creates a ForeignKeyRelationship. Args: diff --git a/spanner_orm/metadata.py b/spanner_orm/metadata.py index 251dfc2..52075ff 100644 --- a/spanner_orm/metadata.py +++ b/spanner_orm/metadata.py @@ -45,11 +45,9 @@ def __init__(self, table: Optional[str] = None, fields: Optional[Dict[str, field.Field]] = None, relations: Optional[Dict[str, relationship.Relationship]] = None, - foreign_key_relations: Optional[ - Dict[ + foreign_key_relations: Optional[Dict[ str, - foreign_key_relationship.ForeignKeyRelationship] - ] = None, + foreign_key_relationship.ForeignKeyRelationship]] = None, indexes: Optional[Dict[str, index.Index]] = None, interleaved: Optional[str] = None, model_class: Optional[Type[Any]] = None): diff --git a/spanner_orm/model.py b/spanner_orm/model.py index c235aa0..e4f2d48 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -80,12 +80,9 @@ def __new__(mcs, name: str, bases: Any, attrs: Dict[str, Any], **kwargs: Any): return cls def __getattr__( - cls, - name: str) -> Union[ - field.Field, - relationship.Relationship, - foreign_key_relationship.ForeignKeyRelationship, - index.Index]: + cls, name: str + ) -> Union[field.Field, relationship.Relationship, + foreign_key_relationship.ForeignKeyRelationship, index.Index]: # Unclear why pylint doesn't like this # pylint: disable=unsupported-membership-test if name in cls.fields: @@ -131,7 +128,6 @@ def foreign_key_relations( cls) -> Dict[str, foreign_key_relationship.ForeignKeyRelationship]: return cls.meta.foreign_key_relations - @property def fields(cls) -> Dict[str, field.Field]: return cls.meta.fields @@ -603,7 +599,6 @@ def _execute_write( else: return cls.spanner_api().run_write(db_api, *args) - def __setattr__(self, name: str, value: Any) -> None: if name in self._relations: raise AttributeError(name) diff --git a/spanner_orm/testlib/spanner_emulator/emulator.py b/spanner_orm/testlib/spanner_emulator/emulator.py index cd29669..f624920 100644 --- a/spanner_orm/testlib/spanner_emulator/emulator.py +++ b/spanner_orm/testlib/spanner_emulator/emulator.py @@ -95,9 +95,9 @@ def _start(self) -> None: emulator_binary_path = os.environ[_EMULATOR_BINARY_PATH_ENV_VAR] except KeyError as key_error: raise ValueError( - f'Please set the environment variable {_EMULATOR_BINARY_PATH_ENV_VAR} ' - 'to a binary with the Cloud Spanner Emulator. For more info, see ' - 'https://github.com/GoogleCloudPlatform/cloud-spanner-emulator.' + f'Please set the environment variable {_EMULATOR_BINARY_PATH_ENV_VAR} ' + 'to a binary with the Cloud Spanner Emulator. For more info, see ' + 'https://github.com/GoogleCloudPlatform/cloud-spanner-emulator.' ) from key_error self._process = subprocess.Popen([ diff --git a/spanner_orm/testlib/spanner_emulator/testlib.py b/spanner_orm/testlib/spanner_emulator/testlib.py index aac4049..1203ba1 100644 --- a/spanner_orm/testlib/spanner_emulator/testlib.py +++ b/spanner_orm/testlib/spanner_emulator/testlib.py @@ -127,4 +127,3 @@ def run_orm_migrations(self, migrations_folder: str) -> None: _make_emulator_spanner_orm_connection(self.spanner_emulator_database, self.spanner_emulator_instance, self.spanner_emulator_client)) - diff --git a/spanner_orm/tests/api_test.py b/spanner_orm/tests/api_test.py index 6caf3a0..a1abb25 100644 --- a/spanner_orm/tests/api_test.py +++ b/spanner_orm/tests/api_test.py @@ -28,6 +28,7 @@ def _mock_run_in_transaction(method, *args, **kwargs): class MockSpannerApi(api.SpannerReadApi, api.SpannerWriteApi): + def __init__(self): self.connection_mock = mock.MagicMock() self.connection_mock.run_in_transaction.side_effect = _mock_run_in_transaction @@ -36,6 +37,7 @@ def __init__(self): def _connection(self): return self.connection_mock + class ApiTest(parameterized.TestCase): @mock.patch('google.cloud.spanner.Client') @@ -70,14 +72,13 @@ def test_admin_api_create_ddl_connection(self, client): @parameterized.parameters('run_read_only', 'run_write') @mock.patch('spanner_orm.api.spanner_api') - def test_reconnect_on_expected_error(self, api_method, - mock_spanner_api): + def test_reconnect_on_expected_error(self, api_method, mock_spanner_api): mock_api = MockSpannerApi() mock_method = mock.Mock() mock_method.side_effect = [ - exceptions.NotFound('Session not found'), - 'Anything other than an exception' + exceptions.NotFound('Session not found'), + 'Anything other than an exception' ] mock_connect = mock_spanner_api.return_value.connect @@ -88,8 +89,7 @@ def test_reconnect_on_expected_error(self, api_method, @parameterized.parameters('run_read_only', 'run_write') @mock.patch('spanner_orm.api.spanner_api') - def test_raise_on_expected_error(self, api_method, - mock_spanner_api): + def test_raise_on_expected_error(self, api_method, mock_spanner_api): mock_api = MockSpannerApi() mock_method = mock.Mock() @@ -105,6 +105,7 @@ def mock_connection(self, client): client().instance().database.return_value = connection return connection + if __name__ == '__main__': logging.basicConfig() unittest.main() diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py index c09a423..de26917 100644 --- a/spanner_orm/tests/migrations_emulator_test.py +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -24,7 +24,6 @@ from google.api_core import exceptions as google_api_exceptions - class MigrationsEmulatorTest(spanner_emulator_testlib.TestCase): TEST_MIGRATIONS_DIR = os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -52,26 +51,27 @@ def test_error_with_missing_referencing_key(self): 'Cannot find referenced key', ): models.ForeignKeyTestModel({ - 'referencing_key_1': 'key', - 'referencing_key_2': 'key', - 'referencing_key_3': 42, - 'value': 'value' + 'referencing_key_1': 'key', + 'referencing_key_2': 'key', + 'referencing_key_3': 42, + 'value': 'value' }).save() def test_key(self): models.SmallTestModel({'key': 'key', 'value_1': 'value'}).save() - models.UnittestModel( - {'string': 'string', - 'int_': 42, - 'float_': 4.2, - 'timestamp': datetime.datetime.now(tz=datetime.timezone.utc), - }).save() + models.UnittestModel({ + 'string': 'string', + 'int_': 42, + 'float_': 4.2, + 'timestamp': datetime.datetime.now(tz=datetime.timezone.utc), + }).save() models.ForeignKeyTestModel({ 'referencing_key_1': 'key', 'referencing_key_2': 'string', 'referencing_key_3': 42, 'value': 'value' - }).save() + }).save() + if __name__ == '__main__': logging.basicConfig() diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py index b8e776b..7e4c291 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py @@ -35,13 +35,16 @@ class OriginalForeignKeyTestModelTable(spanner_orm.model.Model): referencing_key_3 = field.Field(field.Integer, primary_key=True) self_referencing_key = field.Field(field.String, nullable=True) foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( - 'SmallTestModel', {'referencing_key_1': 'key'}) + 'SmallTestModel', {'referencing_key_1': 'key'}) foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( - 'UnittestModel', - {'referencing_key_2': 'string', 'referencing_key_3': 'int_'}, + 'UnittestModel', + { + 'referencing_key_2': 'string', + 'referencing_key_3': 'int_' + }, ) foreign_key_3 = foreign_key_relationship.ForeignKeyRelationship( - 'ForeignKeyTestModel', {'self_referencing_key': 'referencing_key_1'}) + 'ForeignKeyTestModel', {'self_referencing_key': 'referencing_key_1'}) def upgrade() -> spanner_orm.CreateTable: diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py index 91fc5fe..8be2c08 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py @@ -38,6 +38,7 @@ class OriginalUnittestModelTable(spanner_orm.model.Model): timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) + def upgrade() -> spanner_orm.CreateTable: """See ORM migrations interface.""" return spanner_orm.CreateTable(OriginalUnittestModelTable) diff --git a/spanner_orm/tests/migrations_test.py b/spanner_orm/tests/migrations_test.py index 1324e07..446cce4 100644 --- a/spanner_orm/tests/migrations_test.py +++ b/spanner_orm/tests/migrations_test.py @@ -32,8 +32,7 @@ class MigrationsTest(unittest.TestCase): TEST_MIGRATIONS_DIR = os.path.join(TEST_DIR, 'migrations') def test_retrieve(self): - testdata_filename = os.path.join(os.path.dirname(__file__), - 'migrations') + testdata_filename = os.path.join(os.path.dirname(__file__), 'migrations') manager = migration_manager.MigrationManager(testdata_filename) migrations = manager.migrations self.assertEqual(len(migrations), 3) @@ -43,12 +42,10 @@ def test_retrieve(self): migrations[0].migration_id) def test_generate(self): - testdata_filename = os.path.join(os.path.dirname(__file__), - 'migrations') + testdata_filename = os.path.join(os.path.dirname(__file__), 'migrations') shutil.rmtree(self.TEST_MIGRATIONS_DIR) shutil.copytree(testdata_filename, self.TEST_MIGRATIONS_DIR) - os.chmod(self.TEST_MIGRATIONS_DIR, - stat.S_IRWXO | stat.S_IRWXU) + os.chmod(self.TEST_MIGRATIONS_DIR, stat.S_IRWXO | stat.S_IRWXU) for f in os.listdir(self.TEST_MIGRATIONS_DIR): file_path = os.path.join(self.TEST_MIGRATIONS_DIR, f) if not os.path.isdir(file_path): @@ -126,8 +123,8 @@ def test_order_migrations_error_on_no_successor(self): def test_filter_migrations(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor( - connection, self.TEST_MIGRATIONS_DIR) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -147,8 +144,8 @@ def test_filter_migrations(self): def test_filter_migrations_error_on_bad_last_migration(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor( - connection, self.TEST_MIGRATIONS_DIR) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -165,8 +162,8 @@ def test_filter_migrations_error_on_bad_last_migration(self): def test_validate_migrations(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor( - connection, self.TEST_MIGRATIONS_DIR) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -184,8 +181,8 @@ def test_validate_migrations(self): def test_validate_migrations_error_on_unmigrated_after_migrated(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor( - connection, self.TEST_MIGRATIONS_DIR) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -205,8 +202,8 @@ def test_validate_migrations_error_on_unmigrated_after_migrated(self): def test_validate_migrations_error_on_unmigrated_first(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor( - connection, self.TEST_MIGRATIONS_DIR) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('2', '1') with mock.patch.object(executor, 'migrations') as migrations: @@ -224,8 +221,8 @@ def test_validate_migrations_error_on_unmigrated_first(self): def test_migrate(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor( - connection, self.TEST_MIGRATIONS_DIR) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -239,8 +236,8 @@ def test_migrate(self): def test_rollback(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor( - connection, self.TEST_MIGRATIONS_DIR) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -257,6 +254,7 @@ def tearDownClass(cls): super().tearDownClass() shutil.rmtree(MigrationsTest.TEST_DIR) + if __name__ == '__main__': logging.basicConfig() unittest.main() diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index e4ccca7..1b1d4a2 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -62,6 +62,7 @@ class RelationshipTestModel(model.Model): parents = relationship.Relationship('spanner_orm.tests.models.SmallTestModel', {'parent_key': 'key'}) + class ForeignKeyTestModel(model.Model): """Model class for testing foreign keys.""" @@ -71,13 +72,17 @@ class ForeignKeyTestModel(model.Model): referencing_key_3 = field.Field(field.Integer, primary_key=True) self_referencing_key = field.Field(field.String, nullable=True) foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( - 'spanner_orm.tests.models.SmallTestModel', {'referencing_key_1': 'key'}) + 'spanner_orm.tests.models.SmallTestModel', {'referencing_key_1': 'key'}) foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( - 'spanner_orm.tests.models.UnittestModel', - {'referencing_key_2': 'string', 'referencing_key_3': 'int_'}, + 'spanner_orm.tests.models.UnittestModel', + { + 'referencing_key_2': 'string', + 'referencing_key_3': 'int_' + }, ) foreign_key_3 = foreign_key_relationship.ForeignKeyRelationship( - 'spanner_orm.tests.models.ForeignKeyTestModel', {'self_referencing_key': 'referencing_key_1'}) + 'spanner_orm.tests.models.ForeignKeyTestModel', + {'self_referencing_key': 'referencing_key_1'}) class InheritanceTestModel(SmallTestModel): @@ -113,4 +118,3 @@ class UnittestModelWithoutSecondaryIndexes(model.Model): string_2 = field.Field(field.String, nullable=True) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) - diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index 08e1d44..b851d6b 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -67,7 +67,8 @@ def test_count_allows_force_index(self): self.assertEqual(expected_sql, sql) @parameterized.parameters( - condition.limit(1), condition.order_by(('int_', condition.OrderType.DESC))) + condition.limit(1), condition.order_by( + ('int_', condition.OrderType.DESC))) def test_count_only_allows_where_and_from_segment_conditions(self, condition): with self.assertRaises(error.SpannerError): query.CountQuery(models.UnittestModel, [condition]) @@ -208,60 +209,67 @@ def test_force_index_with_object(self): expected_sql = 'FROM table@{FORCE_INDEX=test_index}' self.assertEndsWith(select_query.sql(), expected_sql) - def includes( - self, relation, *conditions, foreign_key_relation=False): - include_condition = condition.includes( - relation, list(conditions), foreign_key_relation) + def includes(self, relation, *conditions, foreign_key_relation=False): + include_condition = condition.includes(relation, list(conditions), + foreign_key_relation) return query.SelectQuery( - models.ForeignKeyTestModel if foreign_key_relation - else models.RelationshipTestModel, - [include_condition], + models.ForeignKeyTestModel + if foreign_key_relation else models.RelationshipTestModel, + [include_condition], ) - @parameterized.parameters( - (models.RelationshipTestModel.parent, True), - (models.ForeignKeyTestModel.foreign_key_1, False) - ) + @parameterized.parameters((models.RelationshipTestModel.parent, True), + (models.ForeignKeyTestModel.foreign_key_1, False)) def test_bad_includes_args(self, relation_key, foreign_key_relation): with self.assertRaisesRegex(ValueError, 'Must pass'): self.includes( - relation_key, - foreign_key_relation=foreign_key_relation, + relation_key, + foreign_key_relation=foreign_key_relation, ) @parameterized.named_parameters( - ( - 'legacy_relationship', - {'relation': 'parent'}, - r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ' - r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' - r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' - r'RelationshipTestModel.parent_key\)', - ), - ( - 'legacy_relationship_with_object_arg', - {'relation': models.RelationshipTestModel.parent}, - r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ' - r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' - r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' - r'RelationshipTestModel.parent_key\)', - ), - ( - 'foreign_key_relationship', - {'relation': 'foreign_key_1', 'foreign_key_relation': True}, - r'SELECT ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ' - r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' - r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' - r'ForeignKeyTestModel.referencing_key_1\)', - ), - ( - 'foreign_key_relationship_with_object_arg', - {'relation': models.ForeignKeyTestModel.foreign_key_1, 'foreign_key_relation': True}, - r'SELECT ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ' - r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' - r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' - r'ForeignKeyTestModel.referencing_key_1\)', - ), + ( + 'legacy_relationship', + { + 'relation': 'parent' + }, + r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ' + r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' + r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' + r'RelationshipTestModel.parent_key\)', + ), + ( + 'legacy_relationship_with_object_arg', + { + 'relation': models.RelationshipTestModel.parent + }, + r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ' + r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' + r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' + r'RelationshipTestModel.parent_key\)', + ), + ( + 'foreign_key_relationship', + { + 'relation': 'foreign_key_1', + 'foreign_key_relation': True + }, + r'SELECT ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ' + r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' + r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' + r'ForeignKeyTestModel.referencing_key_1\)', + ), + ( + 'foreign_key_relationship_with_object_arg', + { + 'relation': models.ForeignKeyTestModel.foreign_key_1, + 'foreign_key_relation': True + }, + r'SELECT ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ' + r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' + r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' + r'ForeignKeyTestModel.referencing_key_1\)', + ), ) def test_includes(self, includes_kwargs, expected_sql): select_query = self.includes(**includes_kwargs) @@ -271,18 +279,17 @@ def test_includes(self, includes_kwargs, expected_sql): self.assertEmpty(select_query.parameters()) self.assertEmpty(select_query.types()) - @parameterized.parameters( - ( - {'relation': models.RelationshipTestModel.parent, 'foreign_key_relation': True}, - ), - ( - {'relation': models.ForeignKeyTestModel.foreign_key_1, 'foreign_key_relation': False}, - ) - ) + @parameterized.parameters(({ + 'relation': models.RelationshipTestModel.parent, + 'foreign_key_relation': True + },), ({ + 'relation': models.ForeignKeyTestModel.foreign_key_1, + 'foreign_key_relation': False + },)) def test_error_mismatched_params(self, includes_kwargs): with self.assertRaisesRegex(ValueError, 'Must pass'): self.includes(**includes_kwargs) - + def test_includes_subconditions_query(self): select_query = self.includes('parents', condition.equal_to('key', 'value')) expected_sql = ( @@ -301,10 +308,12 @@ def includes_result(self, related=1): return child, parent, [result] def fk_includes_result(self, related=1): - child = {'referencing_key_1': 'parent_key', - 'referencing_key_2': 'child', - 'referencing_key_3': 'child', - 'self_referencing_key': 'child'} + child = { + 'referencing_key_1': 'parent_key', + 'referencing_key_2': 'child', + 'referencing_key_3': 'child', + 'self_referencing_key': 'child' + } result = [child[name] for name in models.ForeignKeyTestModel.columns] parent = {'key': 'key', 'value_1': 'value_1', 'value_2': None} parents = [] @@ -313,20 +322,24 @@ def fk_includes_result(self, related=1): result.append(parents) return child, parent, [result] - @parameterized.named_parameters( - ( - 'legacy_relationship', - {'relation': 'parent'}, - lambda x: x.parent, - lambda x: x.includes_result(related=1), - ), - ( - 'foreign_key_relationship', - {'relation': 'foreign_key_1', 'foreign_key_relation': True}, - lambda x: x.foreign_key_1, - lambda x: x.fk_includes_result(related=1), - ), + ( + 'legacy_relationship', + { + 'relation': 'parent' + }, + lambda x: x.parent, + lambda x: x.includes_result(related=1), + ), + ( + 'foreign_key_relationship', + { + 'relation': 'foreign_key_1', + 'foreign_key_relation': True + }, + lambda x: x.foreign_key_1, + lambda x: x.fk_includes_result(related=1), + ), ) def test_includes_single_related_object_result( self, @@ -339,38 +352,37 @@ def test_includes_single_related_object_result( result = select_query.process_results(rows)[0] self.assertIsInstance( - referenced_table_fn(result), - models.SmallTestModel, + referenced_table_fn(result), + models.SmallTestModel, ) for name, value in child_values.items(): self.assertEqual(getattr(result, name), value) for name, value in parent_values.items(): - self.assertEqual( - getattr(referenced_table_fn(result), name), - value - ) + self.assertEqual(getattr(referenced_table_fn(result), name), value) @parameterized.named_parameters( - ( - 'legacy_relationship', - {'relation': 'parent'}, - lambda x: x.parent, - lambda x: x.includes_result(related=0), - ), - ( - 'foreign_key_relationship', - {'relation': 'foreign_key_1', 'foreign_key_relation': True}, - lambda x: x.foreign_key_1, - lambda x: x.fk_includes_result(related=0), - ), + ( + 'legacy_relationship', + { + 'relation': 'parent' + }, + lambda x: x.parent, + lambda x: x.includes_result(related=0), + ), + ( + 'foreign_key_relationship', + { + 'relation': 'foreign_key_1', + 'foreign_key_relation': True + }, + lambda x: x.foreign_key_1, + lambda x: x.fk_includes_result(related=0), + ), ) - def test_includes_single_no_related_object_result( - self, - includes_kwargs, - referenced_table_fn, - includes_result_fn - ): + def test_includes_single_no_related_object_result(self, includes_kwargs, + referenced_table_fn, + includes_result_fn): select_query = self.includes(**includes_kwargs) child_values, _, rows = includes_result_fn(self) result = select_query.process_results(rows)[0] @@ -393,16 +405,21 @@ def test_includes_subcondition_result(self): self.assertEqual(getattr(result.parents[0], name), value) @parameterized.named_parameters( - ( - 'legacy_relationship', - {'relation': 'parent'}, - lambda x: x.includes_result(related=2), - ), - ( - 'foreign_key_relationship', - {'relation': 'foreign_key_1', 'foreign_key_relation': True}, - lambda x: x.fk_includes_result(related=2), - ), + ( + 'legacy_relationship', + { + 'relation': 'parent' + }, + lambda x: x.includes_result(related=2), + ), + ( + 'foreign_key_relationship', + { + 'relation': 'foreign_key_1', + 'foreign_key_relation': True + }, + lambda x: x.fk_includes_result(related=2), + ), ) def test_includes_error_on_multiple_results_for_single( self, includes_kwargs, includes_result_fn): @@ -412,8 +429,7 @@ def test_includes_error_on_multiple_results_for_single( _ = select_query.process_results(rows) @parameterized.parameters(True, False) - def test_includes_error_on_invalid_relation( - self, foreign_key_relation): + def test_includes_error_on_invalid_relation(self, foreign_key_relation): with self.assertRaises(error.ValidationError): self.includes('bad_relation', foreign_key_relation=foreign_key_relation) @@ -425,13 +441,9 @@ def test_includes_error_on_invalid_relation( ('key', ['bad value'], 'parent', False), ('key', ['bad value'], 'foreign_key_1', False), ) - def test_includes_error_on_invalid_subconditions( - self, - column, - value, - relation, - foreign_key_relation - ): + def test_includes_error_on_invalid_subconditions(self, column, value, + relation, + foreign_key_relation): with self.assertRaises(error.ValidationError): self.includes( relation, From 6e9e68b70cbaaa1fc43d64a99368627a266fa20d Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Fri, 29 Jan 2021 15:34:44 -0500 Subject: [PATCH 065/131] Refactor contains() now that ArbitraryCondition exists. This does break the API, but I don't think there's any way for that breakage to result in subtle bugs, and I think the type checker should catch any issues. Also, it's a pretty new API that I'm guessing doesn't have a lot of use yet. 1. This removes escaping logic, which seems like the sort of thing that could be error-prone. I'd expect switching to a dedicated function to also help performance, but I haven't tested it. 2. This makes it easier to add more features later. Specifically, I want to support case-insensitive substring matching. 3. Make the API more powerful. E.g., it's now possible to test if a value contains a column. (I don't actually need this feature now, but switching to ArbitraryCondition makes it trivial to provide.) --- spanner_orm/condition.py | 31 ++++++++++++----------------- spanner_orm/tests/condition_test.py | 27 ++++++++++++++++++++----- 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 3e8aa76..e6c7344 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -832,30 +832,25 @@ def columns_equal(origin_column: str, dest_model_class: Type[Any], def contains( - column: Union[field.Field, str], - value: str, -) -> ComparisonCondition: - """Condition where the specified column contains the given substring. + haystack: Substitution, + needle: Substitution, +) -> Condition: + """Condition where the specified haystack contains the given needle. Args: - column: Name of the column on the origin model or the Field on the origin - model class to compare from - value: The value to compare against + haystack: String or bytes to search. + needle: String or bytes to search for. Must be the same type as haystack. Returns: A Condition subclass that will be used in the query """ - value_escaped = value.translate( - str.maketrans({ - # https://cloud.google.com/spanner/docs/functions-and-operators#comparison_operators - '%': r'\%', - '_': r'\_', - '\\': '\\\\', - })) - return ComparisonCondition( - operator='LIKE', - field_or_name=column, - value=f'%{value_escaped}%', + return ArbitraryCondition( + 'STRPOS($haystack, $needle) > 0', + dict( + haystack=haystack, + needle=needle, + ), + segment=Segment.WHERE, ) diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py index 3abc346..ce91112 100644 --- a/spanner_orm/tests/condition_test.py +++ b/spanner_orm/tests/condition_test.py @@ -300,11 +300,28 @@ def test_or_condition( self.assertEqual(expected_sql, condition_.sql()) self.assertCountEqual(expected_row_keys, tuple(row.key for row in rows)) - def test_contains(self): - contains = spanner_orm.contains('some_column', r'a%b_c\d') - self.assertEqual('some_column', contains.column) - self.assertEqual('LIKE', contains.operator) - self.assertEqual(r'%a\%b\_c\\d%', contains.value) + @parameterized.parameters( + ('ABCD', 'BC', True), + ('ABCD', 'CB', False), + (b'ABCD', b'BC', True), + (b'ABCD', b'CB', False), + ) + def test_contains( + self, + haystack, + needle, + expect_results, + ): + test_model = models.SmallTestModel(dict(key='a', value_1='a', value_2='a')) + test_model.save() + self.assertCountEqual( + ((test_model,) if expect_results else ()), + models.SmallTestModel.where( + spanner_orm.contains( + condition.Param.from_value(haystack), + condition.Param.from_value(needle), + )), + ) if __name__ == '__main__': From fc2a393f76a704f1919651bd749f88c34525ad8e Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Fri, 29 Jan 2021 15:15:56 -0500 Subject: [PATCH 066/131] Add support for case-insensitive substring matching. --- spanner_orm/condition.py | 8 +++++++- spanner_orm/tests/condition_test.py | 10 ++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index e6c7344..250d86e 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -834,18 +834,24 @@ def columns_equal(origin_column: str, dest_model_class: Type[Any], def contains( haystack: Substitution, needle: Substitution, + *, + case_sensitive: bool = True, ) -> Condition: """Condition where the specified haystack contains the given needle. Args: haystack: String or bytes to search. needle: String or bytes to search for. Must be the same type as haystack. + case_sensitive: Whether comparison should be case sensitive or not. See + https://cloud.google.com/spanner/docs/functions-and-operators#lower for + caveats on how the case conversion works. Returns: A Condition subclass that will be used in the query """ return ArbitraryCondition( - 'STRPOS($haystack, $needle) > 0', + ('STRPOS($haystack, $needle) > 0' + if case_sensitive else 'STRPOS(LOWER($haystack), LOWER($needle)) > 0'), dict( haystack=haystack, needle=needle, diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py index ce91112..6d4a84c 100644 --- a/spanner_orm/tests/condition_test.py +++ b/spanner_orm/tests/condition_test.py @@ -302,15 +302,24 @@ def test_or_condition( @parameterized.parameters( ('ABCD', 'BC', True), + ('ABCD', 'bc', False), ('ABCD', 'CB', False), (b'ABCD', b'BC', True), + (b'ABCD', b'bc', False), (b'ABCD', b'CB', False), + ('ABCD', 'BC', True, dict(case_sensitive=False)), + ('ABCD', 'bc', True, dict(case_sensitive=False)), + ('ABCD', 'CB', False, dict(case_sensitive=False)), + (b'ABCD', b'BC', True, dict(case_sensitive=False)), + (b'ABCD', b'bc', True, dict(case_sensitive=False)), + (b'ABCD', b'CB', False, dict(case_sensitive=False)), ) def test_contains( self, haystack, needle, expect_results, + kwargs={}, ): test_model = models.SmallTestModel(dict(key='a', value_1='a', value_2='a')) test_model.save() @@ -320,6 +329,7 @@ def test_contains( spanner_orm.contains( condition.Param.from_value(haystack), condition.Param.from_value(needle), + **kwargs, )), ) From 19dad7f883a536855456d7d7f58eeff1a11acf9c Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 2 Feb 2021 18:38:08 -0500 Subject: [PATCH 067/131] Make it possible to add non-DDL migration updates. SchemaUpdate seems to be designed for DDL (schema) changes, but there currently doesn't seem to be any way to add a `NOT NULL` column using only DDL changes. This commit makes it easier to add non-DDL migration updates; a follow-up commit will add support for backfilling data in a migration update. --- README.md | 4 ++-- spanner_orm/__init__.py | 3 ++- spanner_orm/admin/migration.py | 18 +++++++------- spanner_orm/admin/migration.skel | 4 ++-- spanner_orm/admin/migration_executor.py | 16 ++++++------- spanner_orm/admin/update.py | 31 ++++++++++++++----------- 6 files changed, 41 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 5eaaa3e..ee22ba4 100644 --- a/README.md +++ b/README.md @@ -154,8 +154,8 @@ Running ```spanner-orm generate ``` will generate a new migration file to be filled out in the directory specified (or 'migrations' by default). The ```upgrade``` function is executed when migrating, and the ```downgrade``` function is executed when rolling back the migration. Each of -these should return a single SchemaUpdate object (e.g., CreateTable, AddColumn, -etc.), as Spanner cannot execute multiple schema updates atomically. +these should return a single MigrationUpdate object (e.g., CreateTable, +AddColumn, etc.), as Spanner cannot execute multiple schema updates atomically. ### Executing migrations Running ```spanner-orm migrate ``` will diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 76e0623..6734757 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -95,6 +95,8 @@ transactional_read = decorator.transactional_read transactional_write = decorator.transactional_write +MigrationUpdate = update_module.MigrationUpdate +NoUpdate = update_module.NoUpdate SchemaUpdate = update_module.SchemaUpdate CreateTable = update_module.CreateTable DropTable = update_module.DropTable @@ -103,7 +105,6 @@ AlterColumn = update_module.AlterColumn CreateIndex = update_module.CreateIndex DropIndex = update_module.DropIndex -NoUpdate = update_module.NoUpdate model_creation_ddl = update_module.model_creation_ddl MigrationExecutor = migration_executor.MigrationExecutor diff --git a/spanner_orm/admin/migration.py b/spanner_orm/admin/migration.py index 177e1a7..29d359e 100644 --- a/spanner_orm/admin/migration.py +++ b/spanner_orm/admin/migration.py @@ -19,18 +19,20 @@ from spanner_orm.admin import update -def no_update_callable() -> update.SchemaUpdate: +def no_update_callable() -> update.MigrationUpdate: return update.NoUpdate() class Migration: """Holds information about a specific migration.""" - def __init__(self, - migration_id: str, - prev_migration_id: Optional[str], - upgrade: Optional[Callable[[], update.SchemaUpdate]] = None, - downgrade: Optional[Callable[[], update.SchemaUpdate]] = None): + def __init__( + self, + migration_id: str, + prev_migration_id: Optional[str], + upgrade: Optional[Callable[[], update.MigrationUpdate]] = None, + downgrade: Optional[Callable[[], update.MigrationUpdate]] = None, + ): self._id = migration_id self._prev = prev_migration_id self._upgrade = upgrade or no_update_callable @@ -45,9 +47,9 @@ def prev_migration_id(self) -> Optional[str]: return self._prev @property - def upgrade(self) -> Optional[Callable[[], update.SchemaUpdate]]: + def upgrade(self) -> Optional[Callable[[], update.MigrationUpdate]]: return self._upgrade @property - def downgrade(self) -> Optional[Callable[[], update.SchemaUpdate]]: + def downgrade(self) -> Optional[Callable[[], update.MigrationUpdate]]: return self._downgrade diff --git a/spanner_orm/admin/migration.skel b/spanner_orm/admin/migration.skel index 0f3ac29..049022c 100644 --- a/spanner_orm/admin/migration.skel +++ b/spanner_orm/admin/migration.skel @@ -10,11 +10,11 @@ migration_id = $migration_id prev_migration_id = $prev_migration_id -def upgrade() -> spanner_orm.SchemaUpdate: +def upgrade() -> spanner_orm.MigrationUpdate: """See spanner_orm migrations interface.""" return spanner_orm.NoUpdate() -def downgrade() -> spanner_orm.SchemaUpdate: +def downgrade() -> spanner_orm.MigrationUpdate: """See spanner_orm migrations interface.""" return spanner_orm.NoUpdate() diff --git a/spanner_orm/admin/migration_executor.py b/spanner_orm/admin/migration_executor.py index 3b3b8e2..0d90a8c 100644 --- a/spanner_orm/admin/migration_executor.py +++ b/spanner_orm/admin/migration_executor.py @@ -66,12 +66,12 @@ def migrate(self, target_migration: Optional[str] = None) -> None: target_migration) for migration_ in migrations: _logger.info('Processing migration %s', migration_.migration_id) - schema_update = migration_.upgrade() - if not isinstance(schema_update, update.SchemaUpdate): + migration_update = migration_.upgrade() + if not isinstance(migration_update, update.MigrationUpdate): raise error.SpannerError( - 'Migration {} did not return a SchemaUpdate'.format( + 'Migration {} did not return a MigrationUpdate'.format( migration_.migration_id)) - schema_update.execute() + migration_update.execute() self._update_status(migration_.migration_id, True) self._hangup() @@ -97,12 +97,12 @@ def rollback(self, target_migration: str) -> None: reversed(self.migrations()), True, target_migration) for migration_ in migrations: _logger.info('Processing migration %s', migration_.migration_id) - schema_update = migration_.downgrade() - if not isinstance(schema_update, update.SchemaUpdate): + migration_update = migration_.downgrade() + if not isinstance(migration_update, update.MigrationUpdate): raise error.SpannerError( - 'Migration {} did not return a SchemaUpdate'.format( + 'Migration {} did not return a MigrationUpdate'.format( migration_.migration_id)) - schema_update.execute() + migration_update.execute() self._update_status(migration_.migration_id, False) self._hangup() diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 73f12a3..0f7803e 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -28,7 +28,23 @@ from spanner_orm.admin import metadata -class SchemaUpdate(abc.ABC): +class MigrationUpdate(abc.ABC): + """Base class for all updates that can happen in a migration.""" + + @abc.abstractmethod + def execute(self) -> None: + """Executes the update.""" + raise NotImplementedError + + +class NoUpdate(MigrationUpdate): + """Update that does nothing, for migrations that don't update db schemas.""" + + def execute(self) -> None: + """See base class.""" + + +class SchemaUpdate(MigrationUpdate, abc.ABC): """Base class for specifying schema updates.""" @abc.abstractmethod @@ -336,19 +352,6 @@ def validate(self) -> None: self._index)) -class NoUpdate(SchemaUpdate): - """Update that does nothing, for migrations that don't update db schemas.""" - - def ddl(self) -> str: - return '' - - def execute(self) -> None: - pass - - def validate(self) -> None: - pass - - def model_creation_ddl(model_: Type[model.Model]) -> List[str]: """Returns the list of ddl statements needed to create the model's table.""" ddl_list = [CreateTable(model_).ddl()] From 5176f7983959609098a12356f409b99f74448c69 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 2 Feb 2021 19:20:08 -0500 Subject: [PATCH 068/131] Add ExecutePartitionedDml migrations update. This makes it possible to backfill data from migrations, e.g., changing NULL values to a non-NULL default before making a column `NOT NULL`. --- spanner_orm/__init__.py | 1 + spanner_orm/admin/api.py | 4 +++ spanner_orm/admin/update.py | 15 ++++++++++++ spanner_orm/tests/models.py | 8 ++++++ spanner_orm/tests/update_test.py | 42 +++++++++++++++++++++++++++++++- 5 files changed, 69 insertions(+), 1 deletion(-) diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 6734757..19890b1 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -105,6 +105,7 @@ AlterColumn = update_module.AlterColumn CreateIndex = update_module.CreateIndex DropIndex = update_module.DropIndex +ExecutePartitionedDml = update_module.ExecutePartitionedDml model_creation_ddl = update_module.model_creation_ddl MigrationExecutor = migration_executor.MigrationExecutor diff --git a/spanner_orm/admin/api.py b/spanner_orm/admin/api.py index 8a28dfc..15ce1b4 100644 --- a/spanner_orm/admin/api.py +++ b/spanner_orm/admin/api.py @@ -44,6 +44,10 @@ def update_schema(self, change: str) -> None: operation = self._connection.update_ddl([change]) operation.result() + def execute_partitioned_dml(self, dml: str) -> None: + """See spanner_database.Database.execute_partitioned_dml().""" + self._connection.execute_partitioned_dml(dml) + _admin_api = None diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 0f7803e..ec842da 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -352,6 +352,21 @@ def validate(self) -> None: self._index)) +class ExecutePartitionedDml(MigrationUpdate): + """Update for running arbitrary partitioned DML. + + See https://cloud.google.com/spanner/docs/dml-partitioned for more + information. + """ + + def __init__(self, dml: str): + self._dml = dml + + def execute(self) -> None: + """See base class.""" + api.spanner_admin_api().execute_partitioned_dml(self._dml) + + def model_creation_ddl(model_: Type[model.Model]) -> List[str]: """Returns the list of ddl statements needed to create the model's table.""" ddl_list = [CreateTable(model_).ddl()] diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 1b1d4a2..9f05317 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -31,6 +31,14 @@ class SmallTestModel(model.Model): index_1 = index.Index(['value_1']) +class SmallTestModelWithoutSecondaryIndexes(model.Model): + """Same as SmallTestModel, but with no secondary indexes.""" + __table__ = 'SmallTestModel' + key = field.Field(field.String, primary_key=True) + value_1 = field.Field(field.String) + value_2 = field.Field(field.String, nullable=True) + + class ChildTestModel(model.Model): """Model class for testing interleaved tables.""" diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index 4bb4ac1..bc2d681 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -18,13 +18,30 @@ from absl.testing import parameterized +import spanner_orm from spanner_orm import error from spanner_orm import field from spanner_orm.admin import update +from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib from spanner_orm.tests import models -class UpdateTest(parameterized.TestCase): +class UpdateTest( + spanner_emulator_testlib.TestCase, + parameterized.TestCase, +): + + def setUp(self): + super().setUp() + _, project_id = self.spanner_emulator_client.project_name.split('/') + connection = spanner_orm.SpannerConnection( + instance=self.spanner_emulator_instance.instance_id, + database=self.spanner_emulator_database.database_id, + project=project_id, + credentials=self.spanner_emulator_client.credentials, + ) + spanner_orm.from_connection(connection) + spanner_orm.from_admin_connection(connection) @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_add_column(self, get_model): @@ -175,6 +192,29 @@ def test_add_index(self, test_update, expected_ddl, get_model): test_update.validate() self.assertEqual(test_update.ddl(), expected_ddl) + def test_execute_partitioned_dml(self): + update.CreateTable(models.SmallTestModelWithoutSecondaryIndexes).execute() + test_model = models.SmallTestModel( + dict( + key='some-key', + value_1='foo', + value_2='bar', + )) + test_model.save() + update.ExecutePartitionedDml( + "UPDATE SmallTestModel SET value_2 = value_1 WHERE value_2 = 'bar'", + ).execute() + test_model.reload() + self.assertEqual( + models.SmallTestModel( + dict( + key='some-key', + value_1='foo', + value_2='foo', + )), + test_model, + ) + if __name__ == '__main__': logging.basicConfig() From e8291cd6f749d69c17ef327537dc0367598ab255 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 3 Feb 2021 15:54:18 -0500 Subject: [PATCH 069/131] Make it more prominent that partitioned DML should be idempotent. --- spanner_orm/admin/update.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index ec842da..6000ba5 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -355,8 +355,9 @@ def validate(self) -> None: class ExecutePartitionedDml(MigrationUpdate): """Update for running arbitrary partitioned DML. - See https://cloud.google.com/spanner/docs/dml-partitioned for more - information. + NOTE: Partitioned DML queries should be idempotent. See + https://cloud.google.com/spanner/docs/dml-partitioned for details, and more + information about partitioned DML. """ def __init__(self, dml: str): From 0468211591fa1cc5e6781e05428d25b1b311685c Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 3 Feb 2021 18:23:34 -0500 Subject: [PATCH 070/131] Include seconds and timezone in migration timestamps. I think this could reduce confusion when there are multiple people writing migrations in different timezones. Tested: Manually created a migration, and it showed up as: 2021-02-03 18:23:19-05:00 --- spanner_orm/admin/migration_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spanner_orm/admin/migration_manager.py b/spanner_orm/admin/migration_manager.py index c7f9fde..d9da5c3 100644 --- a/spanner_orm/admin/migration_manager.py +++ b/spanner_orm/admin/migration_manager.py @@ -41,7 +41,8 @@ def generate(self, migration_name: str) -> str: """Creates a new migration that is the last migration to be executed.""" migration_id = uuid.uuid4().hex[-12:] prev_id = self.migrations[-1].migration_id if self.migrations else None - now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M') + now = datetime.datetime.now().astimezone().isoformat( + sep=' ', timespec='seconds') skeleton_directory = os.path.dirname(os.path.abspath(__file__)) skeleton_file = os.path.join(skeleton_directory, 'migration.skel') From bf4c6209c7b34adfdb951c621dadfebc18a22107 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Tue, 9 Feb 2021 11:22:59 -0500 Subject: [PATCH 071/131] Add FK relationship to __init__.py --- spanner_orm/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 19890b1..8ec8c35 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -20,6 +20,7 @@ from spanner_orm import decorator from spanner_orm import error from spanner_orm import field +from spanner_orm import foreign_key_relationship from spanner_orm import index from spanner_orm import model from spanner_orm import relationship @@ -61,6 +62,7 @@ Field = field.Field Integer = field.Integer Float = field.Float +ForeignKeyRelationship = foreign_key_relationship.ForeignKeyRelationship Index = index.Index Relationship = relationship.Relationship String = field.String From 178d248d2d5469b5086507a73c5270fce962c34d Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Tue, 9 Feb 2021 11:26:10 -0500 Subject: [PATCH 072/131] Update README to use FK relationship --- README.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ee22ba4..5b397cf 100644 --- a/README.md +++ b/README.md @@ -50,11 +50,10 @@ class TestModel(spanner_orm.Model): value_index = spanner_orm.Index(['value']) # To indicate that there is a foreign key relationship from this table to - # another one, use a Relationship. This has no impact on the representation - # of the table inside Spanner - fake_relationship = spanner_orm.Relationship( + # another one, use a ForeignKeyRelationship. + foreign_key = spanner_orm.ForeignKeyRelationship( 'OtherModel', - {'value': 'other_model_column'}) + {'referencing_key': 'referenced_key'}) ``` If the model does not refer to an existing table on Spanner, we can create From d8e5c74fc25a2a467b44cc46b13e0a237334764e Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Wed, 10 Feb 2021 21:26:34 -0500 Subject: [PATCH 073/131] Bump Spanner emulator version @bpg130, I think you were having running tests because of an outdated Spanner emulator version here. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ee22ba4..93fce8c 100644 --- a/README.md +++ b/README.md @@ -191,7 +191,7 @@ See https://github.com/GoogleCloudPlatform/cloud-spanner-emulator for several options. If you're on Linux, we recommend: ``` -VERSION=1.0.0 +VERSION=1.2.0 wget https://storage.googleapis.com/cloud-spanner-emulator/releases/${VERSION}/cloud-spanner-emulator_linux_amd64-${VERSION}.tar.gz tar zxvf cloud-spanner-emulator_linux_amd64-${VERSION}.tar.gz chmod u+x gateway_main emulator_main From 397c8b952ad258d2419c428f0cf9961b65bc41d2 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Fri, 12 Feb 2021 11:54:03 -0500 Subject: [PATCH 074/131] Update maintainers to be a mailing list --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 1dcd488..3264329 100644 --- a/setup.py +++ b/setup.py @@ -18,8 +18,8 @@ name='spanner-orm', version='0.1.10', description='Basic ORM for Spanner', - maintainer='Derek Brandao', - maintainer_email='dbrandao@google.com', + maintainer='Python Spanner ORM developers', + maintainer_email='python-spanner-orm@google.com', url='https://github.com/google/python-spanner-orm', packages=['spanner_orm', 'spanner_orm.admin'], include_package_data=True, From f2b80fa88a6c3959117cfde147d38186cd73fc02 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Thu, 18 Feb 2021 13:21:12 -0500 Subject: [PATCH 075/131] Migrate from frozendict to immutabledict. See https://github.com/slezica/python-frozendict/issues/25. --- setup.py | 5 ++++- spanner_orm/condition.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 3264329..05e639c 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,10 @@ packages=['spanner_orm', 'spanner_orm.admin'], include_package_data=True, python_requires='~=3.7', - install_requires=['google-cloud-spanner >= 1.6, <2.0.0dev', 'frozendict'], + install_requires=[ + 'google-cloud-spanner >= 1.6, <2.0.0dev', + 'immutabledict', + ], tests_require=['absl-py', 'google-api-core', 'portpicker'], entry_points={ 'console_scripts': ['spanner-orm = spanner_orm.admin.scripts:main'] diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 5694f70..fc0cede 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -29,9 +29,9 @@ from spanner_orm import index from spanner_orm import relationship -import frozendict from google.api_core import datetime_helpers from google.cloud.spanner_v1.proto import type_pb2 +import immutabledict T = TypeVar('T') @@ -254,7 +254,7 @@ class ArbitraryCondition(Condition): def __init__( self, sql_template: str, - substitutions: Mapping[str, Substitution] = frozendict.frozendict(), + substitutions: Mapping[str, Substitution] = immutabledict.immutabledict(), *, segment: Segment, ): From 04797ba7d0d1db914ff94dfdede7d49d84667cc5 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Mon, 22 Mar 2021 17:22:37 -0400 Subject: [PATCH 076/131] Make it possible to pass client_options to the Spanner client. --- spanner_orm/api.py | 28 +++++++++++++++++++--------- spanner_orm/tests/api_test.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/spanner_orm/api.py b/spanner_orm/api.py index 36d7fa4..4467eee 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -15,8 +15,9 @@ """Class that handles API calls to Spanner.""" import abc -from typing import Any, Callable, Iterable, Optional, TypeVar +from typing import Any, Callable, Dict, Iterable, Optional, TypeVar, Union +from google.api_core import client_options as api_client_options from google.api_core import exceptions from google.auth import credentials as auth_credentials from google.cloud import spanner @@ -105,13 +106,18 @@ def run_write(self, method: Callable[..., CallableReturn], *args: Any, class SpannerConnection: """Class that handles connecting to a Spanner database.""" - def __init__(self, - instance: str, - database: str, - project: Optional[str] = None, - credentials: Optional[auth_credentials.Credentials] = None, - pool: Optional[spanner_pool.AbstractSessionPool] = None, - create_ddl: Optional[Iterable[str]] = None): + def __init__( + self, + instance: str, + database: str, + project: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + pool: Optional[spanner_pool.AbstractSessionPool] = None, + create_ddl: Optional[Iterable[str]] = None, + *, + client_options: Union[api_client_options.ClientOptions, Dict[Any, Any], + None] = None, + ): """Connects to the specified Spanner database.""" self._instance = instance self._database = database @@ -119,12 +125,16 @@ def __init__(self, self._credentials = credentials self._pool = pool self._create_ddl = create_ddl + self._client_options = client_options self.connect() def connect(self): """Establish a new connection to the specified Spanner database.""" client = spanner.Client( - project=self._project, credentials=self._credentials) + project=self._project, + credentials=self._credentials, + client_options=self._client_options, + ) instance = client.instance(self._instance) self.database = instance.database( self._database, pool=self._pool, ddl_statements=self._create_ddl or ()) diff --git a/spanner_orm/tests/api_test.py b/spanner_orm/tests/api_test.py index a1abb25..7dab363 100644 --- a/spanner_orm/tests/api_test.py +++ b/spanner_orm/tests/api_test.py @@ -18,6 +18,8 @@ from absl.testing import parameterized from google.api_core import exceptions +from google.cloud import spanner + from spanner_orm import api from spanner_orm import error from spanner_orm.admin import api as admin_api @@ -40,6 +42,37 @@ def _connection(self): class ApiTest(parameterized.TestCase): + @mock.patch.object(spanner, 'Client', autospec=True, spec_set=True) + def test_connection_args(self, client): + client.return_value.instance.return_value.database.return_value = ( + 'fake-database') + connection = api.SpannerConnection( + instance='some-instance', + database='some-database', + project='some-project', + credentials='fake-credentials', + pool='fake-pool', + create_ddl=('fake-ddl',), + client_options=dict(fake='options'), + ) + self.assertEqual('fake-database', connection.database) + self.assertSequenceEqual( + ( + mock.call( + project='some-project', + credentials='fake-credentials', + client_options=dict(fake='options'), + ), + mock.call().instance('some-instance'), + mock.call().instance().database( + 'some-database', + pool='fake-pool', + ddl_statements=('fake-ddl',), + ), + ), + client.mock_calls, + ) + @mock.patch('google.cloud.spanner.Client') def test_api_connection(self, client): connection = self.mock_connection(client) From 360b59cc1e3ab1268a5f549a10ab91a8192b845c Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 23 Mar 2021 15:09:32 -0400 Subject: [PATCH 077/131] Deprecate connect() functions in favor of from_connection() They provide the same functionality, but from_connection() functions are more powerful since they let you store the connection information to use multiple times. It doesn't seem worth the overhead of maintaining both versions. --- README.md | 3 ++- spanner_orm/admin/api.py | 10 +++++++++- spanner_orm/api.py | 10 +++++++++- spanner_orm/tests/api_test.py | 18 ++++++++++++++++-- 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index bdaed11..07284f2 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,8 @@ then run: To connect the Spanner ORM to an existing Spanner database: ``` python import spanner_orm -spanner_orm.connect(instance_name, database_name) +spanner_orm.from_connection( + spanner_orm.SpannerConnection(instance_name, database_name)) ``` `project` and `credentials` are optional parameters, and the standard Spanner diff --git a/spanner_orm/admin/api.py b/spanner_orm/admin/api.py index 15ce1b4..7d498a3 100644 --- a/spanner_orm/admin/api.py +++ b/spanner_orm/admin/api.py @@ -15,6 +15,8 @@ """Class that handles API calls to Spanner that deal with table metadata.""" from typing import Iterable, Optional +import warnings + from spanner_orm import api from spanner_orm import error @@ -58,7 +60,13 @@ def connect(instance: str, credentials: Optional[auth_credentials.Credentials] = None, pool: Optional[spanner_pool.AbstractSessionPool] = None, create_ddl: Optional[Iterable[str]] = None) -> SpannerAdminApi: - """Connects the global Spanner admin API to a Spanner database.""" + """Connects the global Spanner admin API to a Spanner database. + + Deprecated in favor of from_connection(). + """ + warnings.warn( + DeprecationWarning('Please use spanner_orm.from_admin_connection(' + 'spanner_orm.SpannerConnection(...))')) connection = api.SpannerConnection( instance, database, diff --git a/spanner_orm/api.py b/spanner_orm/api.py index 4467eee..52227a2 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -16,6 +16,7 @@ import abc from typing import Any, Callable, Dict, Iterable, Optional, TypeVar, Union +import warnings from google.api_core import client_options as api_client_options from google.api_core import exceptions @@ -160,7 +161,14 @@ def connect( project: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, pool: Optional[spanner_pool.AbstractSessionPool] = None) -> SpannerApi: - """Connects to the Spanner database and sets the global spanner_api.""" + """Connects to the Spanner database and sets the global spanner_api. + + Deprecated in favor of from_connection(). + """ + warnings.warn( + DeprecationWarning( + 'Please use ' + 'spanner_orm.from_connection(spanner_orm.SpannerConnection(...))')) connection = SpannerConnection( instance, database, project=project, credentials=credentials, pool=pool) return from_connection(connection) diff --git a/spanner_orm/tests/api_test.py b/spanner_orm/tests/api_test.py index 7dab363..22533ca 100644 --- a/spanner_orm/tests/api_test.py +++ b/spanner_orm/tests/api_test.py @@ -15,6 +15,7 @@ import logging import unittest from unittest import mock +import warnings from absl.testing import parameterized from google.api_core import exceptions @@ -76,8 +77,14 @@ def test_connection_args(self, client): @mock.patch('google.cloud.spanner.Client') def test_api_connection(self, client): connection = self.mock_connection(client) - api.connect('', '', '') + with warnings.catch_warnings(record=True) as connect_warnings: + api.connect('', '', '') self.assertEqual(api.spanner_api()._connection, connection) + self.assertLen(connect_warnings, 1) + connect_warning, = connect_warnings + self.assertIn('spanner_orm.from_connection', str(connect_warning.message)) + self.assertIs(DeprecationWarning, connect_warning.category) + self.assertEquals(api.__file__, connect_warning.filename) api.hangup() with self.assertRaises(error.SpannerError): @@ -90,8 +97,15 @@ def test_api_error_when_not_connected(self): @mock.patch('google.cloud.spanner.Client') def test_admin_api_connection(self, client): connection = self.mock_connection(client) - admin_api.connect('', '', '') + with warnings.catch_warnings(record=True) as connect_warnings: + admin_api.connect('', '', '') self.assertEqual(admin_api.spanner_admin_api()._connection, connection) + self.assertLen(connect_warnings, 1) + connect_warning, = connect_warnings + self.assertIn('spanner_orm.from_admin_connection', + str(connect_warning.message)) + self.assertIs(DeprecationWarning, connect_warning.category) + self.assertEquals(admin_api.__file__, connect_warning.filename) admin_api.hangup() with self.assertRaises(error.SpannerError): From e706b94aa0407b6be08b222936bab036f63d7f59 Mon Sep 17 00:00:00 2001 From: Gavin Duggan Date: Mon, 29 Mar 2021 13:11:49 -0700 Subject: [PATCH 078/131] Add context to field validation errors --- spanner_orm/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/spanner_orm/model.py b/spanner_orm/model.py index e4f2d48..4147f21 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -140,7 +140,9 @@ def validate_value(cls, field_name, value, error_type=error.SpannerError): try: cls.fields[field_name].validate(value) except error.ValidationError as ex: - raise error_type(*ex.args) + context = f'Validation error for field {field_name!r}' + raise error_type((f'{context}: {ex.args[0]}' if ex.args else context), + *ex.args[1:]) CallableReturn = TypeVar('CallableReturn') From 5a85a320d10488751659909aaf5f5110fefeae1c Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 31 Aug 2021 18:38:14 -0400 Subject: [PATCH 079/131] Run yapf --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 05e639c..386a1b3 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ # limitations under the License. """spanner_orm setup file.""" from setuptools import setup + setup( name='spanner-orm', version='0.1.10', From 2f23334c8a765ba7131fdd3fb99ff5ec6b44ecef Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 31 Aug 2021 18:58:30 -0400 Subject: [PATCH 080/131] Fix type annotations --- spanner_orm/condition.py | 4 ++-- spanner_orm/tests/model_test.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index fc0cede..0dfb518 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -412,7 +412,7 @@ def __init__( relation_or_name: Union[relationship.Relationship, foreign_key_relationship.ForeignKeyRelationship, str], - conditions: List[Condition] = None, + conditions: Optional[List[Condition]] = None, # Default argument is `False` for backwards-compatability. foreign_key_relation=False, ): @@ -921,7 +921,7 @@ def greater_than_or_equal_to(column: Union[field.Field, str], def includes(relation: Union[relationship.Relationship, foreign_key_relationship.ForeignKeyRelationship, str], - conditions: List[Condition] = None, + conditions: Optional[List[Condition]] = None, foreign_key_relation: bool = False) -> IncludesCondition: """Condition where the objects associated with a relationship are retrieved. diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index b521b34..17889cd 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -15,6 +15,7 @@ import datetime import logging import os +import typing from typing import List import unittest from unittest import mock @@ -352,7 +353,7 @@ def test_object_changes(self): }) # Make sure that changing an object on the model shows up in changes() - string_array = test_model.string_array # type: List + string_array = typing.cast(List[str], test_model.string_array) string_array.append('bat') self.assertIn('string_array', test_model.changes()) From 1e2aa6b87d0772a7c1acb87d207e5d0bb5f49fed Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 31 Aug 2021 19:33:59 -0400 Subject: [PATCH 081/131] Fix pytype import-error, and simplify pytype args Pytype should probably run the same version as the rest of the python environment. --- README.md | 3 +-- spanner_orm/testlib/spanner_emulator/__init__.py | 0 2 files changed, 1 insertion(+), 2 deletions(-) create mode 100644 spanner_orm/testlib/spanner_emulator/__init__.py diff --git a/README.md b/README.md index 07284f2..021ad8a 100644 --- a/README.md +++ b/README.md @@ -205,8 +205,7 @@ To check type annotations, run: ``` pip install pytype -# https://github.com/google/pytype/issues/80#issuecomment-385128856 -pytype -V 3.7 spanner_orm -d import-error +pytype spanner_orm ``` To check formatting, run (change `--diff` to `--in-place` to fix formatting): diff --git a/spanner_orm/testlib/spanner_emulator/__init__.py b/spanner_orm/testlib/spanner_emulator/__init__.py new file mode 100644 index 0000000..e69de29 From 8d18c0466733a1685f096a2a15b72a0fd2fc62d1 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 31 Aug 2021 19:21:58 -0400 Subject: [PATCH 082/131] Fix access to nonexistent attribute I'm mostly guessing based on context what this code meant to do. As far as I can tell, the attribute never existed. --- spanner_orm/api.py | 7 ++++++- spanner_orm/tests/api_test.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/spanner_orm/api.py b/spanner_orm/api.py index 52227a2..95f221d 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -40,7 +40,7 @@ def _ensure_session(self, api_method, *args, **kwargs): if not 'Session not found' in e.message: raise - spanner_api().connect() + spanner_api().spanner_connection.connect() return api_method(*args, **kwargs) @@ -147,6 +147,11 @@ class SpannerApi(SpannerReadApi, SpannerWriteApi): def __init__(self, connection: SpannerConnection): self._spanner_connection = connection + @property + def spanner_connection(self) -> SpannerConnection: + """Connection to the database.""" + return self._spanner_connection + @property def _connection(self): return self._spanner_connection.database diff --git a/spanner_orm/tests/api_test.py b/spanner_orm/tests/api_test.py index 22533ca..b129a19 100644 --- a/spanner_orm/tests/api_test.py +++ b/spanner_orm/tests/api_test.py @@ -127,7 +127,7 @@ def test_reconnect_on_expected_error(self, api_method, mock_spanner_api): exceptions.NotFound('Session not found'), 'Anything other than an exception' ] - mock_connect = mock_spanner_api.return_value.connect + mock_connect = mock_spanner_api.return_value.spanner_connection.connect getattr(mock_api, api_method)(mock_method) From d5147280c7ddcf00bdf95f1e46831bb4e2a2e634 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 31 Aug 2021 18:35:30 -0400 Subject: [PATCH 083/131] Set up GitHub Actions --- .github/workflows/test.yaml | 59 +++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 .github/workflows/test.yaml diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..b63c498 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,59 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +on: + push: {} + +jobs: + test: + strategy: + matrix: + python-version: + - '3.7' + - '3.8' + - '3.9' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Install cloud-spanner-emulator + run: | + # https://github.com/GoogleCloudPlatform/cloud-spanner-emulator#via-pre-built-linux-binaries + VERSION=1.2.0 + wget https://storage.googleapis.com/cloud-spanner-emulator/releases/${VERSION}/cloud-spanner-emulator_linux_amd64-${VERSION}.tar.gz + tar zxvf cloud-spanner-emulator_linux_amd64-${VERSION}.tar.gz + chmod u+x gateway_main emulator_main + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install python dependencies + run: | + pip install \ + absl-py \ + google-api-core \ + 'google-cloud-spanner >= 1.6, <2.0.0dev' \ + immutabledict \ + portpicker + - name: Check formatting + run: | + pip install yapf + yapf --diff --recursive --parallel . + - name: Check types + run: | + pip install pytype + pytype --jobs=auto --keep-going spanner_orm + - name: Test + env: + SPANNER_EMULATOR_BINARY_PATH: ${{ github.workspace }}/emulator_main + run: | + python setup.py test From 62cc0df40ab084f4cc0d23828abc414a363f4427 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 1 Sep 2021 14:58:25 -0400 Subject: [PATCH 084/131] Run the test workflow periodically --- .github/workflows/test.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b63c498..77efffb 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -14,6 +14,8 @@ on: push: {} + schedule: + - cron: '50 13 * * *' jobs: test: From 333acf3edc4e62f64caa2dbe90814c2540922fec Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 1 Sep 2021 16:49:13 -0400 Subject: [PATCH 085/131] Remove calls to the deprecated alias assertEquals() --- spanner_orm/tests/api_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spanner_orm/tests/api_test.py b/spanner_orm/tests/api_test.py index b129a19..bbcb296 100644 --- a/spanner_orm/tests/api_test.py +++ b/spanner_orm/tests/api_test.py @@ -84,7 +84,7 @@ def test_api_connection(self, client): connect_warning, = connect_warnings self.assertIn('spanner_orm.from_connection', str(connect_warning.message)) self.assertIs(DeprecationWarning, connect_warning.category) - self.assertEquals(api.__file__, connect_warning.filename) + self.assertEqual(api.__file__, connect_warning.filename) api.hangup() with self.assertRaises(error.SpannerError): @@ -105,7 +105,7 @@ def test_admin_api_connection(self, client): self.assertIn('spanner_orm.from_admin_connection', str(connect_warning.message)) self.assertIs(DeprecationWarning, connect_warning.category) - self.assertEquals(admin_api.__file__, connect_warning.filename) + self.assertEqual(admin_api.__file__, connect_warning.filename) admin_api.hangup() with self.assertRaises(error.SpannerError): From 8b14108899172b7b4f9f2d174b14615d979b07d0 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Fri, 3 Sep 2021 15:00:39 -0400 Subject: [PATCH 086/131] Parse ABSL flags when running tests It looks like `python setup.py test`, `python -m unittest`, and `pytest` all import test modules rather than running them as __main__, so the usual parsing with absltest.main() doesn't work. This pytest fixture seems to work though. --- .github/workflows/test.yaml | 5 +++-- README.md | 2 +- spanner_orm/tests/conftest.py | 26 ++++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) create mode 100644 spanner_orm/tests/conftest.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 77efffb..a531d84 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -45,7 +45,8 @@ jobs: google-api-core \ 'google-cloud-spanner >= 1.6, <2.0.0dev' \ immutabledict \ - portpicker + portpicker \ + pytest - name: Check formatting run: | pip install yapf @@ -58,4 +59,4 @@ jobs: env: SPANNER_EMULATOR_BINARY_PATH: ${{ github.workspace }}/emulator_main run: | - python setup.py test + pytest diff --git a/README.md b/README.md index 021ad8a..467cb37 100644 --- a/README.md +++ b/README.md @@ -218,5 +218,5 @@ yapf --diff --recursive --parallel . Then run tests with: ``` -SPANNER_EMULATOR_BINARY_PATH=$(pwd)/emulator_main python3 setup.py test +SPANNER_EMULATOR_BINARY_PATH=$(pwd)/emulator_main pytest ``` diff --git a/spanner_orm/tests/conftest.py b/spanner_orm/tests/conftest.py new file mode 100644 index 0000000..8b8ed63 --- /dev/null +++ b/spanner_orm/tests/conftest.py @@ -0,0 +1,26 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pytest fixture.""" + +import sys + +from absl import flags +import pytest + + +@pytest.fixture(scope='session', autouse=True) +def parse_flags(): + # Only pass the first item, because pytest flags shouldn't be parsed as absl + # flags. + flags.FLAGS(sys.argv[:1]) From 627b7cc2bd994e65b2c756890b382f107a1874f7 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Fri, 3 Sep 2021 17:33:43 -0400 Subject: [PATCH 087/131] Connect the admin API I'm planning to use the admin API in a test that uses the emulator. --- spanner_orm/testlib/spanner_emulator/testlib.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/spanner_orm/testlib/spanner_emulator/testlib.py b/spanner_orm/testlib/spanner_emulator/testlib.py index 1203ba1..f59cab8 100644 --- a/spanner_orm/testlib/spanner_emulator/testlib.py +++ b/spanner_orm/testlib/spanner_emulator/testlib.py @@ -116,14 +116,11 @@ def tearDownClass(cls): def run_orm_migrations(self, migrations_folder: str) -> None: """Runs ORM migrations in the given directory and connects the ORM.""" - _migrate_database_at_connection( - _make_emulator_spanner_orm_connection(self.spanner_emulator_database, - self.spanner_emulator_instance, - self.spanner_emulator_client), - migrations_folder) + connection = _make_emulator_spanner_orm_connection( + self.spanner_emulator_database, self.spanner_emulator_instance, + self.spanner_emulator_client) + _migrate_database_at_connection(connection, migrations_folder) # spanner_orm closes the connection to Spanner after migrating so we need to # reconnect before making other Spanner calls. - spanner_orm.from_connection( - _make_emulator_spanner_orm_connection(self.spanner_emulator_database, - self.spanner_emulator_instance, - self.spanner_emulator_client)) + spanner_orm.from_connection(connection) + spanner_orm.from_admin_connection(connection) From 659d6fdee65170c66ce914bd2fe3446acb907054 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Fri, 3 Sep 2021 18:03:52 -0400 Subject: [PATCH 088/131] Don't check for interleaved indexes in DropTable https://cloud.google.com/spanner/docs/data-definition-language#create-index-interleave > If `T` is the table into which the index is interleaved, then: > `T` must be a parent of the table being indexed If I'm understanding that correctly, any table with an interleaved index will also have an interleaved table too. We already check for interleaved tables. --- spanner_orm/admin/update.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 6000ba5..08b2b2f 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -155,10 +155,6 @@ def _validate_not_interleaved(self, if model_.interleaved == existing_model: raise error.SpannerError('Table {} has interleaved table {}'.format( self._table, model_.table)) - for index_ in model_.indexes.values(): - if index_.parent == self._table: - raise error.SpannerError('Table {} has interleaved index {}'.format( - self._table, index_.name)) class AddColumn(SchemaUpdate): From 9524002ad19fcc8a012263e61860adcc94f8bae2 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Fri, 3 Sep 2021 18:00:31 -0400 Subject: [PATCH 089/131] Rely on Cloud Spanner to validate schema changes The model registry seems to be broken during migrations, because it assumes each model name will have only one corresponding class, but each migration that affects a model can create a new class. This commit fixes a bug where if there are any interleaved tables, accessing model_.interleaved while validating DropTable raises an exception because of the duplicate names. Instead of just fixing the validation code, I think it makes more sense to remove it entirely, since Cloud Spanner itself will raise exceptions about the same issues. --- spanner_orm/admin/update.py | 22 +-- spanner_orm/tests/migrations_emulator_test.py | 165 +++++++++++++++++- 2 files changed, 165 insertions(+), 22 deletions(-) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 08b2b2f..444e15d 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -55,9 +55,8 @@ def execute(self) -> None: self.validate() api.spanner_admin_api().update_schema(self.ddl()) - @abc.abstractmethod def validate(self) -> None: - raise NotImplementedError + pass # TODO(dseomn): Remove this method. class CreateTable(SchemaUpdate): @@ -137,25 +136,6 @@ def __init__(self, table_name: str): def ddl(self) -> str: return 'DROP TABLE {}'.format(self._table) - def validate(self) -> None: - existing_model = metadata.SpannerMetadata.model(self._table) - if not existing_model: - raise error.SpannerError('Table {} does not exist'.format(self._table)) - - # Model indexes include the primary index - if len(existing_model.indexes) > 1: - raise error.SpannerError('Table {} has a secondary index'.format( - self._table)) - - self._validate_not_interleaved(existing_model) - - def _validate_not_interleaved(self, - existing_model: Type[model.Model]) -> None: - for model_ in metadata.SpannerMetadata.models().values(): - if model_.interleaved == existing_model: - raise error.SpannerError('Table {} has interleaved table {}'.format( - self._table, model_.table)) - class AddColumn(SchemaUpdate): """Update for adding a column to an existing table. diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py index de26917..1b1140f 100644 --- a/spanner_orm/tests/migrations_emulator_test.py +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -15,9 +15,14 @@ import datetime import logging import os +import textwrap +from typing import Iterable, Type import unittest +from absl.testing import absltest +from absl.testing import parameterized import spanner_orm +from spanner_orm.admin import metadata from spanner_orm.tests import models from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib @@ -25,6 +30,8 @@ class MigrationsEmulatorTest(spanner_emulator_testlib.TestCase): + """Basic tests using generic migrations.""" + TEST_MIGRATIONS_DIR = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'migrations_for_emulator_test', @@ -73,6 +80,162 @@ def test_key(self): }).save() +class SpecificMigrationsEmulatorTest( + parameterized.TestCase, + spanner_emulator_testlib.TestCase, +): + """Tests of specific migrations.""" + + def setUp(self): + super().setUp() + self._migrations_dir = self.create_tempdir() + self._migration_index = None + + def _append_migrations(self, *migrations: str) -> None: + """Appends migrations to the sequence of migrations in self._migrations_dir. + + Args: + *migrations: Each string is the python code to define a single upgrade() + function. Leading indentation is stripped and migration boilerplate is + added. + """ + for migration in migrations: + if self._migration_index is None: + prev_migration_id = None + self._migration_index = 0 + else: + prev_migration_id = str(self._migration_index) + self._migration_index += 1 + migration_id = str(self._migration_index) + self._migrations_dir.create_file( + f'migration_{migration_id}.py', + '\n'.join(( + 'import spanner_orm', + f'migration_id = {migration_id!r}', + f'prev_migration_id = {prev_migration_id!r}', + textwrap.dedent(migration), + 'def downgrade(): raise NotImplementedError()', + )), + ) + + def test_drop_interleaved_table(self): + self._append_migrations( + """ + class _Parent(spanner_orm.Model): + __table__ = 'Parent' + parent_key = spanner_orm.Field( + spanner_orm.String, primary_key=True) + + def upgrade(): + return spanner_orm.CreateTable(_Parent) + """, + """ + class _Parent(spanner_orm.Model): + __table__ = 'Parent' + parent_key = spanner_orm.Field( + spanner_orm.String, primary_key=True) + + class _Child(spanner_orm.Model): + __table__ = 'Child' + __interleaved__ = _Parent + parent_key = spanner_orm.Field( + spanner_orm.String, primary_key=True) + child_key = spanner_orm.Field( + spanner_orm.String, primary_key=True) + + def upgrade(): + return spanner_orm.CreateTable(_Child) + """, + """ + def upgrade(): + return spanner_orm.DropTable('Child') + """, + ) + self.run_orm_migrations(self._migrations_dir) + self.assertCountEqual( + ('Parent',), + metadata.SpannerMetadata.tables().keys() - {'spanner_orm_migrations'}, + ) + + @parameterized.named_parameters( + dict( + testcase_name='does_not_exist', + create_migrations=(), + error_class=google_api_exceptions.NotFound, + ), + dict( + testcase_name='has_secondary_index', + create_migrations=( + """ + class _TableToDrop(spanner_orm.Model): + __table__ = 'TableToDrop' + key = spanner_orm.Field( + spanner_orm.String, primary_key=True) + value = spanner_orm.Field(spanner_orm.String) + + def upgrade(): + return spanner_orm.CreateTable(_TableToDrop) + """, + """ + def upgrade(): + return spanner_orm.CreateIndex( + table_name='TableToDrop', + index_name='value_index', + columns=['value'], + ) + """, + ), + error_class=google_api_exceptions.FailedPrecondition, + ), + dict( + testcase_name='has_interleaved_child', + create_migrations=( + """ + class _TableToDrop(spanner_orm.Model): + __table__ = 'TableToDrop' + parent_key = spanner_orm.Field( + spanner_orm.String, primary_key=True) + + def upgrade(): + return spanner_orm.CreateTable(_TableToDrop) + """, + """ + class _TableToDrop(spanner_orm.Model): + __table__ = 'TableToDrop' + parent_key = spanner_orm.Field( + spanner_orm.String, primary_key=True) + + class _Child(spanner_orm.Model): + __table__ = 'Child' + __interleaved__ = _TableToDrop + parent_key = spanner_orm.Field( + spanner_orm.String, primary_key=True) + child_key = spanner_orm.Field( + spanner_orm.String, primary_key=True) + + def upgrade(): + return spanner_orm.CreateTable(_Child) + """, + ), + error_class=google_api_exceptions.FailedPrecondition, + ), + ) + def test_drop_table_error( + self, + *, + create_migrations: Iterable[str], + error_class: Type[Exception], + ): + self._append_migrations(*create_migrations) + self.run_orm_migrations(self._migrations_dir) + self._append_migrations(""" + def upgrade(): + return spanner_orm.DropTable('TableToDrop') + """) + with self.assertRaises(error_class): + self.run_orm_migrations(self._migrations_dir) + + if __name__ == '__main__': logging.basicConfig() - unittest.main() + absltest.main() From 5b6affffef5c476281b7b12f0211c4949c7665a4 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Mon, 27 Sep 2021 09:43:45 -0400 Subject: [PATCH 090/131] Add status badge --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 467cb37..a9afd7c 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![.github/workflows/test.yaml](https://github.com/google/python-spanner-orm/actions/workflows/test.yaml/badge.svg)](https://github.com/google/python-spanner-orm/actions/workflows/test.yaml) + # Google Cloud Spanner ORM This is a lightweight ORM written in Python and built on top of Cloud Spanner. From e0c0eeb73a5c149ebc4f5172beec87344108f234 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Fri, 1 Oct 2021 13:16:06 -0400 Subject: [PATCH 091/131] Add BytesBase64 support --- spanner_orm/__init__.py | 1 + spanner_orm/field.py | 34 ++++++++++++++++--- spanner_orm/tests/migrations_emulator_test.py | 1 + .../create_unittest_model.py | 2 ++ spanner_orm/tests/model_test.py | 20 ++++++++--- spanner_orm/tests/models.py | 4 +++ spanner_orm/tests/update_test.py | 5 ++- 7 files changed, 56 insertions(+), 11 deletions(-) diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 8ec8c35..7678314 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -68,6 +68,7 @@ String = field.String StringArray = field.StringArray Timestamp = field.Timestamp +BytesBase64 = field.BytesBase64 ArbitraryCondition = condition.ArbitraryCondition Column = condition.Column diff --git a/spanner_orm/field.py b/spanner_orm/field.py index b8bc6b3..837bce3 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -15,12 +15,13 @@ """Helper to deal with field types in Spanner interactions.""" import abc +import base64 +import binascii import datetime from typing import Any, Type -from spanner_orm import error - from google.cloud.spanner_v1.proto import type_pb2 +from spanner_orm import error class FieldType(abc.ABC): @@ -118,7 +119,7 @@ class Float(FieldType): @staticmethod def ddl() -> str: - return "FLOAT64" + return 'FLOAT64' @staticmethod def grpc_type() -> type_pb2.Type: @@ -127,7 +128,7 @@ def grpc_type() -> type_pb2.Type: @staticmethod def validate_type(value: Any) -> None: if not isinstance(value, (int, float)): - raise error.ValidationError("{} is not of type float".format(value)) + raise error.ValidationError('{} is not of type float'.format(value)) class String(FieldType): @@ -184,4 +185,27 @@ def validate_type(value: Any) -> None: raise error.ValidationError('{} is not of type datetime'.format(value)) -ALL_TYPES = [Boolean, Integer, Float, String, StringArray, Timestamp] +class BytesBase64(FieldType): + """Represents a bytes type that must be base64 encoded.""" + + @staticmethod + def ddl() -> str: + return 'BYTES(MAX)' + + @staticmethod + def grpc_type() -> type_pb2.Type: + return type_pb2.Type(code=type_pb2.BYTES) + + @staticmethod + def validate_type(value) -> None: + if not isinstance(value, bytes): + raise error.ValidationError('{} is not of type bytes'.format(value)) + # Rudimentary test to check for base64 encoding. + try: + base64.b64decode(value, altchars=None, validate=True) + except binascii.Error: + raise error.ValidationError( + '{} must be base64-encoded bytes.'.format(value)) + +ALL_TYPES = [Boolean, Integer, Float, String, StringArray, Timestamp, + BytesBase64] diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py index 1b1140f..578a67c 100644 --- a/spanner_orm/tests/migrations_emulator_test.py +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -70,6 +70,7 @@ def test_key(self): 'string': 'string', 'int_': 42, 'float_': 4.2, + 'bytes_': b'A1A1', 'timestamp': datetime.datetime.now(tz=datetime.timezone.utc), }).save() models.ForeignKeyTestModel({ diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py index 8be2c08..db8387b 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py @@ -35,6 +35,8 @@ class OriginalUnittestModelTable(spanner_orm.model.Model): float_2 = field.Field(field.Float, nullable=True) string = field.Field(field.String, primary_key=True) string_2 = field.Field(field.String, nullable=True) + bytes_ = field.Field(field.BytesBase64, primary_key=True) + bytes_2 = field.Field(field.BytesBase64, nullable=True) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 17889cd..4a025da 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -51,6 +51,7 @@ def test_find_calls_api(self, find): string='string', int_=1, float_=2.3, + bytes_=b'A1A1', transaction=mock_transaction, ) @@ -59,7 +60,7 @@ def test_find_calls_api(self, find): self.assertEqual(transaction, mock_transaction) self.assertEqual(table, models.UnittestModel.table) self.assertEqual(columns, models.UnittestModel.columns) - self.assertEqual(keyset.keys, [[1, 2.3, 'string']]) + self.assertEqual(keyset.keys, [[1, 2.3, 'string', b'A1A1']]) @mock.patch('spanner_orm.table_apis.find') def test_find_result(self, find): @@ -98,6 +99,7 @@ def test_find_multi_calls_api(self, find): models.UnittestModel.find_multi( [{ 'string': 'string', + 'bytes_': b'bytes', 'int_': 1, 'float_': 2.3 }], @@ -109,7 +111,7 @@ def test_find_multi_calls_api(self, find): self.assertEqual(transaction, mock_transaction) self.assertEqual(table, models.UnittestModel.table) self.assertEqual(columns, models.UnittestModel.columns) - self.assertEqual(keyset.keys, [[1, 2.3, 'string']]) + self.assertEqual(keyset.keys, [[1, 2.3, 'string', b'bytes']]) @mock.patch('spanner_orm.table_apis.find') def test_find_multi_result(self, find): @@ -226,8 +228,8 @@ def test_set_error_on_primary_key(self): test_model.key = 'error' @parameterized.parameters(('int_2', 'foo'), ('float_2', 'bar'), - ('string_2', 5), ('string_array', 'foo'), - ('timestamp', 5)) + ('string_2', 5), ('bytes_2', 'string'), + ('string_array', 'foo'), ('timestamp', 5)) def test_set_error_on_invalid_type(self, attribute, value): string_array = ['foo', 'bar'] timestamp = datetime.datetime.now(tz=datetime.timezone.utc) @@ -235,6 +237,7 @@ def test_set_error_on_invalid_type(self, attribute, value): 'int_': 0, 'float_': 0, 'string': '', + 'bytes_': b'', 'string_array': string_array, 'timestamp': timestamp }) @@ -273,6 +276,7 @@ def test_model_equates(self): 'int_': 0, 'float_': 0, 'string': '', + 'bytes_': b'', 'string_array': ['foo', 'bar'], 'timestamp': timestamp, }) @@ -280,6 +284,7 @@ def test_model_equates(self): 'int_': 0, 'float_': 0.0, 'string': '', + 'bytes_': b'', 'string_array': ['foo', 'bar'], 'timestamp': timestamp, }) @@ -290,18 +295,21 @@ def test_model_equates(self): 'int_': 0, 'float_': 0, 'string': '1', + 'bytes_': b'1111', 'timestamp': _TIMESTAMP, }), models.UnittestModel({ 'int_': 0, 'float_': 0, 'string': 'a', + 'bytes_': b'A1A1', 'timestamp': _TIMESTAMP, })), (models.UnittestModel({ 'int_': 0, 'float_': 0, 'string': '', + 'bytes_': b'A1A1', 'string_array': ['foo', 'bar'], 'timestamp': _TIMESTAMP, }), @@ -309,6 +317,7 @@ def test_model_equates(self): 'int_': 0, 'float_': 0, 'string': '', + 'bytes_': b'A1A1', 'string_array': ['bar', 'foo'], 'timestamp': _TIMESTAMP, })), @@ -324,7 +333,7 @@ def test_model_are_different(self, test_model1, test_model2): self.assertNotEqual(test_model1, test_model2) def test_id(self): - primary_key = {'string': 'foo', 'int_': 5, 'float_': 2.3} + primary_key = {'string': 'foo', 'int_': 5, 'float_': 2.3, 'bytes_': b'A1A1'} all_data = primary_key.copy() all_data.update({ 'timestamp': datetime.datetime.now(tz=datetime.timezone.utc), @@ -348,6 +357,7 @@ def test_object_changes(self): 'int_': 0, 'float_': 0, 'string': '', + 'bytes_': b'', 'string_array': array, 'timestamp': timestamp }) diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 9f05317..d7a3db2 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -108,6 +108,8 @@ class UnittestModel(model.Model): float_2 = field.Field(field.Float, nullable=True) string = field.Field(field.String, primary_key=True) string_2 = field.Field(field.String, nullable=True) + bytes_ = field.Field(field.BytesBase64, primary_key=True) + bytes_2 = field.Field(field.BytesBase64, nullable=True) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) @@ -124,5 +126,7 @@ class UnittestModelWithoutSecondaryIndexes(model.Model): float_2 = field.Field(field.Float, nullable=True) string = field.Field(field.String, primary_key=True) string_2 = field.Field(field.String, nullable=True) + bytes_ = field.Field(field.BytesBase64, primary_key=True) + bytes_2 = field.Field(field.BytesBase64, nullable=True) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index bc2d681..b906ffd 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -1,3 +1,4 @@ + # python3 # Copyright 2019 Google LLC # @@ -87,8 +88,10 @@ def test_create_table(self, get_model): test_model_ddl = ('CREATE TABLE table (int_ INT64 NOT NULL, int_2 INT64,' ' float_ FLOAT64 NOT NULL, float_2 FLOAT64,' ' string STRING(MAX) NOT NULL, string_2 STRING(MAX),' + ' bytes_ BYTES(MAX) NOT NULL, bytes_2 BYTES(MAX),' ' timestamp TIMESTAMP NOT NULL, string_array' - ' ARRAY) PRIMARY KEY (int_, float_, string)') + ' ARRAY) PRIMARY KEY ' + '(int_, float_, string, bytes_)') self.assertEqual(test_update.ddl(), test_model_ddl) @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') From 2d422bda6b9d98804b2242a5860560b23d2797b0 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Fri, 1 Oct 2021 13:26:03 -0400 Subject: [PATCH 092/131] yapf fixes --- spanner_orm/field.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 837bce3..d5d864b 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -207,5 +207,7 @@ def validate_type(value) -> None: raise error.ValidationError( '{} must be base64-encoded bytes.'.format(value)) -ALL_TYPES = [Boolean, Integer, Float, String, StringArray, Timestamp, - BytesBase64] + +ALL_TYPES = [ + Boolean, Integer, Float, String, StringArray, Timestamp, BytesBase64 +] From 839a4d11522cfd8b6fef82b2c4a3b42752c84b38 Mon Sep 17 00:00:00 2001 From: Daniel Gorelik Date: Fri, 1 Oct 2021 13:27:59 -0400 Subject: [PATCH 093/131] tests yapf fixes --- spanner_orm/tests/update_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index b906ffd..b6ced0a 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -1,4 +1,3 @@ - # python3 # Copyright 2019 Google LLC # From eaf3f5f8b60973fc0463770ecc12079c99e70e9d Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Thu, 18 Nov 2021 16:24:50 -0500 Subject: [PATCH 094/131] Fix pytype errors with IncludesCondition It can't infer the type of self.relation from the value of self.foreign_key_relation. --- spanner_orm/condition.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 0dfb518..1c97f61 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -457,18 +457,15 @@ def bind(self, model_class: Type[Any]) -> None: @property def conditions(self) -> List[Condition]: """Generate the child conditions based on the relationship constraints.""" - if not self.relation: - raise error.SpannerError( - 'Condition must be bound before conditions is called') relation_conditions = [] - if self.foreign_key_relation: + if isinstance(self.relation, + foreign_key_relationship.ForeignKeyRelationship): for pair in self.relation.constraint.columns.items(): referencing_column, referenced_column = pair relation_conditions.append( ColumnsEqualCondition(referenced_column, self.model_class, referencing_column)) - - else: + elif isinstance(self.relation, relationship.Relationship): for constraint in self.relation.constraints: # This is backward from what you might imagine because the condition # will be processed from the context of the destination model. @@ -476,17 +473,21 @@ def conditions(self) -> List[Condition]: ColumnsEqualCondition(constraint.destination_column, constraint.origin_class, constraint.origin_column)) + else: + raise error.SpannerError( + 'Condition must be bound before conditions is called') return relation_conditions + self._conditions @property def destination(self) -> Type[Any]: - if not self.relation: - raise error.SpannerError( - 'Condition must be bound before destination is called') - if self.foreign_key_relation: + if isinstance(self.relation, + foreign_key_relationship.ForeignKeyRelationship): return self.relation.constraint.referenced_table - else: + elif isinstance(self.relation, relationship.Relationship): return self.relation.destination + else: + raise error.SpannerError( + 'Condition must be bound before destination is called') @property def relation_name(self) -> str: From e924ad38289b3a670886d658eea55001e825f169 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 15 Dec 2021 16:41:21 -0500 Subject: [PATCH 095/131] Ignore pytype findings from a bug in pytype --- spanner_orm/api.py | 3 ++- spanner_orm/model.py | 3 ++- spanner_orm/table_apis.py | 3 ++- spanner_orm/tests/api_test.py | 3 ++- spanner_orm/tests/model_test.py | 3 ++- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/spanner_orm/api.py b/spanner_orm/api.py index 95f221d..97f40e7 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -21,7 +21,8 @@ from google.api_core import client_options as api_client_options from google.api_core import exceptions from google.auth import credentials as auth_credentials -from google.cloud import spanner +# TODO(https://github.com/google/pytype/issues/1081): Remove pytype disable. +from google.cloud import spanner # pytype: disable=import-error from google.cloud.spanner_v1 import database as spanner_database from google.cloud.spanner_v1 import pool as spanner_pool from spanner_orm import error diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 4147f21..f3fc240 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -31,7 +31,8 @@ from spanner_orm import table_apis from google.api_core import exceptions -from google.cloud import spanner +# TODO(https://github.com/google/pytype/issues/1081): Remove pytype disable. +from google.cloud import spanner # pytype: disable=import-error from google.cloud.spanner_v1 import transaction as spanner_transaction T = TypeVar('T') diff --git a/spanner_orm/table_apis.py b/spanner_orm/table_apis.py index 0b64df3..5609006 100644 --- a/spanner_orm/table_apis.py +++ b/spanner_orm/table_apis.py @@ -17,7 +17,8 @@ import logging from typing import Any, Dict, Iterable, List, Sequence -from google.cloud import spanner +# TODO(https://github.com/google/pytype/issues/1081): Remove pytype disable. +from google.cloud import spanner # pytype: disable=import-error from google.cloud.spanner_v1 import transaction as spanner_transaction from google.cloud.spanner_v1.proto import type_pb2 diff --git a/spanner_orm/tests/api_test.py b/spanner_orm/tests/api_test.py index bbcb296..331507b 100644 --- a/spanner_orm/tests/api_test.py +++ b/spanner_orm/tests/api_test.py @@ -19,7 +19,8 @@ from absl.testing import parameterized from google.api_core import exceptions -from google.cloud import spanner +# TODO(https://github.com/google/pytype/issues/1081): Remove pytype disable. +from google.cloud import spanner # pytype: disable=import-error from spanner_orm import api from spanner_orm import error diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 4a025da..fe5d52f 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -22,7 +22,8 @@ from absl.testing import parameterized from google.api_core import exceptions -from google.cloud import spanner +# TODO(https://github.com/google/pytype/issues/1081): Remove pytype disable. +from google.cloud import spanner # pytype: disable=import-error from spanner_orm import error from spanner_orm import field from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib From 26b56c10c165e2d87c8950cd3a481dc0e53f290d Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Wed, 15 Dec 2021 17:13:44 -0500 Subject: [PATCH 096/131] Relax google-cloud-spanner constraint We've been running this in prod with google-cloud-spanner 3.11.1 for many months and things are working fine. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 386a1b3..6199fcd 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ include_package_data=True, python_requires='~=3.7', install_requires=[ - 'google-cloud-spanner >= 1.6, <2.0.0dev', + 'google-cloud-spanner >= 1.6, <4', 'immutabledict', ], tests_require=['absl-py', 'google-api-core', 'portpicker'], From 71ec2db83dc5a2b6f674e88d5a07e4fa15e8b90e Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Fri, 17 Dec 2021 20:10:56 +0000 Subject: [PATCH 097/131] Upgrade spanner-orm to work with google-cloud-spanner v2 --- setup.py | 4 +- spanner_orm/condition.py | 77 ++++++++++--------- spanner_orm/field.py | 33 ++++---- spanner_orm/table_apis.py | 9 ++- .../testlib/spanner_emulator/testlib.py | 6 +- spanner_orm/tests/condition_test.py | 47 +++++------ spanner_orm/tests/query_test.py | 6 +- 7 files changed, 93 insertions(+), 89 deletions(-) diff --git a/setup.py b/setup.py index 6199fcd..5e6126a 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ setup( name='spanner-orm', - version='0.1.10', + version='0.2.0', description='Basic ORM for Spanner', maintainer='Python Spanner ORM developers', maintainer_email='python-spanner-orm@google.com', @@ -26,7 +26,7 @@ include_package_data=True, python_requires='~=3.7', install_requires=[ - 'google-cloud-spanner >= 1.6, <4', + 'google-cloud-spanner >= 2, <4', 'immutabledict', ], tests_require=['absl-py', 'google-api-core', 'portpicker'], diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 1c97f61..293261e 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -30,7 +30,8 @@ from spanner_orm import relationship from google.api_core import datetime_helpers -from google.cloud.spanner_v1.proto import type_pb2 +from google.cloud import spanner +from google.cloud import spanner_v1 import immutabledict T = TypeVar('T') @@ -105,7 +106,7 @@ def sql(self) -> str: def _sql(self) -> str: pass - def types(self) -> Dict[str, type_pb2.Type]: + def types(self) -> Dict[str, spanner_v1.Type]: """Returns parameter types to be used in the SQL query. Returns: @@ -117,7 +118,7 @@ def types(self) -> Dict[str, type_pb2.Type]: return self._types() @abc.abstractmethod - def _types(self) -> Dict[str, type_pb2.Type]: + def _types(self) -> Dict[str, spanner_v1.Type]: raise NotImplementedError @abc.abstractmethod @@ -158,7 +159,8 @@ def _validate(self, model_class: Type[Any]) -> None: ] -def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type: +def _spanner_type_of_python_object( + value: GuessableParamType) -> spanner_v1.Type: """Returns the Cloud Spanner type of the given object. Args: @@ -173,31 +175,30 @@ def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type: raise TypeError( 'Cannot infer type of None, because any SQL type can be NULL.') simple_type_code = { - bool: type_pb2.BOOL, - int: type_pb2.INT64, - float: type_pb2.FLOAT64, - datetime_helpers.DatetimeWithNanoseconds: type_pb2.TIMESTAMP, - datetime.datetime: type_pb2.TIMESTAMP, - datetime.date: type_pb2.DATE, - bytes: type_pb2.BYTES, - str: type_pb2.STRING, - decimal.Decimal: type_pb2.NUMERIC, + bool: spanner_v1.TypeCode.BOOL, + int: spanner_v1.TypeCode.INT64, + float: spanner_v1.TypeCode.FLOAT64, + datetime_helpers.DatetimeWithNanoseconds: spanner_v1.TypeCode.TIMESTAMP, + datetime.datetime: spanner_v1.TypeCode.TIMESTAMP, + datetime.date: spanner_v1.TypeCode.DATE, + bytes: spanner_v1.TypeCode.BYTES, + str: spanner_v1.TypeCode.STRING, + decimal.Decimal: spanner_v1.TypeCode.NUMERIC, }.get(type(value)) if simple_type_code is not None: - return type_pb2.Type(code=simple_type_code) + return spanner_v1.Type(code=simple_type_code) elif isinstance(value, (list, tuple)): element_types = tuple( _spanner_type_of_python_object(item) for item in value if item is not None) unique_element_type_count = len({ - # Protos aren't hashable, so use their serializations. - element_type.SerializeToString(deterministic=True) - for element_type in element_types + # Protos aren't hashable, so serialize them. + str(element_type) for element_type in element_types }) if unique_element_type_count == 1: - return type_pb2.Type( - code=type_pb2.ARRAY, + return spanner_v1.Type( + code=spanner_v1.TypeCode.ARRAY, array_element_type=element_types[0], ) else: @@ -211,7 +212,7 @@ def _spanner_type_of_python_object(value: GuessableParamType) -> type_pb2.Type: class Param: """Parameter for substitution into a SQL query.""" value: Any - type: type_pb2.Type + type: spanner_v1.Type @classmethod def from_value(cls: Type[T], value: GuessableParamType) -> T: @@ -220,14 +221,13 @@ def from_value(cls: Type[T], value: GuessableParamType) -> T: # BYTES must be base64-encoded, see # https://github.com/googleapis/python-spanner/blob/87789c939990794bfd91f5300bedc449fd74bd7e/google/cloud/spanner_v1/proto/type.proto#L108-L110 - if (isinstance(value, bytes) and - guessed_type == type_pb2.Type(code=type_pb2.BYTES)): + if (isinstance(value, bytes) and guessed_type == spanner.param_types.BYTES): encoded_value = base64.b64encode(value).decode() elif (isinstance(value, (list, tuple)) and all(isinstance(x, bytes) for x in value if x is not None) and - guessed_type == type_pb2.Type( - code=type_pb2.ARRAY, - array_element_type=type_pb2.Type(code=type_pb2.BYTES), + guessed_type == spanner_v1.Type( + code=spanner_v1.TypeCode.ARRAY, + array_element_type=spanner.param_types.BYTES, )): encoded_value = tuple( None if item is None else base64.b64encode(item).decode() @@ -299,7 +299,7 @@ def _params(self) -> Dict[str, Any]: if isinstance(v, Param) } - def _types(self) -> Dict[str, type_pb2.Type]: + def _types(self) -> Dict[str, spanner_v1.Type]: """See base class.""" return { self.key(k): v.type @@ -345,7 +345,7 @@ def _sql(self) -> str: other_table=self.destination_model_class.table, other_column=self.destination_column) - def _types(self) -> Dict[str, type_pb2.Type]: + def _types(self) -> Dict[str, spanner_v1.Type]: return {} def _validate(self, model_class: Type[Any]) -> None: @@ -389,7 +389,7 @@ def segment(self) -> Segment: def _sql(self) -> str: return '@{{FORCE_INDEX={}}}'.format(self.name) - def _types(self) -> Dict[str, type_pb2.Type]: + def _types(self) -> Dict[str, spanner_v1.Type]: return {} def _validate(self, model_class: Type[Any]) -> None: @@ -509,7 +509,7 @@ def segment(self) -> Segment: def _sql(self) -> str: return '' - def _types(self) -> Dict[str, type_pb2.Type]: + def _types(self) -> Dict[str, spanner_v1.Type]: return {} def _validate(self, model_class: Type[Any]) -> None: @@ -568,10 +568,10 @@ def _sql(self) -> str: limit_key=self._limit_key, offset_key=self._offset_key) return 'LIMIT @{limit_key}'.format(limit_key=self._limit_key) - def _types(self) -> Dict[str, type_pb2.Type]: - types = {self._limit_key: type_pb2.Type(code=type_pb2.INT64)} + def _types(self) -> Dict[str, spanner_v1.Type]: + types = {self._limit_key: spanner.param_types.INT64} if self.offset: - types[self._offset_key] = type_pb2.Type(code=type_pb2.INT64) + types[self._offset_key] = spanner.param_types.INT64 return types def _validate(self, model_class: Type[Any]) -> None: @@ -624,7 +624,7 @@ def _sql(self) -> str: def segment(self) -> Segment: return Segment.WHERE - def _types(self) -> type_pb2.Type: + def _types(self) -> spanner_v1.Type: result = {} for condition in self.all_conditions: condition.suffix = str(int(self.suffix or 0) + len(result)) @@ -669,7 +669,7 @@ def _sql(self) -> str: def segment(self) -> Segment: return Segment.ORDER_BY - def _types(self) -> type_pb2.Type: + def _types(self) -> spanner_v1.Type: return {} def _validate(self, model_class: Type[Any]) -> None: @@ -714,7 +714,7 @@ def _sql(self) -> str: operator=self.operator, column_key=self._column_key) - def _types(self) -> type_pb2.Type: + def _types(self) -> spanner_v1.Type: return {self._column_key: self.model_class.fields[self.column].grpc_type()} def _validate(self, model_class: Type[Any]) -> None: @@ -740,9 +740,10 @@ def _sql(self) -> str: operator=self.operator, column_key=self._column_key) - def _types(self) -> type_pb2.Type: + def _types(self) -> spanner_v1.Type: grpc_type = self.model_class.fields[self.column].grpc_type() - list_type = type_pb2.Type(code=type_pb2.ARRAY, array_element_type=grpc_type) + list_type = spanner_v1.Type( + code=spanner_v1.TypeCode.ARRAY, array_element_type=grpc_type) return {self._column_key: list_type} def _validate(self, model_class: Type[Any]) -> None: @@ -782,7 +783,7 @@ def _sql(self) -> str: operator=self.nullable_operator) return super()._sql() - def _types(self) -> type_pb2.Type: + def _types(self) -> spanner_v1.Type: if self.is_null(): return {} return super()._types() diff --git a/spanner_orm/field.py b/spanner_orm/field.py index d5d864b..ebc12fe 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -20,7 +20,8 @@ import datetime from typing import Any, Type -from google.cloud.spanner_v1.proto import type_pb2 +from google.cloud import spanner +from google.cloud import spanner_v1 from spanner_orm import error @@ -34,7 +35,7 @@ def ddl() -> str: @staticmethod @abc.abstractmethod - def grpc_type() -> type_pb2.Type: + def grpc_type() -> spanner_v1.Type: raise NotImplementedError @staticmethod @@ -88,8 +89,8 @@ def ddl() -> str: return 'BOOL' @staticmethod - def grpc_type() -> type_pb2.Type: - return type_pb2.Type(code=type_pb2.BOOL) + def grpc_type() -> spanner_v1.Type: + return spanner.param_types.BOOL @staticmethod def validate_type(value: Any) -> None: @@ -105,8 +106,8 @@ def ddl() -> str: return 'INT64' @staticmethod - def grpc_type() -> type_pb2.Type: - return type_pb2.Type(code=type_pb2.INT64) + def grpc_type() -> spanner_v1.Type: + return spanner.param_types.INT64 @staticmethod def validate_type(value: Any) -> None: @@ -122,8 +123,8 @@ def ddl() -> str: return 'FLOAT64' @staticmethod - def grpc_type() -> type_pb2.Type: - return type_pb2.Type(code=type_pb2.FLOAT64) + def grpc_type() -> spanner_v1.Type: + return spanner.param_types.FLOAT64 @staticmethod def validate_type(value: Any) -> None: @@ -139,8 +140,8 @@ def ddl() -> str: return 'STRING(MAX)' @staticmethod - def grpc_type() -> type_pb2.Type: - return type_pb2.Type(code=type_pb2.STRING) + def grpc_type() -> spanner_v1.Type: + return spanner.param_types.STRING @staticmethod def validate_type(value) -> None: @@ -156,8 +157,8 @@ def ddl() -> str: return 'ARRAY' @staticmethod - def grpc_type() -> type_pb2.Type: - return type_pb2.Type(code=type_pb2.ARRAY) + def grpc_type() -> spanner_v1.Type: + return spanner.param_types.Array(spanner.param_types.STRING) @staticmethod def validate_type(value: Any) -> None: @@ -176,8 +177,8 @@ def ddl() -> str: return 'TIMESTAMP' @staticmethod - def grpc_type() -> type_pb2.Type: - return type_pb2.Type(code=type_pb2.TIMESTAMP) + def grpc_type() -> spanner_v1.Type: + return spanner.param_types.TIMESTAMP @staticmethod def validate_type(value: Any) -> None: @@ -193,8 +194,8 @@ def ddl() -> str: return 'BYTES(MAX)' @staticmethod - def grpc_type() -> type_pb2.Type: - return type_pb2.Type(code=type_pb2.BYTES) + def grpc_type() -> spanner_v1.Type: + return spanner.param_types.BYTES @staticmethod def validate_type(value) -> None: diff --git a/spanner_orm/table_apis.py b/spanner_orm/table_apis.py index 5609006..73f9097 100644 --- a/spanner_orm/table_apis.py +++ b/spanner_orm/table_apis.py @@ -19,8 +19,8 @@ # TODO(https://github.com/google/pytype/issues/1081): Remove pytype disable. from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner_v1 from google.cloud.spanner_v1 import transaction as spanner_transaction -from google.cloud.spanner_v1.proto import type_pb2 _logger = logging.getLogger(__name__) @@ -50,9 +50,10 @@ def find(transaction: spanner_transaction.Transaction, table_name: str, return list(stream_results) -def sql_query(transaction: spanner_transaction.Transaction, query: str, - parameters: Dict[str, Any], - parameter_types: Dict[str, type_pb2.Type]) -> List[Sequence[Any]]: +def sql_query( + transaction: spanner_transaction.Transaction, query: str, + parameters: Dict[str, Any], + parameter_types: Dict[str, spanner_v1.Type]) -> List[Sequence[Any]]: """Executes a given SQL query against the Spanner database. This isn't technically read-only, but it's necessary to implement the read- diff --git a/spanner_orm/testlib/spanner_emulator/testlib.py b/spanner_orm/testlib/spanner_emulator/testlib.py index f59cab8..a52a559 100644 --- a/spanner_orm/testlib/spanner_emulator/testlib.py +++ b/spanner_orm/testlib/spanner_emulator/testlib.py @@ -55,9 +55,9 @@ def _get_instance(spanner_client: client.Client) -> instance.Instance: Args: spanner_client: An initialized spanner client. """ - existing_instances = list(spanner_client.list_instances()) - if existing_instances: - return existing_instances[0] + existing_instances_pb = list(spanner_client.list_instances()) + if existing_instances_pb: + return instance.Instance.from_pb(existing_instances_pb[0], spanner_client) # The emulator has one default config. config = list(spanner_client.list_instance_configs())[0] diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py index 6d4a84c..0673211 100644 --- a/spanner_orm/tests/condition_test.py +++ b/spanner_orm/tests/condition_test.py @@ -22,7 +22,8 @@ from absl.testing import parameterized from google.api_core import datetime_helpers -from google.cloud.spanner_v1.proto import type_pb2 +from google.cloud import spanner +from google.cloud import spanner_v1 import spanner_orm from spanner_orm import condition @@ -45,30 +46,30 @@ def setUp(self): )) @parameterized.parameters( - (True, type_pb2.Type(code=type_pb2.BOOL)), - (0, type_pb2.Type(code=type_pb2.INT64)), - (0.0, type_pb2.Type(code=type_pb2.FLOAT64)), + (True, spanner_v1.param_types.BOOL), + (0, spanner_v1.param_types.INT64), + (0.0, spanner_v1.param_types.FLOAT64), ( datetime_helpers.DatetimeWithNanoseconds(2021, 1, 5), - type_pb2.Type(code=type_pb2.TIMESTAMP), + spanner_v1.param_types.TIMESTAMP, ), - (datetime.datetime(2021, 1, 5), type_pb2.Type(code=type_pb2.TIMESTAMP)), - (datetime.date(2021, 1, 5), type_pb2.Type(code=type_pb2.DATE)), - (b'\x01', type_pb2.Type(code=type_pb2.BYTES)), - ('foo', type_pb2.Type(code=type_pb2.STRING)), - (decimal.Decimal('1.23'), type_pb2.Type(code=type_pb2.NUMERIC)), + (datetime.datetime(2021, 1, 5), spanner_v1.param_types.TIMESTAMP), + (datetime.date(2021, 1, 5), spanner_v1.param_types.DATE), + (b'\x01', spanner_v1.param_types.BYTES), + ('foo', spanner_v1.param_types.STRING), + (decimal.Decimal('1.23'), spanner_v1.param_types.NUMERIC), ( (0, 1), - type_pb2.Type( - code=type_pb2.ARRAY, - array_element_type=type_pb2.Type(code=type_pb2.INT64), + spanner_v1.Type( + code=spanner_v1.TypeCode.ARRAY, + array_element_type=spanner_v1.param_types.INT64, ), ), ( ['a', None, 'b'], - type_pb2.Type( - code=type_pb2.ARRAY, - array_element_type=type_pb2.Type(code=type_pb2.STRING), + spanner_v1.Type( + code=spanner_v1.TypeCode.ARRAY, + array_element_type=spanner_v1.param_types.STRING, ), ), ) @@ -160,8 +161,8 @@ def test_param_from_value_correctly_encodes(self, tautology): key_param0='some-key', ), dict( - true_param0=type_pb2.Type(code=type_pb2.BOOL), - key_param0=type_pb2.Type(code=type_pb2.STRING), + true_param0=spanner_v1.param_types.BOOL, + key_param0=spanner_v1.param_types.STRING, ), ('SmallTestModel.key = ' 'IF(@true_param0, @key_param0, SmallTestModel.value_1)'), @@ -248,7 +249,7 @@ def test_arbitrary_condition_validation_error( condition.OrCondition( [condition.equal_to(models.SmallTestModel.key, 'a')]), dict(key0='a'), - dict(key0=type_pb2.Type(code=type_pb2.STRING)), + dict(key0=spanner_v1.param_types.STRING), '((SmallTestModel.key = @key0))', 'a', ), @@ -271,10 +272,10 @@ def test_arbitrary_condition_validation_error( value_13='b', ), dict( - key0=type_pb2.Type(code=type_pb2.STRING), - value_11=type_pb2.Type(code=type_pb2.STRING), - key2=type_pb2.Type(code=type_pb2.STRING), - value_13=type_pb2.Type(code=type_pb2.STRING), + key0=spanner_v1.param_types.STRING, + value_11=spanner_v1.param_types.STRING, + key2=spanner_v1.param_types.STRING, + value_13=spanner_v1.param_types.STRING, ), ('(' '(SmallTestModel.key = @key0 AND SmallTestModel.value_1 = @value_11)' diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index b851d6b..2644dba 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -24,7 +24,7 @@ from spanner_orm import query from spanner_orm.tests import models -from google.cloud.spanner_v1.proto import type_pb2 +from google.cloud import spanner_v1 def now(): @@ -178,8 +178,8 @@ def test_query_where_list_comparison(self, column, values, grpc_type): column_key = '{}0'.format(column) expected_sql = ' WHERE table.{} {} UNNEST(@{})'.format( column, current_condition.operator, column_key) - list_type = type_pb2.Type( - code=type_pb2.ARRAY, array_element_type=grpc_type) + list_type = spanner_v1.Type( + code=spanner_v1.TypeCode.ARRAY, array_element_type=grpc_type) self.assertEndsWith(select_query.sql(), expected_sql) self.assertEqual(select_query.parameters(), {column_key: values}) self.assertEqual(select_query.types(), {column_key: list_type}) From 325b911b92581163b8da30bdb9916f26d73b0e85 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 9 Feb 2022 17:43:07 -0500 Subject: [PATCH 098/131] Add __repr__() to Model. This makes it a lot easier to see what's wrong when a unit test assertion about models fails. --- spanner_orm/model.py | 3 +++ spanner_orm/tests/model_test.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/spanner_orm/model.py b/spanner_orm/model.py index f3fc240..e1cf687 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -196,6 +196,9 @@ def __eq__(self, other: Any) -> Union[bool, type(NotImplemented)]: return NotImplemented return self.values == other.values + def __repr__(self) -> str: + return f'{self.__class__.__qualname__}({self.values!r})' + @classmethod def spanner_api(cls) -> api.SpannerApi: if not cls.table: diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index fe5d52f..c36c52d 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -333,6 +333,11 @@ def test_model_equates(self): def test_model_are_different(self, test_model1, test_model2): self.assertNotEqual(test_model1, test_model2) + def test_repr(self): + self.assertEqual( + "SmallTestModel({'key': 'a', 'value_1': 'b', 'value_2': None})", + repr(models.SmallTestModel(dict(key='a', value_1='b', value_2=None)))) + def test_id(self): primary_key = {'string': 'foo', 'int_': 5, 'float_': 2.3, 'bytes_': b'A1A1'} all_data = primary_key.copy() From a1caa026bc082f90777f6a91a757146e43158f6f Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 1 Mar 2022 17:01:34 -0500 Subject: [PATCH 099/131] Add conditions to use NULL_FILTERED indexes safely --- spanner_orm/__init__.py | 1 + spanner_orm/condition.py | 86 ++++++++++++++++--- spanner_orm/tests/condition_test.py | 10 +++ ..._null_filtered_index_model_760ec5fae5da.py | 40 +++++++++ ...ed_index_model_value_index_69a8f072dacf.py | 28 ++++++ spanner_orm/tests/models.py | 10 +++ 6 files changed, 163 insertions(+), 12 deletions(-) create mode 100644 spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_760ec5fae5da.py create mode 100644 spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_value_index_69a8f072dacf.py diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 7678314..193ab45 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -81,6 +81,7 @@ contains = condition.contains equal_to = condition.equal_to force_index = condition.force_index +force_null_filtered_index = condition.force_null_filtered_index greater_than = condition.greater_than greater_than_or_equal_to = condition.greater_than_or_equal_to in_list = condition.in_list diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 1c97f61..4be0ea5 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -21,7 +21,7 @@ import decimal import enum import string -from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union from spanner_orm import error from spanner_orm import field @@ -364,10 +364,10 @@ def _validate(self, model_class: Type[Any]) -> None: origin.name, dest.name)) -class ForceIndexCondition(Condition): - """Used to indicate which index should be used in a Spanner query.""" +class _IndexCondition(Condition): + """Base class for conditions based on an Index.""" - def __init__(self, index_or_name: Union[Type[index.Index], str]): + def __init__(self, index_or_name: Union[index.Index, str]): super().__init__() if isinstance(index_or_name, index.Index): self.name = index_or_name.name @@ -380,6 +380,27 @@ def bind(self, model_class: Type[Any]) -> None: super().bind(model_class) self.index = self.model_class.indexes[self.name] + def _validate(self, model_class: Type[Any]) -> None: + if self.name not in model_class.indexes: + raise error.ValidationError('{} is not an index on {}'.format( + self.name, model_class.table)) + if self.index and self.index != model_class.indexes[self.name]: + raise error.ValidationError('{} does not belong to {}'.format( + self.index.name, model_class.table)) + + +class ForceIndexCondition(_IndexCondition): + """Used to indicate which index should be used in a Spanner query.""" + + def __init__( + self, + index_or_name: Union[index.Index, str], + *, + extra_hints: Sequence[str] = (), + ): + super().__init__(index_or_name) + self._extra_hints = extra_hints + def _params(self) -> Dict[str, Any]: return {} @@ -387,23 +408,35 @@ def segment(self) -> Segment: return Segment.FROM def _sql(self) -> str: - return '@{{FORCE_INDEX={}}}'.format(self.name) + hints = (f'FORCE_INDEX={self.name}', *self._extra_hints) + return f'@{{{",".join(hints)}}}' def _types(self) -> Dict[str, type_pb2.Type]: return {} def _validate(self, model_class: Type[Any]) -> None: - if self.name not in model_class.indexes: - raise error.ValidationError('{} is not an index on {}'.format( - self.name, model_class.table)) - if self.index and self.index != model_class.indexes[self.name]: - raise error.ValidationError('{} does not belong to {}'.format( - self.index.name, model_class.table)) - + super()._validate(model_class) if model_class.indexes[self.name].primary: raise error.ValidationError('Cannot force query using primary index') +class _IndexIgnoreNullsCondition(_IndexCondition): + """Condition to filter NULL values in any column of an index.""" + + def _params(self) -> Dict[str, Any]: + return {} + + def segment(self) -> Segment: + return Segment.WHERE + + def _sql(self) -> str: + return '({})'.format(' AND '.join( + f'{column} IS NOT NULL' for column in self.index.columns)) + + def _types(self) -> Dict[str, type_pb2.Type]: + return {} + + class IncludesCondition(Condition): """Used to include related model_classs via a relation in a Spanner query.""" @@ -888,6 +921,35 @@ def force_index(forced_index: Union[index.Index, str]) -> ForceIndexCondition: return ForceIndexCondition(forced_index) +def force_null_filtered_index( + forced_index: Union[index.Index, str]) -> Sequence[Condition]: + """Returns conditions to force the query to use the given NULL_FILTERED index. + + In Cloud Spanner, a query against a NULL_FILTERED index is tested to see if it + can use safely use that index. If using the index would result in incorrect + results (e.g., by ignoring NULL values that would be in the same query without + using the index), it's an error. However, the Cloud Spanner Emulator + doesn't support that check: + https://github.com/GoogleCloudPlatform/cloud-spanner-emulator/blob/e887ff5569684e6e45ce7c90d0fdfb7b1faa1491/common/errors.cc#L1790-L1800 + + For queries that can safely ignore any NULL values covered by the index, this + function returns conditions that both filter out all relevant NULLs (avoiding + the potential error in Cloud Spanner) and disable the check in Cloud Spanner + Emulator. + + Args: + forced_index: NULL_FILTERED index to use. + """ + return ( + ForceIndexCondition( + forced_index, + extra_hints=( + 'spanner_emulator.disable_query_null_filtered_index_check=true', + )), + _IndexIgnoreNullsCondition(forced_index), + ) + + def greater_than(column: Union[field.Field, str], value: Any) -> ComparisonCondition: """Condition where the specified column is greater than the given value. diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py index 6d4a84c..21af6e1 100644 --- a/spanner_orm/tests/condition_test.py +++ b/spanner_orm/tests/condition_test.py @@ -333,6 +333,16 @@ def test_contains( )), ) + def test_force_null_filtered_index(self): + non_null_model = models.NullFilteredIndexModel( + dict(key='a', value_1='a', value_2=1)) + non_null_model.save() + models.NullFilteredIndexModel(dict(key='b', value_1=None, value_2=2)).save() + self.assertCountEqual((non_null_model,), + models.NullFilteredIndexModel.where( + *spanner_orm.force_null_filtered_index( + models.NullFilteredIndexModel.value_index))) + if __name__ == '__main__': logging.basicConfig() diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_760ec5fae5da.py b/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_760ec5fae5da.py new file mode 100644 index 0000000..494a000 --- /dev/null +++ b/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_760ec5fae5da.py @@ -0,0 +1,40 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Spanner ORM migration: create_null_filtered_index_model. + +Migration ID: '760ec5fae5da' +Created: 2022-03-01 16:50:32-05:00 +""" + +import spanner_orm + +migration_id = '760ec5fae5da' +prev_migration_id = 'f735d6b706d4' + + +class _NullFilteredIndexModel(spanner_orm.Model): + __table__ = 'NullFilteredIndexModel' + key = spanner_orm.Field(spanner_orm.String, primary_key=True) + value_1 = spanner_orm.Field(spanner_orm.String, nullable=True) + value_2 = spanner_orm.Field(spanner_orm.Integer) + + +def upgrade() -> spanner_orm.MigrationUpdate: + """See spanner_orm migrations interface.""" + return spanner_orm.CreateTable(_NullFilteredIndexModel) + + +def downgrade() -> spanner_orm.MigrationUpdate: + """See spanner_orm migrations interface.""" + return spanner_orm.DropTable(_NullFilteredIndexModel.__table__) diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_value_index_69a8f072dacf.py b/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_value_index_69a8f072dacf.py new file mode 100644 index 0000000..e0c16bf --- /dev/null +++ b/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_value_index_69a8f072dacf.py @@ -0,0 +1,28 @@ +"""Spanner ORM migration: create_null_filtered_index_model_value_index. + +Migration ID: '69a8f072dacf' +Created: 2022-03-01 16:53:59-05:00 +""" + +import spanner_orm + +migration_id = '69a8f072dacf' +prev_migration_id = '760ec5fae5da' + + +def upgrade() -> spanner_orm.MigrationUpdate: + """See spanner_orm migrations interface.""" + return spanner_orm.CreateIndex( + table_name='NullFilteredIndexModel', + index_name='value_index', + columns=['value_1', 'value_2'], + null_filtered=True, + ) + + +def downgrade() -> spanner_orm.MigrationUpdate: + """See spanner_orm migrations interface.""" + return spanner_orm.DropIndex( + table_name='NullFilteredIndexModel', + index_name='value_index', + ) diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index d7a3db2..697a582 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -130,3 +130,13 @@ class UnittestModelWithoutSecondaryIndexes(model.Model): bytes_2 = field.Field(field.BytesBase64, nullable=True) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) + + +class NullFilteredIndexModel(model.Model): + """Model class for testing NULL_FILTERED indexes.""" + + __table__ = 'NullFilteredIndexModel' + key = field.Field(field.String, primary_key=True) + value_1 = field.Field(field.String, nullable=True) + value_2 = field.Field(field.Integer) + value_index = index.Index(['value_1', 'value_2'], null_filtered=True) From 9fe194e066618048cd0a8e95ed439514f965457d Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Tue, 8 Mar 2022 10:46:48 -0500 Subject: [PATCH 100/131] Better element type uniqueness check --- spanner_orm/condition.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 293261e..cb37f3f 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -192,11 +192,8 @@ def _spanner_type_of_python_object( _spanner_type_of_python_object(item) for item in value if item is not None) - unique_element_type_count = len({ - # Protos aren't hashable, so serialize them. - str(element_type) for element_type in element_types - }) - if unique_element_type_count == 1: + if element_types and all( + a == b for a, b in zip(element_types, element_types[1:])): return spanner_v1.Type( code=spanner_v1.TypeCode.ARRAY, array_element_type=element_types[0], From 7fffa68457c125be69c51e340a2562c401812040 Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Tue, 8 Mar 2022 11:06:12 -0500 Subject: [PATCH 101/131] Update CI version of google-cloud-spanner --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a531d84..cf57672 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -43,7 +43,7 @@ jobs: pip install \ absl-py \ google-api-core \ - 'google-cloud-spanner >= 1.6, <2.0.0dev' \ + 'google-cloud-spanner >= 2, <4' \ immutabledict \ portpicker \ pytest From 2ee17bab3fc7a0931bf34deec713a404779f8865 Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Mon, 14 Mar 2022 15:52:44 -0400 Subject: [PATCH 102/131] Pin pytype in GitHub actions --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index cf57672..6b2743e 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -53,7 +53,7 @@ jobs: yapf --diff --recursive --parallel . - name: Check types run: | - pip install pytype + pip install pytype==2021.11.29 pytype --jobs=auto --keep-going spanner_orm - name: Test env: From dbf2235da03d7b547ed4c4cefce591427ca35486 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Mon, 14 Mar 2022 16:26:22 -0400 Subject: [PATCH 103/131] Pin pytype to a version before a bug we're seeing a lot of https://github.com/google/pytype/issues/1081 is making https://github.com/google/python-spanner-orm/pull/171 more difficult, so we decided to pin pytype rather than add more disable comments. --- .github/workflows/test.yaml | 4 +++- spanner_orm/api.py | 3 +-- spanner_orm/model.py | 3 +-- spanner_orm/table_apis.py | 3 +-- spanner_orm/tests/api_test.py | 3 +-- spanner_orm/tests/model_test.py | 3 +-- 6 files changed, 8 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a531d84..c4a528e 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -53,7 +53,9 @@ jobs: yapf --diff --recursive --parallel . - name: Check types run: | - pip install pytype + # TODO(https://github.com/google/pytype/issues/1081): Remove the version + # pin. + pip install pytype==2021.11.29 pytype --jobs=auto --keep-going spanner_orm - name: Test env: diff --git a/spanner_orm/api.py b/spanner_orm/api.py index 97f40e7..95f221d 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -21,8 +21,7 @@ from google.api_core import client_options as api_client_options from google.api_core import exceptions from google.auth import credentials as auth_credentials -# TODO(https://github.com/google/pytype/issues/1081): Remove pytype disable. -from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner from google.cloud.spanner_v1 import database as spanner_database from google.cloud.spanner_v1 import pool as spanner_pool from spanner_orm import error diff --git a/spanner_orm/model.py b/spanner_orm/model.py index e1cf687..bd9be61 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -31,8 +31,7 @@ from spanner_orm import table_apis from google.api_core import exceptions -# TODO(https://github.com/google/pytype/issues/1081): Remove pytype disable. -from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner from google.cloud.spanner_v1 import transaction as spanner_transaction T = TypeVar('T') diff --git a/spanner_orm/table_apis.py b/spanner_orm/table_apis.py index 5609006..0b64df3 100644 --- a/spanner_orm/table_apis.py +++ b/spanner_orm/table_apis.py @@ -17,8 +17,7 @@ import logging from typing import Any, Dict, Iterable, List, Sequence -# TODO(https://github.com/google/pytype/issues/1081): Remove pytype disable. -from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner from google.cloud.spanner_v1 import transaction as spanner_transaction from google.cloud.spanner_v1.proto import type_pb2 diff --git a/spanner_orm/tests/api_test.py b/spanner_orm/tests/api_test.py index 331507b..bbcb296 100644 --- a/spanner_orm/tests/api_test.py +++ b/spanner_orm/tests/api_test.py @@ -19,8 +19,7 @@ from absl.testing import parameterized from google.api_core import exceptions -# TODO(https://github.com/google/pytype/issues/1081): Remove pytype disable. -from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner from spanner_orm import api from spanner_orm import error diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index c36c52d..4275d29 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -22,8 +22,7 @@ from absl.testing import parameterized from google.api_core import exceptions -# TODO(https://github.com/google/pytype/issues/1081): Remove pytype disable. -from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner from spanner_orm import error from spanner_orm import field from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib From 5f04392342e83f704b01ea2844809bc13149825d Mon Sep 17 00:00:00 2001 From: Sam Marder Date: Tue, 15 Mar 2022 16:02:38 -0400 Subject: [PATCH 104/131] Remove extra methods --- spanner_orm/condition.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 94bf730..b1e5a09 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -377,18 +377,6 @@ def bind(self, model_class: Type[Any]) -> None: super().bind(model_class) self.index = self.model_class.indexes[self.name] - def _params(self) -> Dict[str, Any]: - return {} - - def segment(self) -> Segment: - return Segment.FROM - - def _sql(self) -> str: - return '@{{FORCE_INDEX={}}}'.format(self.name) - - def _types(self) -> Dict[str, spanner_v1.Type]: - return {} - def _validate(self, model_class: Type[Any]) -> None: if self.name not in model_class.indexes: raise error.ValidationError('{} is not an index on {}'.format( From 7472083767f3983167b35bf5a6c2405c65dd7b60 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Fri, 16 Sep 2022 17:13:28 -0400 Subject: [PATCH 105/131] Remove some legacy things from when python 2 was around --- setup.py | 1 - spanner_orm/__init__.py | 1 - spanner_orm/admin/api.py | 1 - spanner_orm/admin/column.py | 1 - spanner_orm/admin/index.py | 1 - spanner_orm/admin/index_column.py | 1 - spanner_orm/admin/metadata.py | 3 +-- spanner_orm/admin/migration.py | 1 - spanner_orm/admin/migration_executor.py | 1 - spanner_orm/admin/migration_manager.py | 1 - spanner_orm/admin/migration_status.py | 1 - spanner_orm/admin/schema.py | 1 - spanner_orm/admin/scripts.py | 1 - spanner_orm/admin/table.py | 1 - spanner_orm/admin/update.py | 1 - spanner_orm/api.py | 1 - spanner_orm/condition.py | 1 - spanner_orm/decorator.py | 1 - spanner_orm/error.py | 1 - spanner_orm/field.py | 3 +-- spanner_orm/foreign_key_relationship.py | 3 +-- spanner_orm/index.py | 3 +-- spanner_orm/metadata.py | 3 +-- spanner_orm/model.py | 1 - spanner_orm/query.py | 1 - spanner_orm/registry.py | 3 +-- spanner_orm/relationship.py | 3 +-- spanner_orm/table_apis.py | 1 - spanner_orm/testlib/spanner_emulator/emulator.py | 1 - spanner_orm/testlib/spanner_emulator/testlib.py | 1 - spanner_orm/tests/admin_test.py | 1 - spanner_orm/tests/api_test.py | 1 - spanner_orm/tests/condition_test.py | 1 - spanner_orm/tests/decorator_test.py | 1 - spanner_orm/tests/metadata_test.py | 1 - spanner_orm/tests/migrations/test_1_4a7a7dee0718.py | 1 - spanner_orm/tests/migrations/test_2_5c078bbb4d43.py | 1 - spanner_orm/tests/migrations/test_3_eceb25f170dd.py | 1 - spanner_orm/tests/migrations_emulator_test.py | 1 - .../create_foreign_key_test_model.py | 1 - .../migrations_for_emulator_test/create_small_test_model.py | 1 - .../migrations_for_emulator_test/create_unittest_model.py | 1 - spanner_orm/tests/migrations_test.py | 1 - spanner_orm/tests/model_test.py | 1 - spanner_orm/tests/models.py | 1 - spanner_orm/tests/query_test.py | 1 - spanner_orm/tests/update_test.py | 1 - 47 files changed, 7 insertions(+), 54 deletions(-) diff --git a/setup.py b/setup.py index 5e6126a..9ec386e 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 193ab45..83e283a 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/admin/api.py b/spanner_orm/admin/api.py index 7d498a3..85c3062 100644 --- a/spanner_orm/admin/api.py +++ b/spanner_orm/admin/api.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/admin/column.py b/spanner_orm/admin/column.py index 167be6c..b40a3a3 100644 --- a/spanner_orm/admin/column.py +++ b/spanner_orm/admin/column.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/admin/index.py b/spanner_orm/admin/index.py index 95f4804..baadd9d 100644 --- a/spanner_orm/admin/index.py +++ b/spanner_orm/admin/index.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/admin/index_column.py b/spanner_orm/admin/index_column.py index 210be13..5b06937 100644 --- a/spanner_orm/admin/index_column.py +++ b/spanner_orm/admin/index_column.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/admin/metadata.py b/spanner_orm/admin/metadata.py index 988fe2b..5777634 100644 --- a/spanner_orm/admin/metadata.py +++ b/spanner_orm/admin/metadata.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,7 +27,7 @@ from spanner_orm.admin import table -class SpannerMetadata(object): +class SpannerMetadata: """Gathers information about a table from Spanner.""" @classmethod diff --git a/spanner_orm/admin/migration.py b/spanner_orm/admin/migration.py index 29d359e..7f63b06 100644 --- a/spanner_orm/admin/migration.py +++ b/spanner_orm/admin/migration.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/admin/migration_executor.py b/spanner_orm/admin/migration_executor.py index 0d90a8c..da5d51a 100644 --- a/spanner_orm/admin/migration_executor.py +++ b/spanner_orm/admin/migration_executor.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/admin/migration_manager.py b/spanner_orm/admin/migration_manager.py index d9da5c3..3d883bd 100644 --- a/spanner_orm/admin/migration_manager.py +++ b/spanner_orm/admin/migration_manager.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/admin/migration_status.py b/spanner_orm/admin/migration_status.py index cc9eb7b..f128739 100644 --- a/spanner_orm/admin/migration_status.py +++ b/spanner_orm/admin/migration_status.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/admin/schema.py b/spanner_orm/admin/schema.py index c39d553..5a5a1d8 100644 --- a/spanner_orm/admin/schema.py +++ b/spanner_orm/admin/schema.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/admin/scripts.py b/spanner_orm/admin/scripts.py index 6a1131a..2b9857f 100644 --- a/spanner_orm/admin/scripts.py +++ b/spanner_orm/admin/scripts.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/admin/table.py b/spanner_orm/admin/table.py index c07a1df..ac98e40 100644 --- a/spanner_orm/admin/table.py +++ b/spanner_orm/admin/table.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 444e15d..6599011 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/api.py b/spanner_orm/api.py index 95f221d..d049617 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index b1e5a09..a699169 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/decorator.py b/spanner_orm/decorator.py index a2dd7bc..5e3b378 100644 --- a/spanner_orm/decorator.py +++ b/spanner_orm/decorator.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/error.py b/spanner_orm/error.py index e1fe4d8..62ce862 100644 --- a/spanner_orm/error.py +++ b/spanner_orm/error.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/field.py b/spanner_orm/field.py index ebc12fe..f9dfbfa 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -44,7 +43,7 @@ def validate_type(value: Any) -> None: raise NotImplementedError -class Field(object): +class Field: """Represents a column in a table as a field in a model.""" def __init__(self, diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py index 8e1cb8d..9586a2b 100644 --- a/spanner_orm/foreign_key_relationship.py +++ b/spanner_orm/foreign_key_relationship.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,7 +26,7 @@ class ForeignKeyRelationshipConstraint: referenced_table: Type[Any] -class ForeignKeyRelationship(object): +class ForeignKeyRelationship: """Helps define a foreign key relationship between two models.""" def __init__(self, referenced_table_name: str, columns: Mapping[str, str]): diff --git a/spanner_orm/index.py b/spanner_orm/index.py index 1a1ac35..7702b3a 100644 --- a/spanner_orm/index.py +++ b/spanner_orm/index.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,7 +18,7 @@ from spanner_orm import error -class Index(object): +class Index: """Represents an index on a Model.""" PRIMARY_INDEX = 'PRIMARY_KEY' diff --git a/spanner_orm/metadata.py b/spanner_orm/metadata.py index 52075ff..794a7c6 100644 --- a/spanner_orm/metadata.py +++ b/spanner_orm/metadata.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -38,7 +37,7 @@ from spanner_orm import relationship -class ModelMetadata(object): +class ModelMetadata: """Hold information about a Model extracted from the class attributes.""" def __init__(self, diff --git a/spanner_orm/model.py b/spanner_orm/model.py index bd9be61..2269fe7 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/query.py b/spanner_orm/query.py index 8b4273c..85a35aa 100644 --- a/spanner_orm/query.py +++ b/spanner_orm/query.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/registry.py b/spanner_orm/registry.py index 1e78894..a1544a3 100644 --- a/spanner_orm/registry.py +++ b/spanner_orm/registry.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,7 +27,7 @@ def add(self, reference: Type[Any]) -> None: self.references.append(reference) -class Registry(object): +class Registry: def __init__(self): self._registered = {} # type: Dict[str, RegistryComponent] diff --git a/spanner_orm/relationship.py b/spanner_orm/relationship.py index d4947ef..ca3d1c2 100644 --- a/spanner_orm/relationship.py +++ b/spanner_orm/relationship.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,7 +28,7 @@ class RelationshipConstraint: origin_column: str -class Relationship(object): +class Relationship: """Helps define a foreign key relationship between two models.""" def __init__(self, diff --git a/spanner_orm/table_apis.py b/spanner_orm/table_apis.py index fe808b8..9c87223 100644 --- a/spanner_orm/table_apis.py +++ b/spanner_orm/table_apis.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/testlib/spanner_emulator/emulator.py b/spanner_orm/testlib/spanner_emulator/emulator.py index f624920..2cafa8a 100644 --- a/spanner_orm/testlib/spanner_emulator/emulator.py +++ b/spanner_orm/testlib/spanner_emulator/emulator.py @@ -1,4 +1,3 @@ -# Lint as: python3 # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/testlib/spanner_emulator/testlib.py b/spanner_orm/testlib/spanner_emulator/testlib.py index a52a559..1f7ca10 100644 --- a/spanner_orm/testlib/spanner_emulator/testlib.py +++ b/spanner_orm/testlib/spanner_emulator/testlib.py @@ -1,4 +1,3 @@ -# Lint as: python3 # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/admin_test.py b/spanner_orm/tests/admin_test.py index 9700634..22e890e 100644 --- a/spanner_orm/tests/admin_test.py +++ b/spanner_orm/tests/admin_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/api_test.py b/spanner_orm/tests/api_test.py index bbcb296..bf2d9a1 100644 --- a/spanner_orm/tests/api_test.py +++ b/spanner_orm/tests/api_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py index c892448..5fdd5bb 100644 --- a/spanner_orm/tests/condition_test.py +++ b/spanner_orm/tests/condition_test.py @@ -1,4 +1,3 @@ -# Lint as: python3 # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/decorator_test.py b/spanner_orm/tests/decorator_test.py index 93d84dc..d41cee7 100644 --- a/spanner_orm/tests/decorator_test.py +++ b/spanner_orm/tests/decorator_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/metadata_test.py b/spanner_orm/tests/metadata_test.py index 16b43d7..cd1f66d 100644 --- a/spanner_orm/tests/metadata_test.py +++ b/spanner_orm/tests/metadata_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/migrations/test_1_4a7a7dee0718.py b/spanner_orm/tests/migrations/test_1_4a7a7dee0718.py index cc39fea..29078a8 100644 --- a/spanner_orm/tests/migrations/test_1_4a7a7dee0718.py +++ b/spanner_orm/tests/migrations/test_1_4a7a7dee0718.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/migrations/test_2_5c078bbb4d43.py b/spanner_orm/tests/migrations/test_2_5c078bbb4d43.py index 6a347b0..31f6a4d 100644 --- a/spanner_orm/tests/migrations/test_2_5c078bbb4d43.py +++ b/spanner_orm/tests/migrations/test_2_5c078bbb4d43.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/migrations/test_3_eceb25f170dd.py b/spanner_orm/tests/migrations/test_3_eceb25f170dd.py index 44f7152..8510398 100644 --- a/spanner_orm/tests/migrations/test_3_eceb25f170dd.py +++ b/spanner_orm/tests/migrations/test_3_eceb25f170dd.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py index 578a67c..4f936bb 100644 --- a/spanner_orm/tests/migrations_emulator_test.py +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py index 7e4c291..bdd3935 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py @@ -1,4 +1,3 @@ -# Lint as: python3 # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py index 0902bdc..c4b019b 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py @@ -1,4 +1,3 @@ -# Lint as: python3 # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py index db8387b..15113ad 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py @@ -1,4 +1,3 @@ -# Lint as: python3 # Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/migrations_test.py b/spanner_orm/tests/migrations_test.py index 446cce4..e596cc4 100644 --- a/spanner_orm/tests/migrations_test.py +++ b/spanner_orm/tests/migrations_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 4275d29..560d9c1 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 697a582..896b450 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index 2644dba..97a33b9 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index b6ced0a..736c451 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); From 7cb2a60573683cab8a78dc8fee548bf86d35b13b Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 20 Sep 2022 13:56:18 -0400 Subject: [PATCH 106/131] Add missing docstrings and type annotations --- spanner_orm/field.py | 52 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index f9dfbfa..4db7129 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -17,7 +17,7 @@ import base64 import binascii import datetime -from typing import Any, Type +from typing import Any, Optional, Type from google.cloud import spanner from google.cloud import spanner_v1 @@ -30,49 +30,70 @@ class FieldType(abc.ABC): @staticmethod @abc.abstractmethod def ddl() -> str: + """Returns the DDL for this type.""" raise NotImplementedError @staticmethod @abc.abstractmethod def grpc_type() -> spanner_v1.Type: + """Returns the type as used in Cloud Spanner's gRPC API.""" raise NotImplementedError @staticmethod @abc.abstractmethod def validate_type(value: Any) -> None: + """Raises error.ValidationError if value doesn't match the type.""" raise NotImplementedError class Field: - """Represents a column in a table as a field in a model.""" + """Represents a column in a table as a field in a model. + + Attributes: + name: Name of the column, or None if this hasn't been bound to a column yet. + """ + name: Optional[str] def __init__(self, field_type: Type[FieldType], nullable: bool = False, primary_key: bool = False): + """Initializer. + + Args: + field_type: Type of the field. + nullable: Whether the field can be NULL. + primary_key: Whether the field is part of the table's primary key. + """ self.name = None self._type = field_type self._nullable = nullable self._primary_key = primary_key def ddl(self) -> str: + """Returns DDL for the column.""" if self._nullable: return self._type.ddl() return '{field_type} NOT NULL'.format(field_type=self._type.ddl()) def field_type(self) -> Type[FieldType]: + """Returns the type of the field.""" return self._type def grpc_type(self) -> str: + """Returns the type as used in Cloud Spanner's gRPC API.""" return self._type.grpc_type() def nullable(self) -> bool: + """Returns whether the field can be NULL.""" return self._nullable def primary_key(self) -> bool: + """Returns whether the field is part of the table's primary key.""" return self._primary_key - def validate(self, value) -> None: + def validate(self, value: Any) -> None: + """Raises error.ValidationError if value isn't compatible with the field.""" if value is None: if not self._nullable: raise error.ValidationError('None set for non-nullable field') @@ -85,14 +106,17 @@ class Boolean(FieldType): @staticmethod def ddl() -> str: + """See base class.""" return 'BOOL' @staticmethod def grpc_type() -> spanner_v1.Type: + """See base class.""" return spanner.param_types.BOOL @staticmethod def validate_type(value: Any) -> None: + """See base class.""" if not isinstance(value, bool): raise error.ValidationError('{} is not of type bool'.format(value)) @@ -102,14 +126,17 @@ class Integer(FieldType): @staticmethod def ddl() -> str: + """See base class.""" return 'INT64' @staticmethod def grpc_type() -> spanner_v1.Type: + """See base class.""" return spanner.param_types.INT64 @staticmethod def validate_type(value: Any) -> None: + """See base class.""" if not isinstance(value, int): raise error.ValidationError('{} is not of type int'.format(value)) @@ -119,14 +146,17 @@ class Float(FieldType): @staticmethod def ddl() -> str: + """See base class.""" return 'FLOAT64' @staticmethod def grpc_type() -> spanner_v1.Type: + """See base class.""" return spanner.param_types.FLOAT64 @staticmethod def validate_type(value: Any) -> None: + """See base class.""" if not isinstance(value, (int, float)): raise error.ValidationError('{} is not of type float'.format(value)) @@ -136,14 +166,17 @@ class String(FieldType): @staticmethod def ddl() -> str: + """See base class.""" return 'STRING(MAX)' @staticmethod def grpc_type() -> spanner_v1.Type: + """See base class.""" return spanner.param_types.STRING @staticmethod - def validate_type(value) -> None: + def validate_type(value: Any) -> None: + """See base class.""" if not isinstance(value, str): raise error.ValidationError('{} is not of type str'.format(value)) @@ -153,14 +186,17 @@ class StringArray(FieldType): @staticmethod def ddl() -> str: + """See base class.""" return 'ARRAY' @staticmethod def grpc_type() -> spanner_v1.Type: + """See base class.""" return spanner.param_types.Array(spanner.param_types.STRING) @staticmethod def validate_type(value: Any) -> None: + """See base class.""" if not isinstance(value, list): raise error.ValidationError('{} is not of type list'.format(value)) for item in value: @@ -173,14 +209,17 @@ class Timestamp(FieldType): @staticmethod def ddl() -> str: + """See base class.""" return 'TIMESTAMP' @staticmethod def grpc_type() -> spanner_v1.Type: + """See base class.""" return spanner.param_types.TIMESTAMP @staticmethod def validate_type(value: Any) -> None: + """See base class.""" if not isinstance(value, datetime.datetime): raise error.ValidationError('{} is not of type datetime'.format(value)) @@ -190,14 +229,17 @@ class BytesBase64(FieldType): @staticmethod def ddl() -> str: + """See base class.""" return 'BYTES(MAX)' @staticmethod def grpc_type() -> spanner_v1.Type: + """See base class.""" return spanner.param_types.BYTES @staticmethod - def validate_type(value) -> None: + def validate_type(value: Any) -> None: + """See base class.""" if not isinstance(value, bytes): raise error.ValidationError('{} is not of type bytes'.format(value)) # Rudimentary test to check for base64 encoding. From 6484b22aa0b73c3ec3d1674cae5e7cfbbe9a46f4 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 20 Sep 2022 14:07:23 -0400 Subject: [PATCH 107/131] Make some minor things less error-prone 1. Boolean flag-like positional arguments are unreadable and error-prone when parameters are added or removed, so make them keyword only. 2. Mutable global "constants" can be modified by accident. --- spanner_orm/field.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 4db7129..ac8481d 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -54,10 +54,13 @@ class Field: """ name: Optional[str] - def __init__(self, - field_type: Type[FieldType], - nullable: bool = False, - primary_key: bool = False): + def __init__( + self, + field_type: Type[FieldType], + *, + nullable: bool = False, + primary_key: bool = False, + ): """Initializer. Args: @@ -250,6 +253,12 @@ def validate_type(value: Any) -> None: '{} must be base64-encoded bytes.'.format(value)) -ALL_TYPES = [ - Boolean, Integer, Float, String, StringArray, Timestamp, BytesBase64 -] +ALL_TYPES = ( + Boolean, + Integer, + Float, + String, + StringArray, + Timestamp, + BytesBase64, +) From 1e9d594032a3a0dda8da219295ca93f67d901c47 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 20 Sep 2022 14:39:52 -0400 Subject: [PATCH 108/131] Improve string formatting 1. Use f-strings. 2. Use repr() conversions (`!r`) to more clearly separate the value causing an error from the error message and to make types more apparent. E.g., f'{value} is not an int' would be really confusing if value is the string '1'. --- spanner_orm/field.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index ac8481d..75051a1 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -77,7 +77,7 @@ def ddl(self) -> str: """Returns DDL for the column.""" if self._nullable: return self._type.ddl() - return '{field_type} NOT NULL'.format(field_type=self._type.ddl()) + return f'{self._type.ddl()} NOT NULL' def field_type(self) -> Type[FieldType]: """Returns the type of the field.""" @@ -121,7 +121,7 @@ def grpc_type() -> spanner_v1.Type: def validate_type(value: Any) -> None: """See base class.""" if not isinstance(value, bool): - raise error.ValidationError('{} is not of type bool'.format(value)) + raise error.ValidationError(f'{value!r} is not of type bool') class Integer(FieldType): @@ -141,7 +141,7 @@ def grpc_type() -> spanner_v1.Type: def validate_type(value: Any) -> None: """See base class.""" if not isinstance(value, int): - raise error.ValidationError('{} is not of type int'.format(value)) + raise error.ValidationError(f'{value!r} is not of type int') class Float(FieldType): @@ -161,7 +161,7 @@ def grpc_type() -> spanner_v1.Type: def validate_type(value: Any) -> None: """See base class.""" if not isinstance(value, (int, float)): - raise error.ValidationError('{} is not of type float'.format(value)) + raise error.ValidationError(f'{value!r} is not of type float') class String(FieldType): @@ -181,7 +181,7 @@ def grpc_type() -> spanner_v1.Type: def validate_type(value: Any) -> None: """See base class.""" if not isinstance(value, str): - raise error.ValidationError('{} is not of type str'.format(value)) + raise error.ValidationError(f'{value!r} is not of type str') class StringArray(FieldType): @@ -201,10 +201,10 @@ def grpc_type() -> spanner_v1.Type: def validate_type(value: Any) -> None: """See base class.""" if not isinstance(value, list): - raise error.ValidationError('{} is not of type list'.format(value)) + raise error.ValidationError(f'{value!r} is not of type list') for item in value: if not isinstance(item, str): - raise error.ValidationError('{} is not of type str'.format(item)) + raise error.ValidationError(f'{item!r} is not of type str') class Timestamp(FieldType): @@ -224,7 +224,7 @@ def grpc_type() -> spanner_v1.Type: def validate_type(value: Any) -> None: """See base class.""" if not isinstance(value, datetime.datetime): - raise error.ValidationError('{} is not of type datetime'.format(value)) + raise error.ValidationError(f'{value!r} is not of type datetime') class BytesBase64(FieldType): @@ -244,13 +244,12 @@ def grpc_type() -> spanner_v1.Type: def validate_type(value: Any) -> None: """See base class.""" if not isinstance(value, bytes): - raise error.ValidationError('{} is not of type bytes'.format(value)) + raise error.ValidationError(f'{value!r} is not of type bytes') # Rudimentary test to check for base64 encoding. try: base64.b64decode(value, altchars=None, validate=True) except binascii.Error: - raise error.ValidationError( - '{} must be base64-encoded bytes.'.format(value)) + raise error.ValidationError(f'{value!r} must be base64-encoded bytes.') ALL_TYPES = ( From efde90ae69c0f5c04a4837c4359efbba3c7d24ff Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 20 Sep 2022 17:21:21 -0400 Subject: [PATCH 109/131] Unpin pytype The import-error bug is annoying, but 1) it would be nice to get updated pytype changes again and 2) this is holding us back from supporting python 3.10. --- .github/workflows/test.yaml | 4 +--- spanner_orm/admin/migration_manager.py | 3 ++- spanner_orm/api.py | 3 ++- spanner_orm/condition.py | 5 +++-- spanner_orm/field.py | 5 +++-- spanner_orm/model.py | 3 ++- spanner_orm/table_apis.py | 10 +++++++--- spanner_orm/tests/api_test.py | 3 ++- spanner_orm/tests/condition_test.py | 15 ++++++++++++++- spanner_orm/tests/model_test.py | 3 ++- spanner_orm/tests/query_test.py | 6 ++++-- 11 files changed, 42 insertions(+), 18 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index eee9708..cf57672 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -53,9 +53,7 @@ jobs: yapf --diff --recursive --parallel . - name: Check types run: | - # TODO(https://github.com/google/pytype/issues/1081): Remove the version - # pin. - pip install pytype==2021.11.29 + pip install pytype pytype --jobs=auto --keep-going spanner_orm - name: Test env: diff --git a/spanner_orm/admin/migration_manager.py b/spanner_orm/admin/migration_manager.py index 3d883bd..6373f20 100644 --- a/spanner_orm/admin/migration_manager.py +++ b/spanner_orm/admin/migration_manager.py @@ -74,7 +74,8 @@ def _migration_from_file(self, filename: str) -> migration.Migration: path = os.path.join(self.basedir, filename) module = importlib.util.module_from_spec( importlib.util.spec_from_file_location(module_name, path)) - importlib.machinery.SourceFileLoader(module_name, path).exec_module(module) + # TODO(https://github.com/google/pytype/issues/1289): Re-enable pyi-error. + importlib.machinery.SourceFileLoader(module_name, path).exec_module(module) # pytype: disable=pyi-error try: result = migration.Migration(module.migration_id, module.prev_migration_id, diff --git a/spanner_orm/api.py b/spanner_orm/api.py index d049617..1c3b987 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -17,10 +17,11 @@ from typing import Any, Callable, Dict, Iterable, Optional, TypeVar, Union import warnings +# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from google.api_core import client_options as api_client_options from google.api_core import exceptions from google.auth import credentials as auth_credentials -from google.cloud import spanner +from google.cloud import spanner # pytype: disable=import-error from google.cloud.spanner_v1 import database as spanner_database from google.cloud.spanner_v1 import pool as spanner_pool from spanner_orm import error diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index a699169..5cd2e90 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -28,9 +28,10 @@ from spanner_orm import index from spanner_orm import relationship +# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from google.api_core import datetime_helpers -from google.cloud import spanner -from google.cloud import spanner_v1 +from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner_v1 # pytype: disable=import-error import immutabledict T = TypeVar('T') diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 75051a1..50025d9 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -19,8 +19,9 @@ import datetime from typing import Any, Optional, Type -from google.cloud import spanner -from google.cloud import spanner_v1 +# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. +from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner_v1 # pytype: disable=import-error from spanner_orm import error diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 2269fe7..f05f6e0 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -29,8 +29,9 @@ from spanner_orm import relationship from spanner_orm import table_apis +# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from google.api_core import exceptions -from google.cloud import spanner +from google.cloud import spanner # pytype: disable=import-error from google.cloud.spanner_v1 import transaction as spanner_transaction T = TypeVar('T') diff --git a/spanner_orm/table_apis.py b/spanner_orm/table_apis.py index 9c87223..9315e81 100644 --- a/spanner_orm/table_apis.py +++ b/spanner_orm/table_apis.py @@ -16,7 +16,8 @@ import logging from typing import Any, Dict, Iterable, List, Sequence -from google.cloud import spanner +# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. +from google.cloud import spanner # pytype: disable=import-error from google.cloud import spanner_v1 from google.cloud.spanner_v1 import transaction as spanner_transaction @@ -49,9 +50,12 @@ def find(transaction: spanner_transaction.Transaction, table_name: str, def sql_query( - transaction: spanner_transaction.Transaction, query: str, + transaction: spanner_transaction.Transaction, + query: str, parameters: Dict[str, Any], - parameter_types: Dict[str, spanner_v1.Type]) -> List[Sequence[Any]]: + # TODO(https://github.com/google/pytype/issues/1287): Re-enable module-attr. + parameter_types: Dict[str, spanner_v1.Type], # pytype: disable=module-attr +) -> List[Sequence[Any]]: """Executes a given SQL query against the Spanner database. This isn't technically read-only, but it's necessary to implement the read- diff --git a/spanner_orm/tests/api_test.py b/spanner_orm/tests/api_test.py index bf2d9a1..a2e682c 100644 --- a/spanner_orm/tests/api_test.py +++ b/spanner_orm/tests/api_test.py @@ -16,9 +16,10 @@ from unittest import mock import warnings +# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from absl.testing import parameterized from google.api_core import exceptions -from google.cloud import spanner +from google.cloud import spanner # pytype: disable=import-error from spanner_orm import api from spanner_orm import error diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py index 5fdd5bb..bccbbcc 100644 --- a/spanner_orm/tests/condition_test.py +++ b/spanner_orm/tests/condition_test.py @@ -19,9 +19,10 @@ import os import unittest +# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from absl.testing import parameterized from google.api_core import datetime_helpers -from google.cloud import spanner +from google.cloud import spanner # pytype: disable=import-error from google.cloud import spanner_v1 import spanner_orm @@ -45,6 +46,9 @@ def setUp(self): )) @parameterized.parameters( + # TODO(https://github.com/google/pytype/issues/1287): Re-enable + # module-attr. + # pytype: disable=module-attr (True, spanner_v1.param_types.BOOL), (0, spanner_v1.param_types.INT64), (0.0, spanner_v1.param_types.FLOAT64), @@ -71,6 +75,7 @@ def setUp(self): array_element_type=spanner_v1.param_types.STRING, ), ), + # pytype: enable=module-attr ) def test_param_from_value(self, value, expected_type): param = condition.Param.from_value(value) @@ -132,6 +137,9 @@ def test_param_from_value_correctly_encodes(self, tautology): self.assertCountEqual((test_model,), models.SmallTestModel.where(tautology)) @parameterized.named_parameters( + # TODO(https://github.com/google/pytype/issues/1287): Re-enable + # module-attr. + # pytype: disable=module-attr ( 'minimal', condition.ArbitraryCondition( @@ -167,6 +175,7 @@ def test_param_from_value_correctly_encodes(self, tautology): 'IF(@true_param0, @key_param0, SmallTestModel.value_1)'), ('some-key',), ), + # pytype: enable=module-attr ) def test_arbitrary_condition( self, @@ -227,6 +236,9 @@ def test_arbitrary_condition_validation_error( models.SmallTestModel.where(condition_) @parameterized.named_parameters( + # TODO(https://github.com/google/pytype/issues/1287): Re-enable + # module-attr. + # pytype: disable=module-attr ( 'empty_or', condition.OrCondition(), @@ -283,6 +295,7 @@ def test_arbitrary_condition_validation_error( ')'), 'ab', ), + # pytype: enable=module-attr ) def test_or_condition( self, diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 560d9c1..07f5c24 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -19,9 +19,10 @@ import unittest from unittest import mock +# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from absl.testing import parameterized from google.api_core import exceptions -from google.cloud import spanner +from google.cloud import spanner # pytype: disable=import-error from spanner_orm import error from spanner_orm import field from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index 97a33b9..46f3745 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -177,8 +177,10 @@ def test_query_where_list_comparison(self, column, values, grpc_type): column_key = '{}0'.format(column) expected_sql = ' WHERE table.{} {} UNNEST(@{})'.format( column, current_condition.operator, column_key) - list_type = spanner_v1.Type( - code=spanner_v1.TypeCode.ARRAY, array_element_type=grpc_type) + # TODO(https://github.com/google/pytype/issues/1287): Re-enable module-attr. + list_type = spanner_v1.Type( # pytype: disable=module-attr + code=spanner_v1.TypeCode.ARRAY, # pytype: disable=module-attr + array_element_type=grpc_type) self.assertEndsWith(select_query.sql(), expected_sql) self.assertEqual(select_query.parameters(), {column_key: values}) self.assertEqual(select_query.types(), {column_key: list_type}) From 2bc4e2ccbd7f6c54e68efe9674ebac600a1eeeec Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 20 Sep 2022 17:12:06 -0400 Subject: [PATCH 110/131] Bump python versions --- .github/workflows/test.yaml | 3 +-- README.md | 6 +++--- setup.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index cf57672..b1a04dc 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -22,9 +22,8 @@ jobs: strategy: matrix: python-version: - - '3.7' - - '3.8' - '3.9' + - '3.10' runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 diff --git a/README.md b/README.md index a9afd7c..580a8ed 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,8 @@ This is not an officially supported Google product. ### How to install -Make sure that Python 3.7 is the default version of python for your environment, -then run: +Make sure that Python 3.9 or higher is the default version of python for your +environment, then run: ```pip install git+https://github.com/google/python-spanner-orm#egg=spanner_orm``` ### Connecting @@ -183,7 +183,7 @@ or the corresponding ```MigrationExecutor``` method should be used. ## Tests -Note: we suggest using a Python 3.7 +Note: we suggest using a Python 3.9 [virtualenv](https://docs.python.org/3/library/venv.html) for running tests and type checking. diff --git a/setup.py b/setup.py index 9ec386e..f1813da 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ url='https://github.com/google/python-spanner-orm', packages=['spanner_orm', 'spanner_orm.admin'], include_package_data=True, - python_requires='~=3.7', + python_requires='~=3.9', install_requires=[ 'google-cloud-spanner >= 2, <4', 'immutabledict', From 3fe10f01fe9e2a76bfc86c4c24680c30aa992288 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Tue, 20 Sep 2022 14:36:34 -0400 Subject: [PATCH 111/131] Use instances of FieldType instead of classes Aside from , this makes it possible to have types with parameters. E.g., this makes it possible to add code for something like Array(String(length=20)) for the type ARRAY. That way we don't need to define a new FooArray class for each type of array, and we can add support for lengths other than MAX more easily. This could break any users that use FieldType for anything other than passing to Field, but hopefully the vast majority of uses just pass it to Field. --- README.md | 6 +- spanner_orm/admin/column.py | 9 +- spanner_orm/admin/update.py | 2 +- spanner_orm/condition.py | 2 +- spanner_orm/field.py | 143 ++++++++++++++++++-------------- spanner_orm/tests/admin_test.py | 4 +- spanner_orm/tests/field_test.py | 67 +++++++++++++++ spanner_orm/tests/model_test.py | 3 +- spanner_orm/tests/query_test.py | 39 ++++----- 9 files changed, 179 insertions(+), 96 deletions(-) create mode 100644 spanner_orm/tests/field_test.py diff --git a/README.md b/README.md index 580a8ed..23b36a9 100644 --- a/README.md +++ b/README.md @@ -45,9 +45,9 @@ class TestModel(spanner_orm.Model): # is the type of field. The primary key is constructed by the fields labeled # with primary_key=True in the order they appear in the class. # The name of the column is the same as the name of the class attribute - id = spanner_orm.Field(spanner_orm.String, primary_key=True) - value = spanner_orm.Field(spanner_orm.Integer, nullable=True) - number = spanner_orm.Field(spanner_orm.Float, nullable=True) + id = spanner_orm.Field(spanner_orm.String(), primary_key=True) + value = spanner_orm.Field(spanner_orm.Integer(), nullable=True) + number = spanner_orm.Field(spanner_orm.Float(), nullable=True) # Secondary indexes are specified in a similar manner to fields: value_index = spanner_orm.Index(['value']) diff --git a/spanner_orm/admin/column.py b/spanner_orm/admin/column.py index b40a3a3..60f5017 100644 --- a/spanner_orm/admin/column.py +++ b/spanner_orm/admin/column.py @@ -36,10 +36,5 @@ class ColumnSchema(schema.InformationSchema): def nullable(self) -> bool: return self.is_nullable == 'YES' - def field_type(self) -> Type[field.FieldType]: - for field_type in field.ALL_TYPES: - if self.spanner_type == field_type.ddl(): - return field_type - - raise error.SpannerError('No corresponding Type for {}'.format( - self.spanner_type)) + def field_type(self) -> field.FieldType: + return field.field_type_from_ddl(self.spanner_type) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 6599011..f716e08 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -220,7 +220,7 @@ def validate(self) -> None: old_field = model_.fields[self._column] # Validate that the only alteration is to change column nullability - if self._field.field_type() != old_field.field_type(): + if self._field.field_type().ddl() != old_field.field_type().ddl(): raise error.SpannerError('Column {} is changing type'.format( self._column)) if self._field.nullable() == old_field.nullable(): diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 5cd2e90..7bd8fc1 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -355,7 +355,7 @@ def _validate(self, model_class: Type[Any]) -> None: self.destination_column, self.destination_model_class.table)) dest = self.destination_model_class.fields[self.destination_column] - if (origin.field_type() != dest.field_type() or + if (not origin.field_type().comparable_with(dest.field_type()) or origin.nullable() != dest.nullable()): raise error.ValidationError('Types of {} and {} do not match'.format( origin.name, dest.name)) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 50025d9..d010050 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -17,7 +17,8 @@ import base64 import binascii import datetime -from typing import Any, Optional, Type +from typing import Any, Optional, Type, Union +import warnings # TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from google.cloud import spanner # pytype: disable=import-error @@ -28,24 +29,26 @@ class FieldType(abc.ABC): """Base class for column types for Spanner interactions.""" - @staticmethod @abc.abstractmethod - def ddl() -> str: + def ddl(self) -> str: """Returns the DDL for this type.""" raise NotImplementedError - @staticmethod @abc.abstractmethod - def grpc_type() -> spanner_v1.Type: + def grpc_type(self) -> spanner_v1.Type: """Returns the type as used in Cloud Spanner's gRPC API.""" raise NotImplementedError - @staticmethod @abc.abstractmethod - def validate_type(value: Any) -> None: + def validate_type(self, value: Any) -> None: """Raises error.ValidationError if value doesn't match the type.""" raise NotImplementedError + def comparable_with(self, other: 'FieldType') -> bool: + """Returns whether two types are comparable.""" + # https://cloud.google.com/spanner/docs/reference/standard-sql/data-types#comparable_data_types + return type(self) == type(other) + class Field: """Represents a column in a table as a field in a model. @@ -57,7 +60,7 @@ class Field: def __init__( self, - field_type: Type[FieldType], + field_type: Union[FieldType, Type[FieldType]], *, nullable: bool = False, primary_key: bool = False, @@ -65,12 +68,19 @@ def __init__( """Initializer. Args: - field_type: Type of the field. + field_type: Type of the field. Passing a class instead of an instance of + that class is deprecated. nullable: Whether the field can be NULL. primary_key: Whether the field is part of the table's primary key. """ self.name = None - self._type = field_type + if isinstance(field_type, FieldType): + self._type = field_type + else: + warnings.warn( + DeprecationWarning( + 'Pass an instance of FieldType instead of a class.')) + self._type = field_type() self._nullable = nullable self._primary_key = primary_key @@ -80,7 +90,7 @@ def ddl(self) -> str: return self._type.ddl() return f'{self._type.ddl()} NOT NULL' - def field_type(self) -> Type[FieldType]: + def field_type(self) -> FieldType: """Returns the type of the field.""" return self._type @@ -108,19 +118,19 @@ def validate(self, value: Any) -> None: class Boolean(FieldType): """Represents a boolean type.""" - @staticmethod - def ddl() -> str: + def ddl(self) -> str: """See base class.""" + del self # Unused. return 'BOOL' - @staticmethod - def grpc_type() -> spanner_v1.Type: + def grpc_type(self) -> spanner_v1.Type: """See base class.""" + del self # Unused. return spanner.param_types.BOOL - @staticmethod - def validate_type(value: Any) -> None: + def validate_type(self, value: Any) -> None: """See base class.""" + del self # Unused. if not isinstance(value, bool): raise error.ValidationError(f'{value!r} is not of type bool') @@ -128,19 +138,19 @@ def validate_type(value: Any) -> None: class Integer(FieldType): """Represents an integer type.""" - @staticmethod - def ddl() -> str: + def ddl(self) -> str: """See base class.""" + del self # Unused. return 'INT64' - @staticmethod - def grpc_type() -> spanner_v1.Type: + def grpc_type(self) -> spanner_v1.Type: """See base class.""" + del self # Unused. return spanner.param_types.INT64 - @staticmethod - def validate_type(value: Any) -> None: + def validate_type(self, value: Any) -> None: """See base class.""" + del self # Unused. if not isinstance(value, int): raise error.ValidationError(f'{value!r} is not of type int') @@ -148,19 +158,19 @@ def validate_type(value: Any) -> None: class Float(FieldType): """Represents a float type.""" - @staticmethod - def ddl() -> str: + def ddl(self) -> str: """See base class.""" + del self # Unused. return 'FLOAT64' - @staticmethod - def grpc_type() -> spanner_v1.Type: + def grpc_type(self) -> spanner_v1.Type: """See base class.""" + del self # Unused. return spanner.param_types.FLOAT64 - @staticmethod - def validate_type(value: Any) -> None: + def validate_type(self, value: Any) -> None: """See base class.""" + del self # Unused. if not isinstance(value, (int, float)): raise error.ValidationError(f'{value!r} is not of type float') @@ -168,19 +178,19 @@ def validate_type(value: Any) -> None: class String(FieldType): """Represents a string type.""" - @staticmethod - def ddl() -> str: + def ddl(self) -> str: """See base class.""" + del self # Unused. return 'STRING(MAX)' - @staticmethod - def grpc_type() -> spanner_v1.Type: + def grpc_type(self) -> spanner_v1.Type: """See base class.""" + del self # Unused. return spanner.param_types.STRING - @staticmethod - def validate_type(value: Any) -> None: + def validate_type(self, value: Any) -> None: """See base class.""" + del self # Unused. if not isinstance(value, str): raise error.ValidationError(f'{value!r} is not of type str') @@ -188,19 +198,19 @@ def validate_type(value: Any) -> None: class StringArray(FieldType): """Represents an array of strings type.""" - @staticmethod - def ddl() -> str: + def ddl(self) -> str: """See base class.""" + del self # Unused. return 'ARRAY' - @staticmethod - def grpc_type() -> spanner_v1.Type: + def grpc_type(self) -> spanner_v1.Type: """See base class.""" + del self # Unused. return spanner.param_types.Array(spanner.param_types.STRING) - @staticmethod - def validate_type(value: Any) -> None: + def validate_type(self, value: Any) -> None: """See base class.""" + del self # Unused. if not isinstance(value, list): raise error.ValidationError(f'{value!r} is not of type list') for item in value: @@ -211,19 +221,19 @@ def validate_type(value: Any) -> None: class Timestamp(FieldType): """Represents a timestamp type.""" - @staticmethod - def ddl() -> str: + def ddl(self) -> str: """See base class.""" + del self # Unused. return 'TIMESTAMP' - @staticmethod - def grpc_type() -> spanner_v1.Type: + def grpc_type(self) -> spanner_v1.Type: """See base class.""" + del self # Unused. return spanner.param_types.TIMESTAMP - @staticmethod - def validate_type(value: Any) -> None: + def validate_type(self, value: Any) -> None: """See base class.""" + del self # Unused. if not isinstance(value, datetime.datetime): raise error.ValidationError(f'{value!r} is not of type datetime') @@ -231,19 +241,19 @@ def validate_type(value: Any) -> None: class BytesBase64(FieldType): """Represents a bytes type that must be base64 encoded.""" - @staticmethod - def ddl() -> str: + def ddl(self) -> str: """See base class.""" + del self # Unused. return 'BYTES(MAX)' - @staticmethod - def grpc_type() -> spanner_v1.Type: + def grpc_type(self) -> spanner_v1.Type: """See base class.""" + del self # Unused. return spanner.param_types.BYTES - @staticmethod - def validate_type(value: Any) -> None: + def validate_type(self, value: Any) -> None: """See base class.""" + del self # Unused. if not isinstance(value, bytes): raise error.ValidationError(f'{value!r} is not of type bytes') # Rudimentary test to check for base64 encoding. @@ -253,12 +263,21 @@ def validate_type(value: Any) -> None: raise error.ValidationError(f'{value!r} must be base64-encoded bytes.') -ALL_TYPES = ( - Boolean, - Integer, - Float, - String, - StringArray, - Timestamp, - BytesBase64, -) +def field_type_from_ddl(ddl: str) -> FieldType: + """Returns the the field type for the given DDL expression.""" + if ddl == 'BOOL': + return Boolean() + elif ddl == 'INT64': + return Integer() + elif ddl == 'FLOAT64': + return Float() + elif ddl == 'STRING(MAX)': + return String() + elif ddl == 'ARRAY': + return StringArray() + elif ddl == 'TIMESTAMP': + return Timestamp() + elif ddl == 'BYTES(MAX)': + return BytesBase64() + else: + raise error.SpannerError(f'Invalid or unimplemented DDL type: {ddl!r}') diff --git a/spanner_orm/tests/admin_test.py b/spanner_orm/tests/admin_test.py index 22e890e..b28ce7c 100644 --- a/spanner_orm/tests/admin_test.py +++ b/spanner_orm/tests/admin_test.py @@ -108,8 +108,8 @@ def test_metadata(self, tables, columns, index_columns, indexes): self.assertEqual(meta.table, model.table) self.assertEqual(meta.columns, model.columns) for row in model.columns: - self.assertEqual(meta.fields[row].field_type(), - model.fields[row].field_type()) + self.assertEqual(meta.fields[row].field_type().ddl(), + model.fields[row].field_type().ddl()) self.assertEqual(meta.fields[row].nullable(), model.fields[row].nullable()) self.assertEqual(meta.primary_keys, model.primary_keys) diff --git a/spanner_orm/tests/field_test.py b/spanner_orm/tests/field_test.py new file mode 100644 index 0000000..03138ce --- /dev/null +++ b/spanner_orm/tests/field_test.py @@ -0,0 +1,67 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for field.""" + +import warnings + +from spanner_orm import error +from spanner_orm import field +from absl.testing import absltest +from absl.testing import parameterized + + +class FieldTest(parameterized.TestCase): + + @parameterized.parameters( + (field.Boolean(), field.Boolean(), True), + (field.Boolean(), field.String(), False), + ) + def test_field_type_comparable_with( + self, + field_type_1: field.FieldType, + field_type_2: field.FieldType, + expected_comparable: bool, + ): + self.assertEqual( + field_type_1.comparable_with(field_type_2), expected_comparable) + self.assertEqual( + field_type_2.comparable_with(field_type_1), expected_comparable) + + def test_field_field_type_is_class(self): + with warnings.catch_warnings(record=True) as actual_warnings: + self.assertIsInstance( + field.Field(field.String).field_type(), field.String) + self.assertLen(actual_warnings, 1) + self.assertIn('instance of FieldType', str(actual_warnings[0].message)) + self.assertIs(actual_warnings[0].category, DeprecationWarning) + + @parameterized.parameters( + 'BOOL', + 'INT64', + 'FLOAT64', + 'STRING(MAX)', + 'ARRAY', + 'TIMESTAMP', + 'BYTES(MAX)', + ) + def test_ddl_to_field_type_to_ddl(self, ddl: str): + self.assertEqual(field.field_type_from_ddl(ddl).ddl(), ddl) + + def test_field_type_from_ddl_invalid(self): + with self.assertRaisesRegex(error.SpannerError, 'DDL type'): + field.field_type_from_ddl('UNICORN(MAX)') + + +if __name__ == '__main__': + absltest.main() diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 07f5c24..4547167 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -374,7 +374,8 @@ def test_object_changes(self): def test_field_exists_on_model_class(self): self.assertIsInstance(models.SmallTestModel.key, field.Field) - self.assertEqual(models.SmallTestModel.key.field_type(), field.String) + self.assertEqual(models.SmallTestModel.key.field_type().ddl(), + 'STRING(MAX)') self.assertFalse(models.SmallTestModel.key.nullable()) self.assertEqual(models.SmallTestModel.key.name, 'key') diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index 46f3745..ab3c2b7 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -42,7 +42,7 @@ def test_where(self, sql_query): expected_sql = 'SELECT .* FROM table WHERE table.int_ = @int_0' self.assertRegex(sql, expected_sql) self.assertEqual(parameters, {'int_0': 3}) - self.assertEqual(types, {'int_0': field.Integer.grpc_type()}) + self.assertEqual(types, {'int_0': field.Integer().grpc_type()}) @mock.patch('spanner_orm.table_apis.sql_query') def test_count(self, sql_query): @@ -56,7 +56,7 @@ def test_count(self, sql_query): column, column_key) self.assertRegex(sql, expected_sql) self.assertEqual({column_key: value}, parameters) - self.assertEqual(types, {column_key: field.Integer.grpc_type()}) + self.assertEqual(types, {column_key: field.Integer().grpc_type()}) def test_count_allows_force_index(self): force_index = condition.force_index('test_index') @@ -81,7 +81,7 @@ def test_query_limit(self): self.assertEndsWith(select_query.sql(), ' LIMIT @{}'.format(key)) self.assertEqual(select_query.parameters(), {key: value}) - self.assertEqual(select_query.types(), {key: field.Integer.grpc_type()}) + self.assertEqual(select_query.types(), {key: field.Integer().grpc_type()}) select_query = self.select() self.assertNotRegex(select_query.sql(), 'LIMIT') @@ -97,10 +97,11 @@ def test_query_limit_offset(self): limit_key: limit, offset_key: offset }) - self.assertEqual(select_query.types(), { - limit_key: field.Integer.grpc_type(), - offset_key: field.Integer.grpc_type() - }) + self.assertEqual( + select_query.types(), { + limit_key: field.Integer().grpc_type(), + offset_key: field.Integer().grpc_type() + }) def test_query_order_by(self): order = ('int_', condition.OrderType.DESC) @@ -124,9 +125,9 @@ def test_query_order_by_with_object(self): select_query = self.select() self.assertNotRegex(select_query.sql(), 'ORDER BY') - @parameterized.parameters(('int_', 5, field.Integer.grpc_type()), - ('string', 'foo', field.String.grpc_type()), - ('timestamp', now(), field.Timestamp.grpc_type())) + @parameterized.parameters(('int_', 5, field.Integer().grpc_type()), + ('string', 'foo', field.String().grpc_type()), + ('timestamp', now(), field.Timestamp().grpc_type())) def test_query_where_comparison(self, column, value, grpc_type): condition_generators = [ condition.greater_than, condition.not_less_than, condition.less_than, @@ -144,9 +145,9 @@ def test_query_where_comparison(self, column, value, grpc_type): self.assertEqual(select_query.types(), {column_key: grpc_type}) @parameterized.parameters( - (models.UnittestModel.int_, 5, field.Integer.grpc_type()), - (models.UnittestModel.string, 'foo', field.String.grpc_type()), - (models.UnittestModel.timestamp, now(), field.Timestamp.grpc_type())) + (models.UnittestModel.int_, 5, field.Integer().grpc_type()), + (models.UnittestModel.string, 'foo', field.String().grpc_type()), + (models.UnittestModel.timestamp, now(), field.Timestamp().grpc_type())) def test_query_where_comparison_with_object(self, column, value, grpc_type): condition_generators = [ condition.greater_than, condition.not_less_than, condition.less_than, @@ -164,10 +165,10 @@ def test_query_where_comparison_with_object(self, column, value, grpc_type): self.assertEqual(select_query.types(), {column_key: grpc_type}) @parameterized.parameters( - ('int_', [1, 2, 3], field.Integer.grpc_type()), - ('int_', (4, 5, 6), field.Integer.grpc_type()), - ('string', ['a', 'b', 'c'], field.String.grpc_type()), - ('timestamp', [now()], field.Timestamp.grpc_type())) + ('int_', [1, 2, 3], field.Integer().grpc_type()), + ('int_', (4, 5, 6), field.Integer().grpc_type()), + ('string', ['a', 'b', 'c'], field.String().grpc_type()), + ('timestamp', [now()], field.Timestamp().grpc_type())) def test_query_where_list_comparison(self, column, values, grpc_type): condition_generators = [condition.in_list, condition.not_in_list] for condition_generator in condition_generators: @@ -461,8 +462,8 @@ def test_or(self): self.assertEndsWith(select_query.sql(), expected_sql) self.assertEqual(select_query.parameters(), {'int_0': 1, 'int_1': 2}) self.assertEqual(select_query.types(), { - 'int_0': field.Integer.grpc_type(), - 'int_1': field.Integer.grpc_type() + 'int_0': field.Integer().grpc_type(), + 'int_1': field.Integer().grpc_type() }) From a4bcfa1fe007c3885382160f7fee7aae8acab59b Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 21 Sep 2022 12:30:46 -0400 Subject: [PATCH 112/131] Fix typo --- spanner_orm/field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index d010050..7304fdc 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -264,7 +264,7 @@ def validate_type(self, value: Any) -> None: def field_type_from_ddl(ddl: str) -> FieldType: - """Returns the the field type for the given DDL expression.""" + """Returns the field type for the given DDL expression.""" if ddl == 'BOOL': return Boolean() elif ddl == 'INT64': From 5b2ade3289cd23e4fbb22558d122c6be23a70c99 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 21 Sep 2022 13:40:16 -0400 Subject: [PATCH 113/131] Fix return type of Field.grpc_type() --- spanner_orm/field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 7304fdc..ecb3894 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -94,7 +94,7 @@ def field_type(self) -> FieldType: """Returns the type of the field.""" return self._type - def grpc_type(self) -> str: + def grpc_type(self) -> spanner_v1.Type: """Returns the type as used in Cloud Spanner's gRPC API.""" return self._type.grpc_type() From 96dacf9835538659c9c534d05c92090a19b1c3c0 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 21 Sep 2022 13:43:34 -0400 Subject: [PATCH 114/131] Add more tests for spanner_orm.field I didn't add any tests for StringArray because I'm about to make other changes to it. --- spanner_orm/tests/field_test.py | 73 ++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 2 deletions(-) diff --git a/spanner_orm/tests/field_test.py b/spanner_orm/tests/field_test.py index 03138ce..ce02569 100644 --- a/spanner_orm/tests/field_test.py +++ b/spanner_orm/tests/field_test.py @@ -13,16 +13,85 @@ # limitations under the License. """Tests for field.""" +import base64 +import datetime +from typing import Any import warnings -from spanner_orm import error -from spanner_orm import field +# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from absl.testing import absltest from absl.testing import parameterized +from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner_v1 # pytype: disable=import-error +from spanner_orm import error +from spanner_orm import field class FieldTest(parameterized.TestCase): + @parameterized.parameters( + (field.Boolean(), 'BOOL'), + (field.Integer(), 'INT64'), + (field.Float(), 'FLOAT64'), + (field.String(), 'STRING(MAX)'), + (field.Timestamp(), 'TIMESTAMP'), + (field.BytesBase64(), 'BYTES(MAX)'), + ) + def test_field_type_ddl( + self, + field_type: field.FieldType, + ddl: str, + ): + self.assertEqual(field_type.ddl(), ddl) + + @parameterized.parameters( + (field.Boolean(), spanner.param_types.BOOL), + (field.Integer(), spanner.param_types.INT64), + (field.Float(), spanner.param_types.FLOAT64), + (field.String(), spanner.param_types.STRING), + (field.Timestamp(), spanner.param_types.TIMESTAMP), + (field.BytesBase64(), spanner.param_types.BYTES), + ) + def test_field_type_grpc_type( + self, + field_type: field.FieldType, + grpc_type: spanner_v1.Type, + ): + self.assertEqual(field_type.grpc_type(), grpc_type) + + @parameterized.parameters( + (field.Boolean(), True), + (field.Integer(), 1), + (field.Float(), 1), + (field.Float(), 1.0), + (field.String(), 'foo'), + (field.Timestamp(), datetime.datetime(2022, 9, 21)), + (field.BytesBase64(), base64.b64encode(b'\x00')), + ) + def test_field_type_validate_type_ok( + self, + field_type: field.FieldType, + value: Any, + ): + field_type.validate_type(value) + + @parameterized.parameters( + (field.Boolean(), 1), + (field.Integer(), 1.0), + (field.Float(), '1.0'), + (field.String(), b'foo'), + (field.Timestamp(), datetime.date(2022, 9, 21)), + (field.BytesBase64(), base64.b64encode(b'\x00').decode('utf-8')), + (field.BytesBase64(), b'!'), + ) + def test_field_type_validate_type_error( + self, + field_type: field.FieldType, + value: Any, + ): + with self.assertRaises(error.ValidationError): + field_type.validate_type(value) + @parameterized.parameters( (field.Boolean(), field.Boolean(), True), (field.Boolean(), field.String(), False), From 924e01b70d919dd9553d8fe43f4cda5f06fbaee0 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 21 Sep 2022 17:05:30 -0400 Subject: [PATCH 115/131] Add a generic Array type and deprecate StringArray --- spanner_orm/__init__.py | 1 + spanner_orm/field.py | 76 ++++++++++++++++++++++----------- spanner_orm/tests/field_test.py | 29 ++++++++++++- 3 files changed, 80 insertions(+), 26 deletions(-) diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 83e283a..5b27fc6 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -68,6 +68,7 @@ StringArray = field.StringArray Timestamp = field.Timestamp BytesBase64 = field.BytesBase64 +Array = field.Array ArbitraryCondition = condition.ArbitraryCondition Column = condition.Column diff --git a/spanner_orm/field.py b/spanner_orm/field.py index ecb3894..c04997b 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -17,6 +17,7 @@ import base64 import binascii import datetime +import re from typing import Any, Optional, Type, Union import warnings @@ -195,29 +196,6 @@ def validate_type(self, value: Any) -> None: raise error.ValidationError(f'{value!r} is not of type str') -class StringArray(FieldType): - """Represents an array of strings type.""" - - def ddl(self) -> str: - """See base class.""" - del self # Unused. - return 'ARRAY' - - def grpc_type(self) -> spanner_v1.Type: - """See base class.""" - del self # Unused. - return spanner.param_types.Array(spanner.param_types.STRING) - - def validate_type(self, value: Any) -> None: - """See base class.""" - del self # Unused. - if not isinstance(value, list): - raise error.ValidationError(f'{value!r} is not of type list') - for item in value: - if not isinstance(item, str): - raise error.ValidationError(f'{item!r} is not of type str') - - class Timestamp(FieldType): """Represents a timestamp type.""" @@ -263,6 +241,54 @@ def validate_type(self, value: Any) -> None: raise error.ValidationError(f'{value!r} must be base64-encoded bytes.') +class Array(FieldType): + """Represents an array type.""" + + def __init__(self, element_type: FieldType): + """Initializer. + + Args: + element_type: Type of the values in the array. Can't be an Array type + itself. + """ + if isinstance(element_type, Array): + # https://cloud.google.com/spanner/docs/reference/standard-sql/data-types#array_type + raise error.SpannerError( + 'Cloud Spanner does not support arrays of arrays.') + self._element_type = element_type + + def ddl(self) -> str: + """See base class.""" + return f'ARRAY<{self._element_type.ddl()}>' + + def grpc_type(self) -> spanner_v1.Type: + """See base class.""" + return spanner.param_types.Array(self._element_type.grpc_type()) + + def validate_type(self, value: Any) -> None: + """See base class.""" + if not isinstance(value, list): + raise error.ValidationError(f'{value!r} is not of type list') + for element in value: + self._element_type.validate_type(element) + + def comparable_with(self, other: FieldType) -> bool: + """See base class.""" + # Running `select [1, 2] = [1, 2];` in Cloud Spanner gives this error: Query + # failed: Equality is not defined for arguments of type ARRAY at line + # 3, column 8 + return False + + +class StringArray(Array): + """Deprecated way to represent an array of strings type.""" + + def __init__(self): + super().__init__(String()) + warnings.warn( + DeprecationWarning('Use Array(String()) instead of StringArray().')) + + def field_type_from_ddl(ddl: str) -> FieldType: """Returns the field type for the given DDL expression.""" if ddl == 'BOOL': @@ -273,11 +299,11 @@ def field_type_from_ddl(ddl: str) -> FieldType: return Float() elif ddl == 'STRING(MAX)': return String() - elif ddl == 'ARRAY': - return StringArray() elif ddl == 'TIMESTAMP': return Timestamp() elif ddl == 'BYTES(MAX)': return BytesBase64() + elif (match := re.fullmatch(r'ARRAY<(.*)>', ddl)) is not None: + return Array(field_type_from_ddl(match.group(1))) else: raise error.SpannerError(f'Invalid or unimplemented DDL type: {ddl!r}') diff --git a/spanner_orm/tests/field_test.py b/spanner_orm/tests/field_test.py index ce02569..6543d4a 100644 --- a/spanner_orm/tests/field_test.py +++ b/spanner_orm/tests/field_test.py @@ -36,6 +36,8 @@ class FieldTest(parameterized.TestCase): (field.String(), 'STRING(MAX)'), (field.Timestamp(), 'TIMESTAMP'), (field.BytesBase64(), 'BYTES(MAX)'), + (field.Array(field.Boolean()), 'ARRAY'), + (field.Array(field.String()), 'ARRAY'), ) def test_field_type_ddl( self, @@ -51,6 +53,10 @@ def test_field_type_ddl( (field.String(), spanner.param_types.STRING), (field.Timestamp(), spanner.param_types.TIMESTAMP), (field.BytesBase64(), spanner.param_types.BYTES), + (field.Array(field.Boolean()), + spanner.param_types.Array(spanner.param_types.BOOL)), + (field.Array(field.String()), + spanner.param_types.Array(spanner.param_types.STRING)), ) def test_field_type_grpc_type( self, @@ -67,6 +73,7 @@ def test_field_type_grpc_type( (field.String(), 'foo'), (field.Timestamp(), datetime.datetime(2022, 9, 21)), (field.BytesBase64(), base64.b64encode(b'\x00')), + (field.Array(field.Boolean()), [True]), ) def test_field_type_validate_type_ok( self, @@ -83,6 +90,8 @@ def test_field_type_validate_type_ok( (field.Timestamp(), datetime.date(2022, 9, 21)), (field.BytesBase64(), base64.b64encode(b'\x00').decode('utf-8')), (field.BytesBase64(), b'!'), + (field.Array(field.Boolean()), {True}), + (field.Array(field.Boolean()), [1]), ) def test_field_type_validate_type_error( self, @@ -95,6 +104,8 @@ def test_field_type_validate_type_error( @parameterized.parameters( (field.Boolean(), field.Boolean(), True), (field.Boolean(), field.String(), False), + (field.Array(field.Integer()), field.Array(field.Integer()), False), + (field.Array(field.Integer()), field.Integer(), False), ) def test_field_type_comparable_with( self, @@ -115,14 +126,30 @@ def test_field_field_type_is_class(self): self.assertIn('instance of FieldType', str(actual_warnings[0].message)) self.assertIs(actual_warnings[0].category, DeprecationWarning) + def test_array_of_array_is_invalid(self): + with self.assertRaisesRegex(error.SpannerError, 'arrays of arrays'): + field.Array(field.Array(field.String())) + + def test_string_array_is_deprecated_and_equivalent_to_array_of_string(self): + with warnings.catch_warnings(record=True) as actual_warnings: + string_array = field.StringArray() + array_of_string = field.Array(field.String()) + self.assertLen(actual_warnings, 1) + self.assertIn('Use Array(String()) instead', + str(actual_warnings[0].message)) + self.assertIs(actual_warnings[0].category, DeprecationWarning) + self.assertEqual(string_array.ddl(), array_of_string.ddl()) + self.assertEqual(string_array.grpc_type(), array_of_string.grpc_type()) + @parameterized.parameters( 'BOOL', 'INT64', 'FLOAT64', 'STRING(MAX)', - 'ARRAY', 'TIMESTAMP', 'BYTES(MAX)', + 'ARRAY', + 'ARRAY', ) def test_ddl_to_field_type_to_ddl(self, ddl: str): self.assertEqual(field.field_type_from_ddl(ddl).ddl(), ddl) From 95a058b3a2cf0add8d6ed29b4fea295736e15454 Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Fri, 23 Sep 2022 15:55:43 -0400 Subject: [PATCH 116/131] Re-add support for python 3.8 After some discussion, I think we settled on trying a little bit harder to support more versions of python. --- .github/workflows/test.yaml | 1 + README.md | 4 ++-- setup.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b1a04dc..977ed65 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -22,6 +22,7 @@ jobs: strategy: matrix: python-version: + - '3.8' - '3.9' - '3.10' runs-on: ubuntu-latest diff --git a/README.md b/README.md index 23b36a9..ee1bed2 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ This is not an officially supported Google product. ### How to install -Make sure that Python 3.9 or higher is the default version of python for your +Make sure that Python 3.8 or higher is the default version of python for your environment, then run: ```pip install git+https://github.com/google/python-spanner-orm#egg=spanner_orm``` @@ -183,7 +183,7 @@ or the corresponding ```MigrationExecutor``` method should be used. ## Tests -Note: we suggest using a Python 3.9 +Note: we suggest using a Python 3.8 [virtualenv](https://docs.python.org/3/library/venv.html) for running tests and type checking. diff --git a/setup.py b/setup.py index f1813da..0bcdb5f 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ url='https://github.com/google/python-spanner-orm', packages=['spanner_orm', 'spanner_orm.admin'], include_package_data=True, - python_requires='~=3.9', + python_requires='~=3.8', install_requires=[ 'google-cloud-spanner >= 2, <4', 'immutabledict', From e8602e969e1bb0e59a83087407e5c70c70e2a06d Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Wed, 28 Sep 2022 13:44:18 -0400 Subject: [PATCH 117/131] Remove pytype disable comments for fixed bugs --- spanner_orm/admin/migration_manager.py | 3 +-- spanner_orm/api.py | 3 +-- spanner_orm/condition.py | 5 ++--- spanner_orm/field.py | 5 ++--- spanner_orm/model.py | 3 +-- spanner_orm/table_apis.py | 6 ++---- spanner_orm/tests/api_test.py | 3 +-- spanner_orm/tests/condition_test.py | 15 +-------------- spanner_orm/tests/field_test.py | 5 ++--- spanner_orm/tests/model_test.py | 3 +-- spanner_orm/tests/query_test.py | 6 ++---- 11 files changed, 16 insertions(+), 41 deletions(-) diff --git a/spanner_orm/admin/migration_manager.py b/spanner_orm/admin/migration_manager.py index 6373f20..3d883bd 100644 --- a/spanner_orm/admin/migration_manager.py +++ b/spanner_orm/admin/migration_manager.py @@ -74,8 +74,7 @@ def _migration_from_file(self, filename: str) -> migration.Migration: path = os.path.join(self.basedir, filename) module = importlib.util.module_from_spec( importlib.util.spec_from_file_location(module_name, path)) - # TODO(https://github.com/google/pytype/issues/1289): Re-enable pyi-error. - importlib.machinery.SourceFileLoader(module_name, path).exec_module(module) # pytype: disable=pyi-error + importlib.machinery.SourceFileLoader(module_name, path).exec_module(module) try: result = migration.Migration(module.migration_id, module.prev_migration_id, diff --git a/spanner_orm/api.py b/spanner_orm/api.py index 1c3b987..d049617 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -17,11 +17,10 @@ from typing import Any, Callable, Dict, Iterable, Optional, TypeVar, Union import warnings -# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from google.api_core import client_options as api_client_options from google.api_core import exceptions from google.auth import credentials as auth_credentials -from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner from google.cloud.spanner_v1 import database as spanner_database from google.cloud.spanner_v1 import pool as spanner_pool from spanner_orm import error diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 7bd8fc1..0ddbbaa 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -28,10 +28,9 @@ from spanner_orm import index from spanner_orm import relationship -# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from google.api_core import datetime_helpers -from google.cloud import spanner # pytype: disable=import-error -from google.cloud import spanner_v1 # pytype: disable=import-error +from google.cloud import spanner +from google.cloud import spanner_v1 import immutabledict T = TypeVar('T') diff --git a/spanner_orm/field.py b/spanner_orm/field.py index c04997b..bb92d50 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -21,9 +21,8 @@ from typing import Any, Optional, Type, Union import warnings -# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. -from google.cloud import spanner # pytype: disable=import-error -from google.cloud import spanner_v1 # pytype: disable=import-error +from google.cloud import spanner +from google.cloud import spanner_v1 from spanner_orm import error diff --git a/spanner_orm/model.py b/spanner_orm/model.py index f05f6e0..2269fe7 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -29,9 +29,8 @@ from spanner_orm import relationship from spanner_orm import table_apis -# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from google.api_core import exceptions -from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner from google.cloud.spanner_v1 import transaction as spanner_transaction T = TypeVar('T') diff --git a/spanner_orm/table_apis.py b/spanner_orm/table_apis.py index 9315e81..ea98410 100644 --- a/spanner_orm/table_apis.py +++ b/spanner_orm/table_apis.py @@ -16,8 +16,7 @@ import logging from typing import Any, Dict, Iterable, List, Sequence -# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. -from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner from google.cloud import spanner_v1 from google.cloud.spanner_v1 import transaction as spanner_transaction @@ -53,8 +52,7 @@ def sql_query( transaction: spanner_transaction.Transaction, query: str, parameters: Dict[str, Any], - # TODO(https://github.com/google/pytype/issues/1287): Re-enable module-attr. - parameter_types: Dict[str, spanner_v1.Type], # pytype: disable=module-attr + parameter_types: Dict[str, spanner_v1.Type], ) -> List[Sequence[Any]]: """Executes a given SQL query against the Spanner database. diff --git a/spanner_orm/tests/api_test.py b/spanner_orm/tests/api_test.py index a2e682c..bf2d9a1 100644 --- a/spanner_orm/tests/api_test.py +++ b/spanner_orm/tests/api_test.py @@ -16,10 +16,9 @@ from unittest import mock import warnings -# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from absl.testing import parameterized from google.api_core import exceptions -from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner from spanner_orm import api from spanner_orm import error diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py index bccbbcc..5fdd5bb 100644 --- a/spanner_orm/tests/condition_test.py +++ b/spanner_orm/tests/condition_test.py @@ -19,10 +19,9 @@ import os import unittest -# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from absl.testing import parameterized from google.api_core import datetime_helpers -from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner from google.cloud import spanner_v1 import spanner_orm @@ -46,9 +45,6 @@ def setUp(self): )) @parameterized.parameters( - # TODO(https://github.com/google/pytype/issues/1287): Re-enable - # module-attr. - # pytype: disable=module-attr (True, spanner_v1.param_types.BOOL), (0, spanner_v1.param_types.INT64), (0.0, spanner_v1.param_types.FLOAT64), @@ -75,7 +71,6 @@ def setUp(self): array_element_type=spanner_v1.param_types.STRING, ), ), - # pytype: enable=module-attr ) def test_param_from_value(self, value, expected_type): param = condition.Param.from_value(value) @@ -137,9 +132,6 @@ def test_param_from_value_correctly_encodes(self, tautology): self.assertCountEqual((test_model,), models.SmallTestModel.where(tautology)) @parameterized.named_parameters( - # TODO(https://github.com/google/pytype/issues/1287): Re-enable - # module-attr. - # pytype: disable=module-attr ( 'minimal', condition.ArbitraryCondition( @@ -175,7 +167,6 @@ def test_param_from_value_correctly_encodes(self, tautology): 'IF(@true_param0, @key_param0, SmallTestModel.value_1)'), ('some-key',), ), - # pytype: enable=module-attr ) def test_arbitrary_condition( self, @@ -236,9 +227,6 @@ def test_arbitrary_condition_validation_error( models.SmallTestModel.where(condition_) @parameterized.named_parameters( - # TODO(https://github.com/google/pytype/issues/1287): Re-enable - # module-attr. - # pytype: disable=module-attr ( 'empty_or', condition.OrCondition(), @@ -295,7 +283,6 @@ def test_arbitrary_condition_validation_error( ')'), 'ab', ), - # pytype: enable=module-attr ) def test_or_condition( self, diff --git a/spanner_orm/tests/field_test.py b/spanner_orm/tests/field_test.py index 6543d4a..ecc518c 100644 --- a/spanner_orm/tests/field_test.py +++ b/spanner_orm/tests/field_test.py @@ -18,11 +18,10 @@ from typing import Any import warnings -# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from absl.testing import absltest from absl.testing import parameterized -from google.cloud import spanner # pytype: disable=import-error -from google.cloud import spanner_v1 # pytype: disable=import-error +from google.cloud import spanner +from google.cloud import spanner_v1 from spanner_orm import error from spanner_orm import field diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 4547167..28be41c 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -19,10 +19,9 @@ import unittest from unittest import mock -# TODO(https://github.com/google/pytype/issues/1081): Re-enable import-error. from absl.testing import parameterized from google.api_core import exceptions -from google.cloud import spanner # pytype: disable=import-error +from google.cloud import spanner from spanner_orm import error from spanner_orm import field from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index ab3c2b7..ff18627 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -178,10 +178,8 @@ def test_query_where_list_comparison(self, column, values, grpc_type): column_key = '{}0'.format(column) expected_sql = ' WHERE table.{} {} UNNEST(@{})'.format( column, current_condition.operator, column_key) - # TODO(https://github.com/google/pytype/issues/1287): Re-enable module-attr. - list_type = spanner_v1.Type( # pytype: disable=module-attr - code=spanner_v1.TypeCode.ARRAY, # pytype: disable=module-attr - array_element_type=grpc_type) + list_type = spanner_v1.Type( + code=spanner_v1.TypeCode.ARRAY, array_element_type=grpc_type) self.assertEndsWith(select_query.sql(), expected_sql) self.assertEqual(select_query.parameters(), {column_key: values}) self.assertEqual(select_query.types(), {column_key: list_type}) From e55eaa8c3907210c03ded720706b23c7fffabb3d Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Mon, 3 Oct 2022 17:30:11 -0400 Subject: [PATCH 118/131] Trigger tests for PRs Most of our PRs are from branches in this repo, so the push trigger covered those. But I don't think the push trigger covers pushes to branches in other repos that have open PRs in this repo. --- .github/workflows/test.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 977ed65..e25a9f7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -13,6 +13,7 @@ # limitations under the License. on: + pull_request: {} push: {} schedule: - cron: '50 13 * * *' From e3b85a6932a990e1929e28f8d9335ec5ecb725bf Mon Sep 17 00:00:00 2001 From: wlliu Date: Tue, 13 Sep 2022 13:30:56 -0700 Subject: [PATCH 119/131] Initial commit for custom length --- .gitignore | 1 + spanner_orm/field.py | 56 +++++++++++++++++++ ...create_custom_length_field_f959b767457d.py | 32 +++++++++++ spanner_orm/tests/models.py | 3 + spanner_orm/tests/update_test.py | 9 ++- 5 files changed, 98 insertions(+), 3 deletions(-) create mode 100644 spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py diff --git a/.gitignore b/.gitignore index 0e9dea3..6007829 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__ .eggs *.egg-info .pytype +env # Files that may or may not be added to the repo while acquiring the Spanner # emulator. diff --git a/spanner_orm/field.py b/spanner_orm/field.py index bb92d50..06aa2bd 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -83,6 +83,14 @@ def __init__( self._type = field_type() self._nullable = nullable self._primary_key = primary_key + self._length = length + + if self._length < 0: + raise error.ValidationError('length can not be less than zero') + + if not self._type.support_length() and self._length: + raise error.ValidationError('length can not be set on field {}'.format( + self._type)) def ddl(self) -> str: """Returns DDL for the column.""" @@ -134,6 +142,14 @@ def validate_type(self, value: Any) -> None: if not isinstance(value, bool): raise error.ValidationError(f'{value!r} is not of type bool') + @staticmethod + def matches(type: str) -> bool: + return type == 'BOOL' + + @staticmethod + def support_length() -> bool: + return False + class Integer(FieldType): """Represents an integer type.""" @@ -154,6 +170,14 @@ def validate_type(self, value: Any) -> None: if not isinstance(value, int): raise error.ValidationError(f'{value!r} is not of type int') + @staticmethod + def matches(type: str) -> bool: + return type == 'INT64' + + @staticmethod + def support_length() -> bool: + return False + class Float(FieldType): """Represents a float type.""" @@ -174,6 +198,14 @@ def validate_type(self, value: Any) -> None: if not isinstance(value, (int, float)): raise error.ValidationError(f'{value!r} is not of type float') + @staticmethod + def matches(type: str) -> bool: + return type == 'FLOAT64' + + @staticmethod + def support_length() -> bool: + return False + class String(FieldType): """Represents a string type.""" @@ -194,6 +226,14 @@ def validate_type(self, value: Any) -> None: if not isinstance(value, str): raise error.ValidationError(f'{value!r} is not of type str') + @staticmethod + def matches(type: str) -> bool: + return type.startswith('ARRAY') + + @staticmethod + def support_length() -> bool: + return True + class Timestamp(FieldType): """Represents a timestamp type.""" @@ -214,6 +254,14 @@ def validate_type(self, value: Any) -> None: if not isinstance(value, datetime.datetime): raise error.ValidationError(f'{value!r} is not of type datetime') + @staticmethod + def matches(type: str) -> bool: + return type == 'TIMESTAMP' + + @staticmethod + def support_length() -> bool: + return False + class BytesBase64(FieldType): """Represents a bytes type that must be base64 encoded.""" @@ -239,6 +287,14 @@ def validate_type(self, value: Any) -> None: except binascii.Error: raise error.ValidationError(f'{value!r} must be base64-encoded bytes.') + @staticmethod + def matches(type: str) -> bool: + return type[0:6] == 'BYTES(' and type[-1] == ')' + + @staticmethod + def support_length() -> bool: + return True + class Array(FieldType): """Represents an array type.""" diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py new file mode 100644 index 0000000..196d703 --- /dev/null +++ b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py @@ -0,0 +1,32 @@ +"""Spanner ORM migration: create_custom_length_field. + +Migration ID: 'f959b767457d' +Created: 2022-09-13 13:28:34-07:00 +""" + +import spanner_orm + +migration_id = 'f959b767457d' +prev_migration_id = '69a8f072dacf' + + +class OriginalTeeTable(spanner_orm.model.Model): + """ORM Model with the original schema for the Commands table. + Don't update this model, create new migrations instead. + """ + + __table__ = 'Tee' + id = spanner_orm.Field(spanner_orm.String, primary_key=True) + cus_str = spanner_orm.Field(spanner_orm.String, length=555) + cus_bytes = spanner_orm.Field(spanner_orm.BytesBase64, length=12) + cus_strarr = spanner_orm.Field(spanner_orm.StringArray, length=24) + + +def upgrade() -> spanner_orm.CreateTable: + """Creates the original Commands table.""" + return spanner_orm.CreateTable(OriginalTeeTable) + + +def downgrade() -> spanner_orm.DropTable: + """Drops the original Commands table.""" + return spanner_orm.DropTable(OriginalTeeTable.__table__) \ No newline at end of file diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 896b450..1e38485 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -125,10 +125,13 @@ class UnittestModelWithoutSecondaryIndexes(model.Model): float_2 = field.Field(field.Float, nullable=True) string = field.Field(field.String, primary_key=True) string_2 = field.Field(field.String, nullable=True) + string_3 = field.Field(field.String, nullable=True, length=20) bytes_ = field.Field(field.BytesBase64, primary_key=True) bytes_2 = field.Field(field.BytesBase64, nullable=True) + bytes_3 = field.Field(field.BytesBase64, nullable=True, length=20) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) + string_array_2 = field.Field(field.StringArray, nullable=True, length=20) class NullFilteredIndexModel(model.Model): diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index 736c451..6486cde 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -86,10 +86,13 @@ def test_create_table(self, get_model): test_model_ddl = ('CREATE TABLE table (int_ INT64 NOT NULL, int_2 INT64,' ' float_ FLOAT64 NOT NULL, float_2 FLOAT64,' ' string STRING(MAX) NOT NULL, string_2 STRING(MAX),' + ' string_3 STRING(20),' ' bytes_ BYTES(MAX) NOT NULL, bytes_2 BYTES(MAX),' - ' timestamp TIMESTAMP NOT NULL, string_array' - ' ARRAY) PRIMARY KEY ' - '(int_, float_, string, bytes_)') + ' bytes_3 BYTES(20),' + ' timestamp TIMESTAMP NOT NULL,' + ' string_array ARRAY,' + ' string_array_2 ARRAY)' + ' PRIMARY KEY (int_, float_, string, bytes_)') self.assertEqual(test_update.ddl(), test_model_ddl) @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') From f940b9ca1eba42ef1b0eb7c278b10290199ffca4 Mon Sep 17 00:00:00 2001 From: wlliu Date: Tue, 13 Sep 2022 13:34:25 -0700 Subject: [PATCH 120/131] Add new line --- .../create_custom_length_field_f959b767457d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py index 196d703..867ac7d 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py @@ -29,4 +29,4 @@ def upgrade() -> spanner_orm.CreateTable: def downgrade() -> spanner_orm.DropTable: """Drops the original Commands table.""" - return spanner_orm.DropTable(OriginalTeeTable.__table__) \ No newline at end of file + return spanner_orm.DropTable(OriginalTeeTable.__table__) From b5148b8b22b7ace55095758cdd40d5cf089f2da2 Mon Sep 17 00:00:00 2001 From: wlliu Date: Tue, 13 Sep 2022 15:31:40 -0700 Subject: [PATCH 121/131] Address comments --- spanner_orm/field.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 06aa2bd..ad85b3e 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -85,12 +85,12 @@ def __init__( self._primary_key = primary_key self._length = length - if self._length < 0: - raise error.ValidationError('length can not be less than zero') - if not self._type.support_length() and self._length: raise error.ValidationError('length can not be set on field {}'.format( self._type)) + + if self._length < 0: + raise error.ValidationError('length can not be less than zero') def ddl(self) -> str: """Returns DDL for the column.""" @@ -289,7 +289,7 @@ def validate_type(self, value: Any) -> None: @staticmethod def matches(type: str) -> bool: - return type[0:6] == 'BYTES(' and type[-1] == ')' + return type.startswith('BYTES(') and type.endswith(')') @staticmethod def support_length() -> bool: From 2f124c159b371336ad1e8d38f43d31ac56df6247 Mon Sep 17 00:00:00 2001 From: wlliu Date: Fri, 16 Sep 2022 12:14:59 -0700 Subject: [PATCH 122/131] Address comments --- spanner_orm/field.py | 42 ++++++++----------- .../create_unittest_model.py | 3 ++ spanner_orm/tests/model_test.py | 7 ++-- spanner_orm/tests/models.py | 3 ++ 4 files changed, 27 insertions(+), 28 deletions(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index ad85b3e..8417f62 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -85,13 +85,13 @@ def __init__( self._primary_key = primary_key self._length = length - if not self._type.support_length() and self._length: - raise error.ValidationError('length can not be set on field {}'.format( - self._type)) - - if self._length < 0: + if self._length and self._length < 0: raise error.ValidationError('length can not be less than zero') + if not self._type.supports_length() and self._length: + raise error.ValidationError( + f'length can not be set on field {self._type}') + def ddl(self) -> str: """Returns DDL for the column.""" if self._nullable: @@ -146,10 +146,6 @@ def validate_type(self, value: Any) -> None: def matches(type: str) -> bool: return type == 'BOOL' - @staticmethod - def support_length() -> bool: - return False - class Integer(FieldType): """Represents an integer type.""" @@ -174,10 +170,6 @@ def validate_type(self, value: Any) -> None: def matches(type: str) -> bool: return type == 'INT64' - @staticmethod - def support_length() -> bool: - return False - class Float(FieldType): """Represents a float type.""" @@ -202,10 +194,6 @@ def validate_type(self, value: Any) -> None: def matches(type: str) -> bool: return type == 'FLOAT64' - @staticmethod - def support_length() -> bool: - return False - class String(FieldType): """Represents a string type.""" @@ -228,10 +216,14 @@ def validate_type(self, value: Any) -> None: @staticmethod def matches(type: str) -> bool: - return type.startswith('ARRAY') + val = re.findall('ARRAY<(.*)>', type) + # We expect exact one matching result if type matches ARRAY. + if not val or len(val) != 1: + return False + return String.matches(val[0]) @staticmethod - def support_length() -> bool: + def supports_length() -> bool: return True @@ -258,10 +250,6 @@ def validate_type(self, value: Any) -> None: def matches(type: str) -> bool: return type == 'TIMESTAMP' - @staticmethod - def support_length() -> bool: - return False - class BytesBase64(FieldType): """Represents a bytes type that must be base64 encoded.""" @@ -289,10 +277,14 @@ def validate_type(self, value: Any) -> None: @staticmethod def matches(type: str) -> bool: - return type.startswith('BYTES(') and type.endswith(')') + val = re.findall('BYTES\((.*)\)', type) + # We expect exact one matching result if type matches BYTE. + if not val or len(val) != 1: + return False + return val[0] == 'MAX' or val[0].isnumeric() @staticmethod - def support_length() -> bool: + def supports_length() -> bool: return True diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py index 15113ad..7a2699d 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py @@ -34,10 +34,13 @@ class OriginalUnittestModelTable(spanner_orm.model.Model): float_2 = field.Field(field.Float, nullable=True) string = field.Field(field.String, primary_key=True) string_2 = field.Field(field.String, nullable=True) + string_3 = field.Field(field.String, nullable=True, length=20) bytes_ = field.Field(field.BytesBase64, primary_key=True) bytes_2 = field.Field(field.BytesBase64, nullable=True) + bytes_3 = field.Field(field.BytesBase64, nullable=True, length=20) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) + string_array_2 = field.Field(field.StringArray, nullable=True, length=20) def upgrade() -> spanner_orm.CreateTable: diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 28be41c..6ff8e84 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -226,9 +226,10 @@ def test_set_error_on_primary_key(self): with self.assertRaises(AttributeError): test_model.key = 'error' - @parameterized.parameters(('int_2', 'foo'), ('float_2', 'bar'), - ('string_2', 5), ('bytes_2', 'string'), - ('string_array', 'foo'), ('timestamp', 5)) + @parameterized.parameters( + ('int_2', 'foo'), ('float_2', 'bar'), ('string_2', 5), ('string_3', 5), + ('bytes_2', 'string'), ('bytes_2', 'string2'), ('string_array', 'foo'), + ('timestamp', 5)) def test_set_error_on_invalid_type(self, attribute, value): string_array = ['foo', 'bar'] timestamp = datetime.datetime.now(tz=datetime.timezone.utc) diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 1e38485..451c810 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -107,10 +107,13 @@ class UnittestModel(model.Model): float_2 = field.Field(field.Float, nullable=True) string = field.Field(field.String, primary_key=True) string_2 = field.Field(field.String, nullable=True) + string_3 = field.Field(field.String, nullable=True, length=20) bytes_ = field.Field(field.BytesBase64, primary_key=True) bytes_2 = field.Field(field.BytesBase64, nullable=True) + bytes_3 = field.Field(field.BytesBase64, nullable=True, length=20) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) + string_array_2 = field.Field(field.StringArray, nullable=True, length=20) test_index = index.Index(['string_2']) From 86f8dae50beab57844abc14474a7dd93d96634bc Mon Sep 17 00:00:00 2001 From: wlliu Date: Fri, 16 Sep 2022 14:15:46 -0700 Subject: [PATCH 123/131] address more comments --- spanner_orm/field.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 8417f62..00358df 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -85,12 +85,12 @@ def __init__( self._primary_key = primary_key self._length = length - if self._length and self._length < 0: - raise error.ValidationError('length can not be less than zero') - - if not self._type.supports_length() and self._length: - raise error.ValidationError( - f'length can not be set on field {self._type}') + if self._length is not None: + if self._length < 0: + raise error.ValidationError('length can not be less than zero') + if not self._type.supports_length(): + raise error.ValidationError( + f'length can not be set on field {self._type}') def ddl(self) -> str: """Returns DDL for the column.""" @@ -216,11 +216,7 @@ def validate_type(self, value: Any) -> None: @staticmethod def matches(type: str) -> bool: - val = re.findall('ARRAY<(.*)>', type) - # We expect exact one matching result if type matches ARRAY. - if not val or len(val) != 1: - return False - return String.matches(val[0]) + return re.fullmatch('ARRAY', type) is not None @staticmethod def supports_length() -> bool: @@ -277,11 +273,7 @@ def validate_type(self, value: Any) -> None: @staticmethod def matches(type: str) -> bool: - val = re.findall('BYTES\((.*)\)', type) - # We expect exact one matching result if type matches BYTE. - if not val or len(val) != 1: - return False - return val[0] == 'MAX' or val[0].isnumeric() + return re.fullmatch('BYTES\((?:[0-9]+|MAX)\)', type) is not None @staticmethod def supports_length() -> bool: From 810dfb295667b3c226ffbc1c57f16a9f6ed3214e Mon Sep 17 00:00:00 2001 From: wlliu Date: Thu, 22 Sep 2022 18:48:32 -0700 Subject: [PATCH 124/131] Rebase --- spanner_orm/field.py | 48 +++---------------- ...create_custom_length_field_f959b767457d.py | 4 +- .../create_unittest_model.py | 5 +- spanner_orm/tests/model_test.py | 3 +- spanner_orm/tests/models.py | 10 ++-- spanner_orm/tests/update_test.py | 1 - 6 files changed, 14 insertions(+), 57 deletions(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 00358df..695ff78 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -83,14 +83,6 @@ def __init__( self._type = field_type() self._nullable = nullable self._primary_key = primary_key - self._length = length - - if self._length is not None: - if self._length < 0: - raise error.ValidationError('length can not be less than zero') - if not self._type.supports_length(): - raise error.ValidationError( - f'length can not be set on field {self._type}') def ddl(self) -> str: """Returns DDL for the column.""" @@ -142,10 +134,6 @@ def validate_type(self, value: Any) -> None: if not isinstance(value, bool): raise error.ValidationError(f'{value!r} is not of type bool') - @staticmethod - def matches(type: str) -> bool: - return type == 'BOOL' - class Integer(FieldType): """Represents an integer type.""" @@ -166,10 +154,6 @@ def validate_type(self, value: Any) -> None: if not isinstance(value, int): raise error.ValidationError(f'{value!r} is not of type int') - @staticmethod - def matches(type: str) -> bool: - return type == 'INT64' - class Float(FieldType): """Represents a float type.""" @@ -190,17 +174,17 @@ def validate_type(self, value: Any) -> None: if not isinstance(value, (int, float)): raise error.ValidationError(f'{value!r} is not of type float') - @staticmethod - def matches(type: str) -> bool: - return type == 'FLOAT64' - class String(FieldType): """Represents a string type.""" + def __init__(self, size: Optional[int] = None): + self._size = size + def ddl(self) -> str: """See base class.""" - del self # Unused. + if self._size is not None: + return f'STRING({self._size})' return 'STRING(MAX)' def grpc_type(self) -> spanner_v1.Type: @@ -214,14 +198,6 @@ def validate_type(self, value: Any) -> None: if not isinstance(value, str): raise error.ValidationError(f'{value!r} is not of type str') - @staticmethod - def matches(type: str) -> bool: - return re.fullmatch('ARRAY', type) is not None - - @staticmethod - def supports_length() -> bool: - return True - class Timestamp(FieldType): """Represents a timestamp type.""" @@ -242,10 +218,6 @@ def validate_type(self, value: Any) -> None: if not isinstance(value, datetime.datetime): raise error.ValidationError(f'{value!r} is not of type datetime') - @staticmethod - def matches(type: str) -> bool: - return type == 'TIMESTAMP' - class BytesBase64(FieldType): """Represents a bytes type that must be base64 encoded.""" @@ -271,14 +243,6 @@ def validate_type(self, value: Any) -> None: except binascii.Error: raise error.ValidationError(f'{value!r} must be base64-encoded bytes.') - @staticmethod - def matches(type: str) -> bool: - return re.fullmatch('BYTES\((?:[0-9]+|MAX)\)', type) is not None - - @staticmethod - def supports_length() -> bool: - return True - class Array(FieldType): """Represents an array type.""" @@ -336,7 +300,7 @@ def field_type_from_ddl(ddl: str) -> FieldType: return Integer() elif ddl == 'FLOAT64': return Float() - elif ddl == 'STRING(MAX)': + elif re.fullmatch(r'STRING\((?:[0-9]+|MAX)\)', ddl) is not None: return String() elif ddl == 'TIMESTAMP': return Timestamp() diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py index 867ac7d..ace117e 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py @@ -17,9 +17,7 @@ class OriginalTeeTable(spanner_orm.model.Model): __table__ = 'Tee' id = spanner_orm.Field(spanner_orm.String, primary_key=True) - cus_str = spanner_orm.Field(spanner_orm.String, length=555) - cus_bytes = spanner_orm.Field(spanner_orm.BytesBase64, length=12) - cus_strarr = spanner_orm.Field(spanner_orm.StringArray, length=24) + cus_str = spanner_orm.Field(spanner_orm.String(20)) def upgrade() -> spanner_orm.CreateTable: diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py index 7a2699d..a68766a 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py @@ -34,13 +34,12 @@ class OriginalUnittestModelTable(spanner_orm.model.Model): float_2 = field.Field(field.Float, nullable=True) string = field.Field(field.String, primary_key=True) string_2 = field.Field(field.String, nullable=True) - string_3 = field.Field(field.String, nullable=True, length=20) + string_3 = field.Field(field.String(20), nullable=True) bytes_ = field.Field(field.BytesBase64, primary_key=True) bytes_2 = field.Field(field.BytesBase64, nullable=True) - bytes_3 = field.Field(field.BytesBase64, nullable=True, length=20) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) - string_array_2 = field.Field(field.StringArray, nullable=True, length=20) + string_array_2 = field.Field(field.Array(field.String(20)), nullable=True) def upgrade() -> spanner_orm.CreateTable: diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 6ff8e84..1dc3c10 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -228,8 +228,7 @@ def test_set_error_on_primary_key(self): @parameterized.parameters( ('int_2', 'foo'), ('float_2', 'bar'), ('string_2', 5), ('string_3', 5), - ('bytes_2', 'string'), ('bytes_2', 'string2'), ('string_array', 'foo'), - ('timestamp', 5)) + ('bytes_2', 'string'), ('string_array', 'foo'), ('timestamp', 5)) def test_set_error_on_invalid_type(self, attribute, value): string_array = ['foo', 'bar'] timestamp = datetime.datetime.now(tz=datetime.timezone.utc) diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 451c810..81c889f 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -107,13 +107,12 @@ class UnittestModel(model.Model): float_2 = field.Field(field.Float, nullable=True) string = field.Field(field.String, primary_key=True) string_2 = field.Field(field.String, nullable=True) - string_3 = field.Field(field.String, nullable=True, length=20) + string_3 = field.Field(field.String(20), nullable=True) bytes_ = field.Field(field.BytesBase64, primary_key=True) bytes_2 = field.Field(field.BytesBase64, nullable=True) - bytes_3 = field.Field(field.BytesBase64, nullable=True, length=20) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) - string_array_2 = field.Field(field.StringArray, nullable=True, length=20) + string_array_2 = field.Field(field.Array(field.String(20)), nullable=True) test_index = index.Index(['string_2']) @@ -128,13 +127,12 @@ class UnittestModelWithoutSecondaryIndexes(model.Model): float_2 = field.Field(field.Float, nullable=True) string = field.Field(field.String, primary_key=True) string_2 = field.Field(field.String, nullable=True) - string_3 = field.Field(field.String, nullable=True, length=20) + string_3 = field.Field(field.String(20), nullable=True) bytes_ = field.Field(field.BytesBase64, primary_key=True) bytes_2 = field.Field(field.BytesBase64, nullable=True) - bytes_3 = field.Field(field.BytesBase64, nullable=True, length=20) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) - string_array_2 = field.Field(field.StringArray, nullable=True, length=20) + string_array_2 = field.Field(field.Array(field.String(20)), nullable=True) class NullFilteredIndexModel(model.Model): diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index 6486cde..f0639ab 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -88,7 +88,6 @@ def test_create_table(self, get_model): ' string STRING(MAX) NOT NULL, string_2 STRING(MAX),' ' string_3 STRING(20),' ' bytes_ BYTES(MAX) NOT NULL, bytes_2 BYTES(MAX),' - ' bytes_3 BYTES(20),' ' timestamp TIMESTAMP NOT NULL,' ' string_array ARRAY,' ' string_array_2 ARRAY)' From a2154ff659748570e4b25a90a358f6fc2afb5957 Mon Sep 17 00:00:00 2001 From: wlliu Date: Thu, 22 Sep 2022 18:53:12 -0700 Subject: [PATCH 125/131] Add description --- spanner_orm/field.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 695ff78..2f8a795 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -179,6 +179,13 @@ class String(FieldType): """Represents a string type.""" def __init__(self, size: Optional[int] = None): + """Initializer. + + Args: + size: Size of the String. MAX is used if not specified. + """ + if size < 0: + raise error.ValidationError('string size can not be negative') self._size = size def ddl(self) -> str: From 8cd95d60b7fd0845e7edbe57293f65761b7ca866 Mon Sep 17 00:00:00 2001 From: wlliu Date: Mon, 26 Sep 2022 15:22:29 -0700 Subject: [PATCH 126/131] support byte64 --- spanner_orm/field.py | 19 +++++++++++++++---- ...create_custom_length_field_f959b767457d.py | 3 +++ .../create_unittest_model.py | 1 + spanner_orm/tests/models.py | 2 ++ spanner_orm/tests/update_test.py | 1 + 5 files changed, 22 insertions(+), 4 deletions(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 2f8a795..1583727 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -184,8 +184,8 @@ def __init__(self, size: Optional[int] = None): Args: size: Size of the String. MAX is used if not specified. """ - if size < 0: - raise error.ValidationError('string size can not be negative') + if size is not None and size <= 0: + raise error.ValidationError('string size must be positive') self._size = size def ddl(self) -> str: @@ -229,9 +229,20 @@ def validate_type(self, value: Any) -> None: class BytesBase64(FieldType): """Represents a bytes type that must be base64 encoded.""" + def __init__(self, size: Optional[int] = None): + """Initializer. + + Args: + size: Size of the Byte64. MAX is used if not specified. + """ + if size is not None and size <= 0: + raise error.ValidationError('Byte64 size must be positive') + self._size = size + def ddl(self) -> str: """See base class.""" - del self # Unused. + if self._size is not None: + return f'BYTES({self._size})' return 'BYTES(MAX)' def grpc_type(self) -> spanner_v1.Type: @@ -311,7 +322,7 @@ def field_type_from_ddl(ddl: str) -> FieldType: return String() elif ddl == 'TIMESTAMP': return Timestamp() - elif ddl == 'BYTES(MAX)': + elif re.fullmatch(r'BYTES\((?:[0-9]+|MAX)\)', ddl) is not None: return BytesBase64() elif (match := re.fullmatch(r'ARRAY<(.*)>', ddl)) is not None: return Array(field_type_from_ddl(match.group(1))) diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py index ace117e..04d81f9 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py @@ -18,6 +18,9 @@ class OriginalTeeTable(spanner_orm.model.Model): __table__ = 'Tee' id = spanner_orm.Field(spanner_orm.String, primary_key=True) cus_str = spanner_orm.Field(spanner_orm.String(20)) + cus_arr_str = spanner_orm.Field(spanner_orm.Array(spanner_orm.String(4))) + cus_byt = spanner_orm.Field(spanner_orm.BytesBase64(20)) + cus_arr_bye = spanner_orm.Field(spanner_orm.Array(spanner_orm.BytesBase64(4))) def upgrade() -> spanner_orm.CreateTable: diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py index a68766a..315d972 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py @@ -37,6 +37,7 @@ class OriginalUnittestModelTable(spanner_orm.model.Model): string_3 = field.Field(field.String(20), nullable=True) bytes_ = field.Field(field.BytesBase64, primary_key=True) bytes_2 = field.Field(field.BytesBase64, nullable=True) + bytes_3 = field.Field(field.BytesBase64(20), nullable=True) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) string_array_2 = field.Field(field.Array(field.String(20)), nullable=True) diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 81c889f..5f4a370 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -110,6 +110,7 @@ class UnittestModel(model.Model): string_3 = field.Field(field.String(20), nullable=True) bytes_ = field.Field(field.BytesBase64, primary_key=True) bytes_2 = field.Field(field.BytesBase64, nullable=True) + bytes_3 = field.Field(field.BytesBase64(20), nullable=True) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) string_array_2 = field.Field(field.Array(field.String(20)), nullable=True) @@ -130,6 +131,7 @@ class UnittestModelWithoutSecondaryIndexes(model.Model): string_3 = field.Field(field.String(20), nullable=True) bytes_ = field.Field(field.BytesBase64, primary_key=True) bytes_2 = field.Field(field.BytesBase64, nullable=True) + bytes_3 = field.Field(field.BytesBase64(20), nullable=True) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) string_array_2 = field.Field(field.Array(field.String(20)), nullable=True) diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index f0639ab..6486cde 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -88,6 +88,7 @@ def test_create_table(self, get_model): ' string STRING(MAX) NOT NULL, string_2 STRING(MAX),' ' string_3 STRING(20),' ' bytes_ BYTES(MAX) NOT NULL, bytes_2 BYTES(MAX),' + ' bytes_3 BYTES(20),' ' timestamp TIMESTAMP NOT NULL,' ' string_array ARRAY,' ' string_array_2 ARRAY)' From 11db6cc37b244ef64f0539154a320fd01191e28e Mon Sep 17 00:00:00 2001 From: wlliu Date: Wed, 28 Sep 2022 17:04:07 -0700 Subject: [PATCH 127/131] Adress comments and add tests in field_test.py --- spanner_orm/field.py | 36 ++++++++++--------- spanner_orm/tests/field_test.py | 17 +++++++++ ...create_custom_length_field_f959b767457d.py | 10 +++--- 3 files changed, 43 insertions(+), 20 deletions(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 1583727..77969b2 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -178,20 +178,20 @@ def validate_type(self, value: Any) -> None: class String(FieldType): """Represents a string type.""" - def __init__(self, size: Optional[int] = None): + def __init__(self, length: Optional[int] = None): """Initializer. Args: - size: Size of the String. MAX is used if not specified. + length: Length of the String. MAX is used if not specified. """ - if size is not None and size <= 0: - raise error.ValidationError('string size must be positive') - self._size = size + if length is not None and length <= 0: + raise error.ValidationError('String length must be positive') + self._length = length def ddl(self) -> str: """See base class.""" - if self._size is not None: - return f'STRING({self._size})' + if self._length is not None: + return f'STRING({self._length})' return 'STRING(MAX)' def grpc_type(self) -> spanner_v1.Type: @@ -229,20 +229,20 @@ def validate_type(self, value: Any) -> None: class BytesBase64(FieldType): """Represents a bytes type that must be base64 encoded.""" - def __init__(self, size: Optional[int] = None): + def __init__(self, length: Optional[int] = None): """Initializer. Args: - size: Size of the Byte64. MAX is used if not specified. + length: Length of the Byte64. MAX is used if not specified. """ - if size is not None and size <= 0: - raise error.ValidationError('Byte64 size must be positive') - self._size = size + if length is not None and length <= 0: + raise error.ValidationError('Bytes length must be positive') + self._length = length def ddl(self) -> str: """See base class.""" - if self._size is not None: - return f'BYTES({self._size})' + if self._length is not None: + return f'BYTES({self._length})' return 'BYTES(MAX)' def grpc_type(self) -> spanner_v1.Type: @@ -318,12 +318,16 @@ def field_type_from_ddl(ddl: str) -> FieldType: return Integer() elif ddl == 'FLOAT64': return Float() - elif re.fullmatch(r'STRING\((?:[0-9]+|MAX)\)', ddl) is not None: + elif ddl == 'STRING(MAX)': return String() + elif (match := re.fullmatch(r'STRING\(((?:[0-9]+))\)', ddl)) is not None: + return String(int(match.group(1))) elif ddl == 'TIMESTAMP': return Timestamp() - elif re.fullmatch(r'BYTES\((?:[0-9]+|MAX)\)', ddl) is not None: + elif ddl == 'BYTES(MAX)': return BytesBase64() + elif (match := re.fullmatch(r'BYTES\(((?:[0-9]+))\)', ddl)) is not None: + return BytesBase64(int(match.group(1))) elif (match := re.fullmatch(r'ARRAY<(.*)>', ddl)) is not None: return Array(field_type_from_ddl(match.group(1))) else: diff --git a/spanner_orm/tests/field_test.py b/spanner_orm/tests/field_test.py index ecc518c..246f3f6 100644 --- a/spanner_orm/tests/field_test.py +++ b/spanner_orm/tests/field_test.py @@ -33,10 +33,15 @@ class FieldTest(parameterized.TestCase): (field.Integer(), 'INT64'), (field.Float(), 'FLOAT64'), (field.String(), 'STRING(MAX)'), + (field.String(10), 'STRING(10)'), (field.Timestamp(), 'TIMESTAMP'), (field.BytesBase64(), 'BYTES(MAX)'), + (field.BytesBase64(10), 'BYTES(10)'), (field.Array(field.Boolean()), 'ARRAY'), (field.Array(field.String()), 'ARRAY'), + (field.Array(field.String(10)), 'ARRAY'), + (field.Array(field.BytesBase64()), 'ARRAY'), + (field.Array(field.BytesBase64(10)), 'ARRAY'), ) def test_field_type_ddl( self, @@ -50,12 +55,16 @@ def test_field_type_ddl( (field.Integer(), spanner.param_types.INT64), (field.Float(), spanner.param_types.FLOAT64), (field.String(), spanner.param_types.STRING), + (field.String(10), spanner.param_types.STRING), (field.Timestamp(), spanner.param_types.TIMESTAMP), (field.BytesBase64(), spanner.param_types.BYTES), + (field.BytesBase64(10), spanner.param_types.BYTES), (field.Array(field.Boolean()), spanner.param_types.Array(spanner.param_types.BOOL)), (field.Array(field.String()), spanner.param_types.Array(spanner.param_types.STRING)), + (field.Array(field.String(10)), + spanner.param_types.Array(spanner.param_types.STRING)), ) def test_field_type_grpc_type( self, @@ -70,8 +79,10 @@ def test_field_type_grpc_type( (field.Float(), 1), (field.Float(), 1.0), (field.String(), 'foo'), + (field.String(10), 'foo'), (field.Timestamp(), datetime.datetime(2022, 9, 21)), (field.BytesBase64(), base64.b64encode(b'\x00')), + (field.BytesBase64(10), base64.b64encode(b'\x00')), (field.Array(field.Boolean()), [True]), ) def test_field_type_validate_type_ok( @@ -86,9 +97,11 @@ def test_field_type_validate_type_ok( (field.Integer(), 1.0), (field.Float(), '1.0'), (field.String(), b'foo'), + (field.String(10), b'foo'), (field.Timestamp(), datetime.date(2022, 9, 21)), (field.BytesBase64(), base64.b64encode(b'\x00').decode('utf-8')), (field.BytesBase64(), b'!'), + (field.BytesBase64(10), b'!'), (field.Array(field.Boolean()), {True}), (field.Array(field.Boolean()), [1]), ) @@ -103,6 +116,7 @@ def test_field_type_validate_type_error( @parameterized.parameters( (field.Boolean(), field.Boolean(), True), (field.Boolean(), field.String(), False), + (field.String(10), field.String(20), True), (field.Array(field.Integer()), field.Array(field.Integer()), False), (field.Array(field.Integer()), field.Integer(), False), ) @@ -145,10 +159,13 @@ def test_string_array_is_deprecated_and_equivalent_to_array_of_string(self): 'INT64', 'FLOAT64', 'STRING(MAX)', + 'STRING(10)', 'TIMESTAMP', 'BYTES(MAX)', + 'BYTES(10)', 'ARRAY', 'ARRAY', + 'ARRAY', ) def test_ddl_to_field_type_to_ddl(self, ddl: str): self.assertEqual(field.field_type_from_ddl(ddl).ddl(), ddl) diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py index 04d81f9..47d698b 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py @@ -17,10 +17,12 @@ class OriginalTeeTable(spanner_orm.model.Model): __table__ = 'Tee' id = spanner_orm.Field(spanner_orm.String, primary_key=True) - cus_str = spanner_orm.Field(spanner_orm.String(20)) - cus_arr_str = spanner_orm.Field(spanner_orm.Array(spanner_orm.String(4))) - cus_byt = spanner_orm.Field(spanner_orm.BytesBase64(20)) - cus_arr_bye = spanner_orm.Field(spanner_orm.Array(spanner_orm.BytesBase64(4))) + custom_string_length = spanner_orm.Field(spanner_orm.String(20)) + custom_array_string_length = spanner_orm.Field( + spanner_orm.Array(spanner_orm.String(4))) + custom_bytes_length = spanner_orm.Field(spanner_orm.BytesBase64(20)) + custom_array_bytes_length = spanner_orm.Field( + spanner_orm.Array(spanner_orm.BytesBase64(4))) def upgrade() -> spanner_orm.CreateTable: From 7b3b5a67b07dea6de0d6a4600ff7230cba377bd1 Mon Sep 17 00:00:00 2001 From: wlliu Date: Mon, 3 Oct 2022 13:12:48 -0700 Subject: [PATCH 128/131] address comments --- spanner_orm/field.py | 6 +++--- spanner_orm/tests/field_test.py | 7 +++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 77969b2..c8cdd6e 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -233,7 +233,7 @@ def __init__(self, length: Optional[int] = None): """Initializer. Args: - length: Length of the Byte64. MAX is used if not specified. + length: Length of the Bytes. MAX is used if not specified. """ if length is not None and length <= 0: raise error.ValidationError('Bytes length must be positive') @@ -320,13 +320,13 @@ def field_type_from_ddl(ddl: str) -> FieldType: return Float() elif ddl == 'STRING(MAX)': return String() - elif (match := re.fullmatch(r'STRING\(((?:[0-9]+))\)', ddl)) is not None: + elif (match := re.fullmatch(r'STRING\((([0-9]+))\)', ddl)) is not None: return String(int(match.group(1))) elif ddl == 'TIMESTAMP': return Timestamp() elif ddl == 'BYTES(MAX)': return BytesBase64() - elif (match := re.fullmatch(r'BYTES\(((?:[0-9]+))\)', ddl)) is not None: + elif (match := re.fullmatch(r'BYTES\((([0-9]+))\)', ddl)) is not None: return BytesBase64(int(match.group(1))) elif (match := re.fullmatch(r'ARRAY<(.*)>', ddl)) is not None: return Array(field_type_from_ddl(match.group(1))) diff --git a/spanner_orm/tests/field_test.py b/spanner_orm/tests/field_test.py index 246f3f6..553a081 100644 --- a/spanner_orm/tests/field_test.py +++ b/spanner_orm/tests/field_test.py @@ -117,6 +117,7 @@ def test_field_type_validate_type_error( (field.Boolean(), field.Boolean(), True), (field.Boolean(), field.String(), False), (field.String(10), field.String(20), True), + (field.String(), field.String(), True), (field.Array(field.Integer()), field.Array(field.Integer()), False), (field.Array(field.Integer()), field.Integer(), False), ) @@ -170,9 +171,11 @@ def test_string_array_is_deprecated_and_equivalent_to_array_of_string(self): def test_ddl_to_field_type_to_ddl(self, ddl: str): self.assertEqual(field.field_type_from_ddl(ddl).ddl(), ddl) - def test_field_type_from_ddl_invalid(self): + @parameterized.parameters('UNICORN(MAX)', 'STRING(MAX1)', 'STRING(MIN)', + 'ARRAY', 'BYTES(MAX1)', 'BYTES(MIN)') + def test_field_type_from_ddl_invalid(self, ddl: str): with self.assertRaisesRegex(error.SpannerError, 'DDL type'): - field.field_type_from_ddl('UNICORN(MAX)') + field.field_type_from_ddl(ddl) if __name__ == '__main__': From 0196241f8b403edd16e47ece163f92b0571751bf Mon Sep 17 00:00:00 2001 From: wlliu Date: Mon, 3 Oct 2022 14:14:50 -0700 Subject: [PATCH 129/131] Address comments --- spanner_orm/field.py | 4 ++-- spanner_orm/tests/field_test.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/spanner_orm/field.py b/spanner_orm/field.py index c8cdd6e..cbc83f0 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -320,13 +320,13 @@ def field_type_from_ddl(ddl: str) -> FieldType: return Float() elif ddl == 'STRING(MAX)': return String() - elif (match := re.fullmatch(r'STRING\((([0-9]+))\)', ddl)) is not None: + elif (match := re.fullmatch(r'STRING\(([0-9]+)\)', ddl)) is not None: return String(int(match.group(1))) elif ddl == 'TIMESTAMP': return Timestamp() elif ddl == 'BYTES(MAX)': return BytesBase64() - elif (match := re.fullmatch(r'BYTES\((([0-9]+))\)', ddl)) is not None: + elif (match := re.fullmatch(r'BYTES\(([0-9]+)\)', ddl)) is not None: return BytesBase64(int(match.group(1))) elif (match := re.fullmatch(r'ARRAY<(.*)>', ddl)) is not None: return Array(field_type_from_ddl(match.group(1))) diff --git a/spanner_orm/tests/field_test.py b/spanner_orm/tests/field_test.py index 553a081..36da5aa 100644 --- a/spanner_orm/tests/field_test.py +++ b/spanner_orm/tests/field_test.py @@ -117,7 +117,7 @@ def test_field_type_validate_type_error( (field.Boolean(), field.Boolean(), True), (field.Boolean(), field.String(), False), (field.String(10), field.String(20), True), - (field.String(), field.String(), True), + (field.String(), field.String(10), True), (field.Array(field.Integer()), field.Array(field.Integer()), False), (field.Array(field.Integer()), field.Integer(), False), ) From d04420f815b081107dc82808b1b212d078caf5af Mon Sep 17 00:00:00 2001 From: David Mandelberg Date: Thu, 12 Sep 2024 18:47:51 +0000 Subject: [PATCH 130/131] Fix tests and add support for newer python --- .github/workflows/test.yaml | 1 + spanner_orm/query.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index e25a9f7..b198738 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -26,6 +26,7 @@ jobs: - '3.8' - '3.9' - '3.10' + - '3.11' runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 diff --git a/spanner_orm/query.py b/spanner_orm/query.py index 85a35aa..ac09f9a 100644 --- a/spanner_orm/query.py +++ b/spanner_orm/query.py @@ -14,13 +14,15 @@ """Helps build SQL for complex Spanner queries.""" import abc -from typing import Any, Dict, Iterable, List, Sequence, Tuple, Type +from typing import Any, Dict, Generic, Iterable, List, Sequence, Tuple, Type, TypeVar from spanner_orm import condition from spanner_orm import error +ResultType = TypeVar('ResultType') -class SpannerQuery(abc.ABC): + +class SpannerQuery(abc.ABC, Generic[ResultType]): """Helps build SQL for complex Spanner queries.""" def __init__(self, model: Type[Any], @@ -46,7 +48,7 @@ def types(self) -> Dict[str, Any]: return self._types @abc.abstractmethod - def process_results(self, results: List[Sequence[Any]]) -> None: + def process_results(self, results: List[Sequence[Any]]) -> ResultType: pass def _segments(self, @@ -133,7 +135,7 @@ def _limit(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: return (sql, parameters, types) -class CountQuery(SpannerQuery): +class CountQuery(SpannerQuery[int]): """Handles COUNT Spanner queries.""" def __init__(self, model: Type[Any], @@ -151,7 +153,7 @@ def process_results(self, results: List[Sequence[Any]]) -> int: return int(results[0][0]) -class SelectQuery(SpannerQuery): +class SelectQuery(SpannerQuery[List[Type[Any]]]): """Handles SELECT Spanner queries.""" def __init__(self, model: Type[Any], @@ -188,7 +190,7 @@ def _select(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: def process_results(self, results: List[Sequence[Any]]) -> List[Type[Any]]: return [self._process_row(result) for result in results] - def _process_row(self, row: List[Any]) -> Type[Any]: + def _process_row(self, row: Sequence[Any]) -> Type[Any]: """Parses a row of results from a Spanner query based on the conditions.""" values = dict(zip(self._model.columns, row)) join_values = row[len(self._model.columns):] From b9b9191f292d8504045570df43344e108c722b7d Mon Sep 17 00:00:00 2001 From: Sabrina Gutierrez Date: Tue, 12 Aug 2025 09:31:48 -0400 Subject: [PATCH 131/131] Spanner orm upgrade (#195) --- .github/workflows/test.yaml | 2 +- setup.py | 4 ++-- spanner_orm/__init__.py | 5 ----- .../testlib/spanner_emulator/emulator.py | 2 ++ spanner_orm/tests/migrations_emulator_test.py | 20 +++++++++---------- ...create_custom_length_field_f959b767457d.py | 2 +- .../create_foreign_key_test_model.py | 8 ++++---- ..._null_filtered_index_model_760ec5fae5da.py | 6 +++--- .../create_small_test_model.py | 6 +++--- .../create_unittest_model.py | 14 ++++++------- 10 files changed, 33 insertions(+), 36 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b198738..1adfa23 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -45,7 +45,7 @@ jobs: pip install \ absl-py \ google-api-core \ - 'google-cloud-spanner >= 2, <4' \ + 'google-cloud-spanner >= 3, <4' \ immutabledict \ portpicker \ pytest diff --git a/setup.py b/setup.py index 0bcdb5f..2aa2b71 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name='spanner-orm', - version='0.2.0', + version='0.3.0', description='Basic ORM for Spanner', maintainer='Python Spanner ORM developers', maintainer_email='python-spanner-orm@google.com', @@ -25,7 +25,7 @@ include_package_data=True, python_requires='~=3.8', install_requires=[ - 'google-cloud-spanner >= 2, <4', + 'google-cloud-spanner >= 3, <4', 'immutabledict', ], tests_require=['absl-py', 'google-api-core', 'portpicker'], diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 5b27fc6..4da2f26 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -113,8 +113,3 @@ model_creation_ddl = update_module.model_creation_ddl MigrationExecutor = migration_executor.MigrationExecutor - -try: - __import__('pkg_resources').declare_namespace('spanner_orm') -except ImportError: - __path__ = __import__('pkgutil').extend_path(__path__, 'spanner_orm') diff --git a/spanner_orm/testlib/spanner_emulator/emulator.py b/spanner_orm/testlib/spanner_emulator/emulator.py index 2cafa8a..55f7a67 100644 --- a/spanner_orm/testlib/spanner_emulator/emulator.py +++ b/spanner_orm/testlib/spanner_emulator/emulator.py @@ -15,6 +15,7 @@ import os import subprocess +import time from typing import Mapping, Optional import portpicker @@ -65,6 +66,7 @@ def __init__(self, self._host_port = None self._start() + time.sleep(1) self._wait_for_ready() def get_client( diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py index 4f936bb..96978bd 100644 --- a/spanner_orm/tests/migrations_emulator_test.py +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -124,7 +124,7 @@ def test_drop_interleaved_table(self): class _Parent(spanner_orm.Model): __table__ = 'Parent' parent_key = spanner_orm.Field( - spanner_orm.String, primary_key=True) + spanner_orm.String(), primary_key=True) def upgrade(): return spanner_orm.CreateTable(_Parent) @@ -133,15 +133,15 @@ def upgrade(): class _Parent(spanner_orm.Model): __table__ = 'Parent' parent_key = spanner_orm.Field( - spanner_orm.String, primary_key=True) + spanner_orm.String(), primary_key=True) class _Child(spanner_orm.Model): __table__ = 'Child' __interleaved__ = _Parent parent_key = spanner_orm.Field( - spanner_orm.String, primary_key=True) + spanner_orm.String(), primary_key=True) child_key = spanner_orm.Field( - spanner_orm.String, primary_key=True) + spanner_orm.String(), primary_key=True) def upgrade(): return spanner_orm.CreateTable(_Child) @@ -170,8 +170,8 @@ def upgrade(): class _TableToDrop(spanner_orm.Model): __table__ = 'TableToDrop' key = spanner_orm.Field( - spanner_orm.String, primary_key=True) - value = spanner_orm.Field(spanner_orm.String) + spanner_orm.String(), primary_key=True) + value = spanner_orm.Field(spanner_orm.String()) def upgrade(): return spanner_orm.CreateTable(_TableToDrop) @@ -194,7 +194,7 @@ def upgrade(): class _TableToDrop(spanner_orm.Model): __table__ = 'TableToDrop' parent_key = spanner_orm.Field( - spanner_orm.String, primary_key=True) + spanner_orm.String(), primary_key=True) def upgrade(): return spanner_orm.CreateTable(_TableToDrop) @@ -203,15 +203,15 @@ def upgrade(): class _TableToDrop(spanner_orm.Model): __table__ = 'TableToDrop' parent_key = spanner_orm.Field( - spanner_orm.String, primary_key=True) + spanner_orm.String(), primary_key=True) class _Child(spanner_orm.Model): __table__ = 'Child' __interleaved__ = _TableToDrop parent_key = spanner_orm.Field( - spanner_orm.String, primary_key=True) + spanner_orm.String(), primary_key=True) child_key = spanner_orm.Field( - spanner_orm.String, primary_key=True) + spanner_orm.String(), primary_key=True) def upgrade(): return spanner_orm.CreateTable(_Child) diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py index 47d698b..915d1f3 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py @@ -16,7 +16,7 @@ class OriginalTeeTable(spanner_orm.model.Model): """ __table__ = 'Tee' - id = spanner_orm.Field(spanner_orm.String, primary_key=True) + id = spanner_orm.Field(spanner_orm.String(), primary_key=True) custom_string_length = spanner_orm.Field(spanner_orm.String(20)) custom_array_string_length = spanner_orm.Field( spanner_orm.Array(spanner_orm.String(4))) diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py index bdd3935..02c947d 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py @@ -29,10 +29,10 @@ class OriginalForeignKeyTestModelTable(spanner_orm.model.Model): """ORM Model with the original schema for the ForeignKeyTestModel table.""" __table__ = 'ForeignKeyTestModel' - referencing_key_1 = field.Field(field.String, primary_key=True) - referencing_key_2 = field.Field(field.String, primary_key=True) - referencing_key_3 = field.Field(field.Integer, primary_key=True) - self_referencing_key = field.Field(field.String, nullable=True) + referencing_key_1 = field.Field(field.String(), primary_key=True) + referencing_key_2 = field.Field(field.String(), primary_key=True) + referencing_key_3 = field.Field(field.Integer(), primary_key=True) + self_referencing_key = field.Field(field.String(), nullable=True) foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( 'SmallTestModel', {'referencing_key_1': 'key'}) foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_760ec5fae5da.py b/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_760ec5fae5da.py index 494a000..4dafccb 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_760ec5fae5da.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_760ec5fae5da.py @@ -25,9 +25,9 @@ class _NullFilteredIndexModel(spanner_orm.Model): __table__ = 'NullFilteredIndexModel' - key = spanner_orm.Field(spanner_orm.String, primary_key=True) - value_1 = spanner_orm.Field(spanner_orm.String, nullable=True) - value_2 = spanner_orm.Field(spanner_orm.Integer) + key = spanner_orm.Field(spanner_orm.String(), primary_key=True) + value_1 = spanner_orm.Field(spanner_orm.String(), nullable=True) + value_2 = spanner_orm.Field(spanner_orm.Integer()) def upgrade() -> spanner_orm.MigrationUpdate: diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py index c4b019b..5878337 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py @@ -28,9 +28,9 @@ class OriginalSmallTestModelTable(spanner_orm.model.Model): """ORM Model with the original schema for the SmallTestModel table.""" __table__ = 'SmallTestModel' - key = field.Field(field.String, primary_key=True) - value_1 = field.Field(field.String) - value_2 = field.Field(field.String, nullable=True) + key = field.Field(field.String(), primary_key=True) + value_1 = field.Field(field.String()) + value_2 = field.Field(field.String(), nullable=True) def upgrade() -> spanner_orm.CreateTable: diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py index 315d972..8fcf25e 100644 --- a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py +++ b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py @@ -28,18 +28,18 @@ class OriginalUnittestModelTable(spanner_orm.model.Model): """ORM Model with the original schema for the UnittestModel table.""" __table__ = 'table' - int_ = field.Field(field.Integer, primary_key=True) - int_2 = field.Field(field.Integer, nullable=True) - float_ = field.Field(field.Float, primary_key=True) - float_2 = field.Field(field.Float, nullable=True) - string = field.Field(field.String, primary_key=True) - string_2 = field.Field(field.String, nullable=True) + int_ = field.Field(field.Integer(), primary_key=True) + int_2 = field.Field(field.Integer(), nullable=True) + float_ = field.Field(field.Float(), primary_key=True) + float_2 = field.Field(field.Float(), nullable=True) + string = field.Field(field.String(), primary_key=True) + string_2 = field.Field(field.String(), nullable=True) string_3 = field.Field(field.String(20), nullable=True) bytes_ = field.Field(field.BytesBase64, primary_key=True) bytes_2 = field.Field(field.BytesBase64, nullable=True) bytes_3 = field.Field(field.BytesBase64(20), nullable=True) timestamp = field.Field(field.Timestamp) - string_array = field.Field(field.StringArray, nullable=True) + string_array = field.Field(field.Array(field.String()), nullable=True) string_array_2 = field.Field(field.Array(field.String(20)), nullable=True)