diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..1adfa23 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,64 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +on: + pull_request: {} + push: {} + schedule: + - cron: '50 13 * * *' + +jobs: + test: + strategy: + matrix: + python-version: + - '3.8' + - '3.9' + - '3.10' + - '3.11' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Install cloud-spanner-emulator + run: | + # https://github.com/GoogleCloudPlatform/cloud-spanner-emulator#via-pre-built-linux-binaries + VERSION=1.2.0 + wget https://storage.googleapis.com/cloud-spanner-emulator/releases/${VERSION}/cloud-spanner-emulator_linux_amd64-${VERSION}.tar.gz + tar zxvf cloud-spanner-emulator_linux_amd64-${VERSION}.tar.gz + chmod u+x gateway_main emulator_main + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install python dependencies + run: | + pip install \ + absl-py \ + google-api-core \ + 'google-cloud-spanner >= 3, <4' \ + immutabledict \ + portpicker \ + pytest + - name: Check formatting + run: | + pip install yapf + yapf --diff --recursive --parallel . + - name: Check types + run: | + pip install pytype + pytype --jobs=auto --keep-going spanner_orm + - name: Test + env: + SPANNER_EMULATOR_BINARY_PATH: ${{ github.workspace }}/emulator_main + run: | + pytest diff --git a/.gitignore b/.gitignore index 7c47c20..6007829 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,13 @@ dist build __pycache__ +.eggs +*.egg-info +.pytype +env + +# Files that may or may not be added to the repo while acquiring the Spanner +# emulator. +cloud-spanner-emulator* +emulator_main +gateway_main diff --git a/README.md b/README.md index 75a7cbb..ee1bed2 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,24 @@ +[![.github/workflows/test.yaml](https://github.com/google/python-spanner-orm/actions/workflows/test.yaml/badge.svg)](https://github.com/google/python-spanner-orm/actions/workflows/test.yaml) + # Google Cloud Spanner ORM This is a lightweight ORM written in Python and built on top of Cloud Spanner. +This is not an officially supported Google product. ## Getting started ### How to install -Make sure that Python 3.7 is the default version of python for your environment, -then run: +Make sure that Python 3.8 or higher is the default version of python for your +environment, then run: ```pip install git+https://github.com/google/python-spanner-orm#egg=spanner_orm``` ### Connecting To connect the Spanner ORM to an existing Spanner database: ``` python import spanner_orm -spanner_orm.connect(instance_name, database_name) +spanner_orm.from_connection( + spanner_orm.SpannerConnection(instance_name, database_name)) ``` `project` and `credentials` are optional parameters, and the standard Spanner @@ -41,18 +45,18 @@ class TestModel(spanner_orm.Model): # is the type of field. The primary key is constructed by the fields labeled # with primary_key=True in the order they appear in the class. # The name of the column is the same as the name of the class attribute - id = spanner_orm.Field(spanner_orm.String, primary_key=True) - value = spanner_orm.Field(spanner_orm.Integer, nullable=True) + id = spanner_orm.Field(spanner_orm.String(), primary_key=True) + value = spanner_orm.Field(spanner_orm.Integer(), nullable=True) + number = spanner_orm.Field(spanner_orm.Float(), nullable=True) # Secondary indexes are specified in a similar manner to fields: value_index = spanner_orm.Index(['value']) # To indicate that there is a foreign key relationship from this table to - # another one, use a Relationship. This has no impact on the representation - # of the table inside Spanner - fake_relationship = spanner_orm.Relationship( + # another one, use a ForeignKeyRelationship. + foreign_key = spanner_orm.ForeignKeyRelationship( 'OtherModel', - {'value': 'other_model_column'}) + {'referencing_key': 'referenced_key'}) ``` If the model does not refer to an existing table on Spanner, we can create @@ -65,7 +69,7 @@ admin_api = spanner_orm.connect_admin( 'instance_name', 'database_name', create_ddl=spanner_orm.model_creation_ddl(TestModel)) -admin_api.create() +admin_api.create_database() ``` If the database already exists, we can execute a Migration where the upgrade @@ -82,23 +86,24 @@ The two main ways of retrieving data through the ORM are ```where()``` and ```find()```/```find_multi()```: ``` python -# where() is invokes on a model class to retrieve models of that tyep. it takes a -# transaction and then a sequence of conditions. -# Most conditions that specify a Field, Index, Relationship, or Model can take -# either the name of the object or the object itself -test_objects = TestModel.where(None, spanner_orm.greater_than('value', '50')) +# where() is invokes on a model class to retrieve models of that type. it takes +# a sequence of conditions. Most conditions that specify a Field, Index, +# Relationship, or Model can take either the name of the object or the object +# itself +test_objects = TestModel.where(spanner_orm.greater_than('value', '50')) # To also retrieve related objects, the includes() condition should be used: -test_and_other_objects = TestModel.where(None, - spanner_orm.greater_than(TestModel.value, '50'), - spanner_orm.includes(TestModel.fake_relationship)) +test_and_other_objects = TestModel.where( + spanner_orm.greater_than(TestModel.value, '50'), + spanner_orm.includes(TestModel.fake_relationship), +) # To create a transaction, run_read_only() or run_write() are used with the # method to be run inside the transaction and any arguments to passs to the method. # The method is invoked with the transaction as the first argument and then the # rest of the provided arguments: def callback_1(transaction, argument): - return TestModel.find(transaction, id=argument) + return TestModel.find(id=argument, transaction=transaction) specific_object = spanner_orm.spanner_api().run_read_only(callback, 1) @@ -106,7 +111,7 @@ specific_object = spanner_orm.spanner_api().run_read_only(callback, 1) # call a bit: @transactional_read def finder(argument, transaction=None): - return TestModel.find(transaction, id=argument) + return TestModel.find(id=argument, transaction=transaction) specific_object = finder(1) ``` @@ -129,7 +134,7 @@ models = [] for i in range(10): key = 'test_{}'.format(i) models.append(TestModel({'key': key, 'value': value})) -TestModel.save_batch(None, models) +TestModel.save_batch(models) ``` ```spanner_orm.spanner_api().run_write()``` can be used for executing read-write @@ -151,8 +156,8 @@ Running ```spanner-orm generate ``` will generate a new migration file to be filled out in the directory specified (or 'migrations' by default). The ```upgrade``` function is executed when migrating, and the ```downgrade``` function is executed when rolling back the migration. Each of -these should return a single SchemaUpdate object (e.g., CreateTable, AddColumn, -etc.), as Spanner cannot execute multiple schema updates atomically. +these should return a single MigrationUpdate object (e.g., CreateTable, +AddColumn, etc.), as Spanner cannot execute multiple schema updates atomically. ### Executing migrations Running ```spanner-orm migrate ``` will @@ -175,3 +180,45 @@ multiple times, so try not to do that. If a migration needs to be rolled back, ```spanner-orm rollback ``` or the corresponding ```MigrationExecutor``` method should be used. + +## Tests + +Note: we suggest using a Python 3.8 +[virtualenv](https://docs.python.org/3/library/venv.html) +for running tests and type checking. + + +Before running any tests, you'll need to download the Cloud Spanner Emulator. +See https://github.com/GoogleCloudPlatform/cloud-spanner-emulator for several +options. If you're on Linux, we recommend: + +``` +VERSION=1.2.0 +wget https://storage.googleapis.com/cloud-spanner-emulator/releases/${VERSION}/cloud-spanner-emulator_linux_amd64-${VERSION}.tar.gz +tar zxvf cloud-spanner-emulator_linux_amd64-${VERSION}.tar.gz +chmod u+x gateway_main emulator_main +``` + +``` +git clone git@github.com:GoogleCloudPlatform/cloud-spanner-emulator.git +``` + +To check type annotations, run: + +``` +pip install pytype +pytype spanner_orm +``` + +To check formatting, run (change `--diff` to `--in-place` to fix formatting): + +``` +pip install yapf +yapf --diff --recursive --parallel . +``` + +Then run tests with: + +``` +SPANNER_EMULATOR_BINARY_PATH=$(pwd)/emulator_main pytest +``` diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..db5b80a --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[yapf] +based_on_style = yapf diff --git a/setup.py b/setup.py index 57a6355..2aa2b71 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,18 +13,22 @@ # limitations under the License. """spanner_orm setup file.""" from setuptools import setup + setup( name='spanner-orm', - version='0.1.10', + version='0.3.0', description='Basic ORM for Spanner', - maintainer='Derek Brandao', - maintainer_email='dbrandao@google.com', + maintainer='Python Spanner ORM developers', + maintainer_email='python-spanner-orm@google.com', url='https://github.com/google/python-spanner-orm', packages=['spanner_orm', 'spanner_orm.admin'], include_package_data=True, - python_requires='~=3.7', - install_requires=['google-cloud-spanner >= 1.6, <2.0.0dev'], - tests_require=['absl-py'], + python_requires='~=3.8', + install_requires=[ + 'google-cloud-spanner >= 3, <4', + 'immutabledict', + ], + tests_require=['absl-py', 'google-api-core', 'portpicker'], entry_points={ 'console_scripts': ['spanner-orm = spanner_orm.admin.scripts:main'] }) diff --git a/spanner_orm/__init__.py b/spanner_orm/__init__.py index 335b864..4da2f26 100644 --- a/spanner_orm/__init__.py +++ b/spanner_orm/__init__.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +19,7 @@ from spanner_orm import decorator from spanner_orm import error from spanner_orm import field +from spanner_orm import foreign_key_relationship from spanner_orm import index from spanner_orm import model from spanner_orm import relationship @@ -60,18 +60,32 @@ Boolean = field.Boolean Field = field.Field Integer = field.Integer +Float = field.Float +ForeignKeyRelationship = foreign_key_relationship.ForeignKeyRelationship Index = index.Index Relationship = relationship.Relationship String = field.String StringArray = field.StringArray Timestamp = field.Timestamp +BytesBase64 = field.BytesBase64 +Array = field.Array +ArbitraryCondition = condition.ArbitraryCondition +Column = condition.Column +Condition = condition.Condition +ORDER_ASC = condition.OrderType.ASC +ORDER_DESC = condition.OrderType.DESC +Param = condition.Param +Segment = condition.Segment +columns_equal = condition.columns_equal +contains = condition.contains equal_to = condition.equal_to force_index = condition.force_index +force_null_filtered_index = condition.force_null_filtered_index greater_than = condition.greater_than greater_than_or_equal_to = condition.greater_than_or_equal_to -includes = condition.includes in_list = condition.in_list +includes = condition.includes less_than = condition.less_than less_than_or_equal_to = condition.less_than_or_equal_to limit = condition.limit @@ -79,13 +93,15 @@ not_greater_than = condition.not_greater_than not_in_list = condition.not_in_list not_less_than = condition.not_less_than +or_ = condition.or_ order_by = condition.order_by -ORDER_ASC = condition.OrderType.ASC -ORDER_DESC = condition.OrderType.DESC transactional_read = decorator.transactional_read transactional_write = decorator.transactional_write +MigrationUpdate = update_module.MigrationUpdate +NoUpdate = update_module.NoUpdate +SchemaUpdate = update_module.SchemaUpdate CreateTable = update_module.CreateTable DropTable = update_module.DropTable AddColumn = update_module.AddColumn @@ -93,7 +109,7 @@ AlterColumn = update_module.AlterColumn CreateIndex = update_module.CreateIndex DropIndex = update_module.DropIndex -NoUpdate = update_module.NoUpdate +ExecutePartitionedDml = update_module.ExecutePartitionedDml model_creation_ddl = update_module.model_creation_ddl MigrationExecutor = migration_executor.MigrationExecutor diff --git a/spanner_orm/admin/api.py b/spanner_orm/admin/api.py index 59e5ae0..85c3062 100644 --- a/spanner_orm/admin/api.py +++ b/spanner_orm/admin/api.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,15 +13,15 @@ # limitations under the License. """Class that handles API calls to Spanner that deal with table metadata.""" -from __future__ import annotations - from typing import Iterable, Optional +import warnings + from spanner_orm import api from spanner_orm import error from google.auth import credentials as auth_credentials -from google.cloud import spanner from google.cloud.spanner_v1 import database as spanner_database +from google.cloud.spanner_v1 import pool as spanner_pool class SpannerAdminApi(api.SpannerReadApi, api.SpannerWriteApi): @@ -32,7 +31,7 @@ def __init__(self, connection: api.SpannerConnection): self._spanner_connection = connection @property - def _connection(self) -> spanner_database.SpannerDatabase: + def _connection(self) -> spanner_database.Database: return self._spanner_connection.database def create_database(self) -> None: @@ -46,6 +45,10 @@ def update_schema(self, change: str) -> None: operation = self._connection.update_ddl([change]) operation.result() + def execute_partitioned_dml(self, dml: str) -> None: + """See spanner_database.Database.execute_partitioned_dml().""" + self._connection.execute_partitioned_dml(dml) + _admin_api = None @@ -54,9 +57,15 @@ def connect(instance: str, database: str, project: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, - pool: Optional[spanner.Pool] = None, + pool: Optional[spanner_pool.AbstractSessionPool] = None, create_ddl: Optional[Iterable[str]] = None) -> SpannerAdminApi: - """Connects the global Spanner admin API to a Spanner database.""" + """Connects the global Spanner admin API to a Spanner database. + + Deprecated in favor of from_connection(). + """ + warnings.warn( + DeprecationWarning('Please use spanner_orm.from_admin_connection(' + 'spanner_orm.SpannerConnection(...))')) connection = api.SpannerConnection( instance, database, diff --git a/spanner_orm/admin/column.py b/spanner_orm/admin/column.py index 7c22b9a..60f5017 100644 --- a/spanner_orm/admin/column.py +++ b/spanner_orm/admin/column.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,8 +13,7 @@ # limitations under the License. """Model for interacting with Spanner column schema table.""" -from __future__ import annotations - +import typing from typing import Type from spanner_orm import error @@ -27,21 +25,16 @@ class ColumnSchema(schema.InformationSchema): """Model for interacting with Spanner column schema table.""" __table__ = 'information_schema.columns' - table_catalog = field.Field(field.String, primary_key=True) - table_schema = field.Field(field.String, primary_key=True) - table_name = field.Field(field.String, primary_key=True) - column_name = field.Field(field.String, primary_key=True) - ordinal_position = field.Field(field.Integer) - is_nullable = field.Field(field.String) - spanner_type = field.Field(field.String) + table_catalog = typing.cast(str, field.Field(field.String, primary_key=True)) + table_schema = typing.cast(str, field.Field(field.String, primary_key=True)) + table_name = typing.cast(str, field.Field(field.String, primary_key=True)) + column_name = typing.cast(str, field.Field(field.String, primary_key=True)) + ordinal_position = typing.cast(int, field.Field(field.Integer)) + is_nullable = typing.cast(str, field.Field(field.String)) + spanner_type = typing.cast(str, field.Field(field.String)) def nullable(self) -> bool: return self.is_nullable == 'YES' - def field_type(self) -> Type[field.FieldType]: - for field_type in field.ALL_TYPES: - if self.spanner_type == field_type.ddl(): - return field_type - - raise error.SpannerError('No corresponding Type for {}'.format( - self.spanner_type)) + def field_type(self) -> field.FieldType: + return field.field_type_from_ddl(self.spanner_type) diff --git a/spanner_orm/admin/index.py b/spanner_orm/admin/index.py index f4b0ba4..baadd9d 100644 --- a/spanner_orm/admin/index.py +++ b/spanner_orm/admin/index.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +13,8 @@ # limitations under the License. """Model for interacting with Spanner index schema table.""" -from __future__ import annotations +import typing +from typing import Optional from spanner_orm import field from spanner_orm.admin import schema @@ -24,12 +24,13 @@ class IndexSchema(schema.InformationSchema): """Model for interacting with Spanner index schema table.""" __table__ = 'information_schema.indexes' - table_catalog = field.Field(field.String, primary_key=True) - table_schema = field.Field(field.String, primary_key=True) - table_name = field.Field(field.String, primary_key=True) - index_name = field.Field(field.String, primary_key=True) - index_type = field.Field(field.String) - parent_table_name = field.Field(field.String, nullable=True) - is_unique = field.Field(field.Boolean) - is_null_filtered = field.Field(field.Boolean) - index_state = field.Field(field.String) + table_catalog = typing.cast(str, field.Field(field.String, primary_key=True)) + table_schema = typing.cast(str, field.Field(field.String, primary_key=True)) + table_name = typing.cast(str, field.Field(field.String, primary_key=True)) + index_name = typing.cast(str, field.Field(field.String, primary_key=True)) + index_type = typing.cast(str, field.Field(field.String)) + parent_table_name = typing.cast(Optional[str], + field.Field(field.String, nullable=True)) + is_unique = typing.cast(bool, field.Field(field.Boolean)) + is_null_filtered = typing.cast(bool, field.Field(field.Boolean)) + index_state = typing.cast(str, field.Field(field.String)) diff --git a/spanner_orm/admin/index_column.py b/spanner_orm/admin/index_column.py index b2fc736..5b06937 100644 --- a/spanner_orm/admin/index_column.py +++ b/spanner_orm/admin/index_column.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +13,8 @@ # limitations under the License. """Model for interacting with Spanner index column schema table.""" -from __future__ import annotations +import typing +from typing import Optional from spanner_orm import field from spanner_orm.admin import schema @@ -24,12 +24,14 @@ class IndexColumnSchema(schema.InformationSchema): """Model for interacting with Spanner index column schema table.""" __table__ = 'information_schema.index_columns' - table_catalog = field.Field(field.String, primary_key=True) - table_schema = field.Field(field.String, primary_key=True) - table_name = field.Field(field.String, primary_key=True) - index_name = field.Field(field.String, primary_key=True) - column_name = field.Field(field.String, primary_key=True) - ordinal_position = field.Field(field.Integer, nullable=True) - column_ordering = field.Field(field.String, nullable=True) - is_nullable = field.Field(field.String) - spanner_type = field.Field(field.String) + table_catalog = typing.cast(str, field.Field(field.String, primary_key=True)) + table_schema = typing.cast(str, field.Field(field.String, primary_key=True)) + table_name = typing.cast(str, field.Field(field.String, primary_key=True)) + index_name = typing.cast(str, field.Field(field.String, primary_key=True)) + column_name = typing.cast(str, field.Field(field.String, primary_key=True)) + ordinal_position = typing.cast(Optional[int], + field.Field(field.Integer, nullable=True)) + column_ordering = typing.cast(Optional[str], + field.Field(field.String, nullable=True)) + is_nullable = typing.cast(str, field.Field(field.String)) + spanner_type = typing.cast(str, field.Field(field.String)) diff --git a/spanner_orm/admin/metadata.py b/spanner_orm/admin/metadata.py index 04bd89f..5777634 100644 --- a/spanner_orm/admin/metadata.py +++ b/spanner_orm/admin/metadata.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,8 +13,6 @@ # limitations under the License. """Retrieves table metadata from Spanner.""" -from __future__ import annotations - import collections from typing import Any, Dict, Optional, Type @@ -30,7 +27,7 @@ from spanner_orm.admin import table -class SpannerMetadata(object): +class SpannerMetadata: """Gathers information about a table from Spanner.""" @classmethod @@ -73,9 +70,10 @@ def model(cls, table_name) -> Optional[Type[model.Model]]: def tables(cls) -> Dict[str, Dict[str, Any]]: """Compiles table information from column schema.""" column_data = collections.defaultdict(dict) - columns = column.ColumnSchema.where(None, - condition.equal_to('table_catalog', ''), - condition.equal_to('table_schema', '')) + columns = column.ColumnSchema.where( + condition.equal_to('table_catalog', ''), + condition.equal_to('table_schema', ''), + ) for column_row in columns: new_field = field.Field( column_row.field_type(), nullable=column_row.nullable()) @@ -84,9 +82,10 @@ def tables(cls) -> Dict[str, Dict[str, Any]]: column_data[column_row.table_name][column_row.column_name] = new_field table_data = collections.defaultdict(dict) - tables = table.TableSchema.where(None, - condition.equal_to('table_catalog', ''), - condition.equal_to('table_schema', '')) + tables = table.TableSchema.where( + condition.equal_to('table_catalog', ''), + condition.equal_to('table_schema', ''), + ) for table_row in tables: name = table_row.table_name table_data[name]['parent_table'] = table_row.parent_table_name @@ -100,9 +99,10 @@ def indexes(cls) -> Dict[str, Dict[str, Any]]: # Results are ordered by that so the index columns are added in the # correct order. index_column_schemas = index_column.IndexColumnSchema.where( - None, condition.equal_to('table_catalog', ''), + condition.equal_to('table_catalog', ''), condition.equal_to('table_schema', ''), - condition.order_by(('ordinal_position', condition.OrderType.ASC))) + condition.order_by(('ordinal_position', condition.OrderType.ASC)), + ) index_columns = collections.defaultdict(list) storing_columns = collections.defaultdict(list) @@ -114,8 +114,9 @@ def indexes(cls) -> Dict[str, Dict[str, Any]]: storing_columns[key].append(schema.column_name) index_schemas = index_schema.IndexSchema.where( - None, condition.equal_to('table_catalog', ''), - condition.equal_to('table_schema', '')) + condition.equal_to('table_catalog', ''), + condition.equal_to('table_schema', ''), + ) indexes = collections.defaultdict(dict) for schema in index_schemas: key = (schema.table_name, schema.index_name) diff --git a/spanner_orm/admin/migration.py b/spanner_orm/admin/migration.py index e189304..7f63b06 100644 --- a/spanner_orm/admin/migration.py +++ b/spanner_orm/admin/migration.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,25 +13,25 @@ # limitations under the License. """Holds information about a specific migration.""" -from __future__ import annotations - from typing import Callable, Optional from spanner_orm.admin import update -def no_update_callable() -> update.SchemaUpdate: +def no_update_callable() -> update.MigrationUpdate: return update.NoUpdate() class Migration: """Holds information about a specific migration.""" - def __init__(self, - migration_id: str, - prev_migration_id: Optional[str], - upgrade: Optional[Callable[[], update.SchemaUpdate]] = None, - downgrade: Optional[Callable[[], update.SchemaUpdate]] = None): + def __init__( + self, + migration_id: str, + prev_migration_id: Optional[str], + upgrade: Optional[Callable[[], update.MigrationUpdate]] = None, + downgrade: Optional[Callable[[], update.MigrationUpdate]] = None, + ): self._id = migration_id self._prev = prev_migration_id self._upgrade = upgrade or no_update_callable @@ -47,9 +46,9 @@ def prev_migration_id(self) -> Optional[str]: return self._prev @property - def upgrade(self) -> Optional[Callable[[], update.SchemaUpdate]]: + def upgrade(self) -> Optional[Callable[[], update.MigrationUpdate]]: return self._upgrade @property - def downgrade(self) -> Optional[Callable[[], update.SchemaUpdate]]: + def downgrade(self) -> Optional[Callable[[], update.MigrationUpdate]]: return self._downgrade diff --git a/spanner_orm/admin/migration.skel b/spanner_orm/admin/migration.skel index de43076..049022c 100644 --- a/spanner_orm/admin/migration.skel +++ b/spanner_orm/admin/migration.skel @@ -1,4 +1,4 @@ -"""$migration_name +"""Spanner ORM migration: $migration_name. Migration ID: $migration_id Created: $current_date @@ -9,10 +9,12 @@ import spanner_orm migration_id = $migration_id prev_migration_id = $prev_migration_id -# Returns a SchemaUpdate object that tells what should be changed -def upgrade() -> spanner_orm.NoUpdate: + +def upgrade() -> spanner_orm.MigrationUpdate: + """See spanner_orm migrations interface.""" return spanner_orm.NoUpdate() -# Returns a SchemaUpdate object that tells how to roll back the changes -def downgrade() -> spanner_orm.NoUpdate: + +def downgrade() -> spanner_orm.MigrationUpdate: + """See spanner_orm migrations interface.""" return spanner_orm.NoUpdate() diff --git a/spanner_orm/admin/migration_executor.py b/spanner_orm/admin/migration_executor.py index 774fc16..da5d51a 100644 --- a/spanner_orm/admin/migration_executor.py +++ b/spanner_orm/admin/migration_executor.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,8 +13,6 @@ # limitations under the License. """Handles execution of migrations.""" -from __future__ import annotations - import datetime import logging from typing import Iterable, List, Dict, Optional @@ -68,12 +65,12 @@ def migrate(self, target_migration: Optional[str] = None) -> None: target_migration) for migration_ in migrations: _logger.info('Processing migration %s', migration_.migration_id) - schema_update = migration_.upgrade() - if not isinstance(schema_update, update.SchemaUpdate): + migration_update = migration_.upgrade() + if not isinstance(migration_update, update.MigrationUpdate): raise error.SpannerError( - 'Migration {} did not return a SchemaUpdate'.format( + 'Migration {} did not return a MigrationUpdate'.format( migration_.migration_id)) - schema_update.execute() + migration_update.execute() self._update_status(migration_.migration_id, True) self._hangup() @@ -99,12 +96,12 @@ def rollback(self, target_migration: str) -> None: reversed(self.migrations()), True, target_migration) for migration_ in migrations: _logger.info('Processing migration %s', migration_.migration_id) - schema_update = migration_.downgrade() - if not isinstance(schema_update, update.SchemaUpdate): + migration_update = migration_.downgrade() + if not isinstance(migration_update, update.MigrationUpdate): raise error.SpannerError( - 'Migration {} did not return a SchemaUpdate'.format( + 'Migration {} did not return a MigrationUpdate'.format( migration_.migration_id)) - schema_update.execute() + migration_update.execute() self._update_status(migration_.migration_id, False) self._hangup() @@ -167,8 +164,7 @@ def _update_status(self, migration_id: str, new_status: bool) -> None: 'migrated': new_status, 'update_time': datetime.datetime.utcnow(), }) - migration_status.MigrationStatus.save_batch( - None, [new_model], force_write=True) + migration_status.MigrationStatus.save_batch([new_model], force_write=True) self._migration_status()[migration_id] = new_status def _validate_migrations(self) -> None: diff --git a/spanner_orm/admin/migration_manager.py b/spanner_orm/admin/migration_manager.py index 046d6d9..3d883bd 100644 --- a/spanner_orm/admin/migration_manager.py +++ b/spanner_orm/admin/migration_manager.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,8 +13,6 @@ # limitations under the License. """Handles reading and writing of migration files.""" -from __future__ import annotations - import datetime import importlib import os @@ -43,7 +40,8 @@ def generate(self, migration_name: str) -> str: """Creates a new migration that is the last migration to be executed.""" migration_id = uuid.uuid4().hex[-12:] prev_id = self.migrations[-1].migration_id if self.migrations else None - now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M') + now = datetime.datetime.now().astimezone().isoformat( + sep=' ', timespec='seconds') skeleton_directory = os.path.dirname(os.path.abspath(__file__)) skeleton_file = os.path.join(skeleton_directory, 'migration.skel') @@ -74,9 +72,9 @@ def _migration_from_file(self, filename: str) -> migration.Migration: """Loads a single migration from the given filename in the base dir.""" module_name = re.sub(r'\W', '_', filename) path = os.path.join(self.basedir, filename) - spec = importlib.util.spec_from_file_location(module_name, path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + module = importlib.util.module_from_spec( + importlib.util.spec_from_file_location(module_name, path)) + importlib.machinery.SourceFileLoader(module_name, path).exec_module(module) try: result = migration.Migration(module.migration_id, module.prev_migration_id, @@ -91,7 +89,7 @@ def _all_migrations(self) -> List[migration.Migration]: migrations = [] for filename in os.listdir(self.basedir): _, ext = os.path.splitext(filename) - if ext == '.py': + if ext == '.py' and filename != '__init__.py': migrations.append(self._migration_from_file(filename)) return migrations diff --git a/spanner_orm/admin/migration_status.py b/spanner_orm/admin/migration_status.py index e580192..f128739 100644 --- a/spanner_orm/admin/migration_status.py +++ b/spanner_orm/admin/migration_status.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +13,8 @@ # limitations under the License. """Indicates whether a migration has been applied to the current database.""" -from __future__ import annotations +import datetime +import typing from spanner_orm import field from spanner_orm import model @@ -28,6 +28,6 @@ def spanner_api(cls) -> api.SpannerAdminApi: return api.spanner_admin_api() __table__ = 'spanner_orm_migrations' - id = field.Field(field.String, primary_key=True) - migrated = field.Field(field.Boolean) - update_time = field.Field(field.Timestamp) + id = typing.cast(str, field.Field(field.String, primary_key=True)) + migrated = typing.cast(bool, field.Field(field.Boolean)) + update_time = typing.cast(datetime.datetime, field.Field(field.Timestamp)) diff --git a/spanner_orm/admin/schema.py b/spanner_orm/admin/schema.py index 9887d55..5a5a1d8 100644 --- a/spanner_orm/admin/schema.py +++ b/spanner_orm/admin/schema.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,8 +13,6 @@ # limitations under the License. """Base model for schemas.""" -from __future__ import annotations - from typing import NoReturn from spanner_orm import error diff --git a/spanner_orm/admin/scripts.py b/spanner_orm/admin/scripts.py index 4dc4557..2b9857f 100644 --- a/spanner_orm/admin/scripts.py +++ b/spanner_orm/admin/scripts.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,8 +13,6 @@ # limitations under the License. """Entry point for spanner_orm scripts.""" -from __future__ import annotations - import argparse from typing import Any @@ -44,11 +41,10 @@ def rollback(args: Any) -> None: def main(as_module: bool = False) -> None: prog = 'spanner-orm' if as_module else None parser = argparse.ArgumentParser(prog=prog) + # 'subcommand' is actually required, but required subparsers are not supported + # for Python < 3.7. subparsers = parser.add_subparsers( - dest='subcommand', - title='subcommands', - description='valid subcommands', - required=True) + dest='subcommand', title='subcommands', description='valid subcommands') generate_parser = subparsers.add_parser( 'generate', help='Generate a new migration') @@ -75,7 +71,10 @@ def main(as_module: bool = False) -> None: rollback_parser.set_defaults(execute=rollback) args = parser.parse_args() - args.execute(args) + if args.subcommand is None: + parser.print_help() + else: + args.execute(args) if __name__ == '__main__': diff --git a/spanner_orm/admin/table.py b/spanner_orm/admin/table.py index 0c5838c..ac98e40 100644 --- a/spanner_orm/admin/table.py +++ b/spanner_orm/admin/table.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +13,8 @@ # limitations under the License. """Model for interacting with Spanner column schema table.""" -from __future__ import annotations +import typing +from typing import Optional from spanner_orm import field from spanner_orm.admin import schema @@ -24,8 +24,10 @@ class TableSchema(schema.InformationSchema): """Model for interacting with Spanner column schema table.""" __table__ = 'information_schema.tables' - table_catalog = field.Field(field.String, primary_key=True) - table_schema = field.Field(field.String, primary_key=True) - table_name = field.Field(field.String, primary_key=True) - parent_table_name = field.Field(field.String, nullable=True) - on_delete_action = field.Field(field.String, nullable=True) + table_catalog = typing.cast(str, field.Field(field.String, primary_key=True)) + table_schema = typing.cast(str, field.Field(field.String, primary_key=True)) + table_name = typing.cast(str, field.Field(field.String, primary_key=True)) + parent_table_name = typing.cast(Optional[str], + field.Field(field.String, nullable=True)) + on_delete_action = typing.cast(Optional[str], + field.Field(field.String, nullable=True)) diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index e2a73e8..f716e08 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,21 +13,37 @@ # limitations under the License. """Used with SpannerAdminApi to manage Spanner schema updates.""" -from __future__ import annotations - import abc from typing import Iterable, List, Optional, Type from spanner_orm import condition from spanner_orm import error from spanner_orm import field +from spanner_orm import foreign_key_relationship +from spanner_orm import index from spanner_orm import model from spanner_orm.admin import api from spanner_orm.admin import index_column from spanner_orm.admin import metadata -class SchemaUpdate(abc.ABC): +class MigrationUpdate(abc.ABC): + """Base class for all updates that can happen in a migration.""" + + @abc.abstractmethod + def execute(self) -> None: + """Executes the update.""" + raise NotImplementedError + + +class NoUpdate(MigrationUpdate): + """Update that does nothing, for migrations that don't update db schemas.""" + + def execute(self) -> None: + """See base class.""" + + +class SchemaUpdate(MigrationUpdate, abc.ABC): """Base class for specifying schema updates.""" @abc.abstractmethod @@ -39,9 +54,8 @@ def execute(self) -> None: self.validate() api.spanner_admin_api().update_schema(self.ddl()) - @abc.abstractmethod def validate(self) -> None: - raise NotImplementedError + pass # TODO(dseomn): Remove this method. class CreateTable(SchemaUpdate): @@ -51,13 +65,16 @@ def __init__(self, model_: Type[model.Model]): self._model = model_ def ddl(self) -> str: - fields = [ + key_fields = [ '{} {}'.format(name, field.ddl()) for name, field in self._model.fields.items() ] + key_fields_ddl = ', '.join(key_fields) + for relation in self._model.foreign_key_relations.values(): + key_fields_ddl += f', {relation.ddl}' index_ddl = 'PRIMARY KEY ({})'.format(', '.join(self._model.primary_keys)) statement = 'CREATE TABLE {} ({}) {}'.format(self._model.table, - ', '.join(fields), index_ddl) + key_fields_ddl, index_ddl) if self._model.interleaved: statement += ', INTERLEAVE IN PARENT {parent} ON DELETE CASCADE'.format( @@ -78,6 +95,11 @@ def validate(self) -> None: self._validate_primary_keys() + if self._model.indexes.keys() - {index.Index.PRIMARY_INDEX}: + raise error.SpannerError( + 'Secondary indexes cannot be created by CreateTable; use CreateIndex ' + 'in a separate migration.') + def _validate_parent(self) -> None: """Verifies that the parent table information is valid.""" parent_primary_keys = self._model.interleaved.primary_keys @@ -113,29 +135,6 @@ def __init__(self, table_name: str): def ddl(self) -> str: return 'DROP TABLE {}'.format(self._table) - def validate(self) -> None: - existing_model = metadata.SpannerMetadata.model(self._table) - if not existing_model: - raise error.SpannerError('Table {} does not exist'.format(self._table)) - - # Model indexes include the primary index - if len(existing_model.indexes) > 1: - raise error.SpannerError('Table {} has a secondary index'.format( - self._table)) - - self._validate_not_interleaved(existing_model) - - def _validate_not_interleaved(self, - existing_model: Type[model.Model]) -> None: - for model_ in metadata.SpannerMetadata.models().values(): - if model_.interleaved == existing_model: - raise error.SpannerError('Table {} has interleaved table {}'.format( - self._table, model_.table)) - for index in model_.indexes.values(): - if index.parent == self._table: - raise error.SpannerError('Table {} has interleaved index {}'.format( - self._table, index.name)) - class AddColumn(SchemaUpdate): """Update for adding a column to an existing table. @@ -184,8 +183,9 @@ def validate(self) -> None: # Verify no indices exist on the column we're trying to drop num_indexed_columns = index_column.IndexColumnSchema.count( - None, condition.equal_to('column_name', self._column), - condition.equal_to('table_name', self._table)) + condition.equal_to('column_name', self._column), + condition.equal_to('table_name', self._table), + ) if num_indexed_columns > 0: raise error.SpannerError('Column {} is indexed'.format(self._column)) @@ -220,7 +220,7 @@ def validate(self) -> None: old_field = model_.fields[self._column] # Validate that the only alteration is to change column nullability - if self._field.field_type() != old_field.field_type(): + if self._field.field_type().ddl() != old_field.field_type().ddl(): raise error.SpannerError('Column {} is changing type'.format( self._column)) if self._field.nullable() == old_field.nullable(): @@ -235,16 +235,25 @@ def __init__(self, index_name: str, columns: Iterable[str], interleaved: Optional[str] = None, + null_filtered: bool = False, + unique: bool = False, storing_columns: Optional[Iterable[str]] = None): self._table = table_name self._index = index_name self._columns = columns self._parent_table = interleaved + self._null_filtered = null_filtered + self._unique = unique self._storing_columns = storing_columns or [] def ddl(self) -> str: - statement = 'CREATE INDEX {} ON {} ({})'.format(self._index, self._table, - ', '.join(self._columns)) + statement = 'CREATE' + if self._unique: + statement += ' UNIQUE' + if self._null_filtered: + statement += ' NULL_FILTERED' + statement += (f' INDEX {self._index} ' + f'ON {self._table} ({", ".join(self._columns)})') if self._storing_columns: statement += 'STORING ({})'.format(', '.join(self._storing_columns)) if self._parent_table: @@ -318,17 +327,20 @@ def validate(self) -> None: self._index)) -class NoUpdate(SchemaUpdate): - """Update that does nothing, for migrations that don't update db schemas.""" +class ExecutePartitionedDml(MigrationUpdate): + """Update for running arbitrary partitioned DML. - def ddl(self) -> str: - return '' + NOTE: Partitioned DML queries should be idempotent. See + https://cloud.google.com/spanner/docs/dml-partitioned for details, and more + information about partitioned DML. + """ - def execute(self) -> None: - pass + def __init__(self, dml: str): + self._dml = dml - def validate(self) -> None: - pass + def execute(self) -> None: + """See base class.""" + api.spanner_admin_api().execute_partitioned_dml(self._dml) def model_creation_ddl(model_: Type[model.Model]) -> List[str]: diff --git a/spanner_orm/api.py b/spanner_orm/api.py index 3254dc0..d049617 100644 --- a/spanner_orm/api.py +++ b/spanner_orm/api.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,21 +13,37 @@ # limitations under the License. """Class that handles API calls to Spanner.""" -from __future__ import annotations - import abc -from typing import Any, Callable, Iterable, Optional, TypeVar - -from spanner_orm import error +from typing import Any, Callable, Dict, Iterable, Optional, TypeVar, Union +import warnings +from google.api_core import client_options as api_client_options +from google.api_core import exceptions from google.auth import credentials as auth_credentials from google.cloud import spanner from google.cloud.spanner_v1 import database as spanner_database +from google.cloud.spanner_v1 import pool as spanner_pool +from spanner_orm import error CallableReturn = TypeVar('CallableReturn') -class SpannerReadApi(abc.ABC): +class SpannerRetryableApi(abc.ABC): + + def _ensure_session(self, api_method, *args, **kwargs): + try: + return api_method(*args, **kwargs) + except exceptions.NotFound as e: + # https://cloud.google.com/spanner/docs/sessions#handle_deleted_sessions + # states that Spanner may delete existing sessions for various reasons. + if not 'Session not found' in e.message: + raise + + spanner_api().spanner_connection.connect() + return api_method(*args, **kwargs) + + +class SpannerReadApi(SpannerRetryableApi): """Handles sending read requests to Spanner.""" @property @@ -51,16 +66,19 @@ def run_read_only(self, method: Callable[..., CallableReturn], *args: Any, Returns: The return value from `method` will be returned from this method """ + return self._ensure_session(self._run_read_only, method, *args, **kwargs) + + def _run_read_only(self, method, *args, **kwargs): with self._connection.snapshot(multi_use=True) as snapshot: return method(snapshot, *args, **kwargs) -class SpannerWriteApi(abc.ABC): +class SpannerWriteApi(SpannerRetryableApi): """Handles sending write requests to Spanner.""" @property @abc.abstractmethod - def _connection(self) -> spanner_database.SpannerDatabase: + def _connection(self) -> spanner_database.Database: raise NotImplementedError def run_write(self, method: Callable[..., CallableReturn], *args: Any, @@ -81,24 +99,45 @@ def run_write(self, method: Callable[..., CallableReturn], *args: Any, Returns: The return value from `method` will be returned from this method """ - return self._connection.run_in_transaction(method, *args, **kwargs) + return self._ensure_session(self._connection.run_in_transaction, method, + *args, **kwargs) class SpannerConnection: """Class that handles connecting to a Spanner database.""" - def __init__(self, - instance: str, - database: str, - project: Optional[str] = None, - credentials: Optional[auth_credentials.Credentials] = None, - pool: Optional[spanner.Pool] = None, - create_ddl: Optional[Iterable[str]] = None): + def __init__( + self, + instance: str, + database: str, + project: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + pool: Optional[spanner_pool.AbstractSessionPool] = None, + create_ddl: Optional[Iterable[str]] = None, + *, + client_options: Union[api_client_options.ClientOptions, Dict[Any, Any], + None] = None, + ): """Connects to the specified Spanner database.""" - client = spanner.Client(project=project, credentials=credentials) - instance = client.instance(instance) + self._instance = instance + self._database = database + self._project = project + self._credentials = credentials + self._pool = pool + self._create_ddl = create_ddl + self._client_options = client_options + self.connect() + + def connect(self): + """Establish a new connection to the specified Spanner database.""" + client = spanner.Client( + project=self._project, + credentials=self._credentials, + client_options=self._client_options, + ) + instance = client.instance(self._instance) self.database = instance.database( - database, pool=pool, ddl_statements=create_ddl or ()) + self._database, pool=self._pool, ddl_statements=self._create_ddl or ()) class SpannerApi(SpannerReadApi, SpannerWriteApi): @@ -107,6 +146,11 @@ class SpannerApi(SpannerReadApi, SpannerWriteApi): def __init__(self, connection: SpannerConnection): self._spanner_connection = connection + @property + def spanner_connection(self) -> SpannerConnection: + """Connection to the database.""" + return self._spanner_connection + @property def _connection(self): return self._spanner_connection.database @@ -115,12 +159,20 @@ def _connection(self): _api = None # type: Optional[SpannerApi] -def connect(instance: str, - database: str, - project: Optional[str] = None, - credentials: Optional[auth_credentials.Credentials] = None, - pool: Optional[spanner.Pool] = None) -> SpannerApi: - """Connects to the Spanner database and sets the global spanner_api.""" +def connect( + instance: str, + database: str, + project: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + pool: Optional[spanner_pool.AbstractSessionPool] = None) -> SpannerApi: + """Connects to the Spanner database and sets the global spanner_api. + + Deprecated in favor of from_connection(). + """ + warnings.warn( + DeprecationWarning( + 'Please use ' + 'spanner_orm.from_connection(spanner_orm.SpannerConnection(...))')) connection = SpannerConnection( instance, database, project=project, credentials=credentials, pool=pool) return from_connection(connection) diff --git a/spanner_orm/condition.py b/spanner_orm/condition.py index 2c9e43b..0ddbbaa 100644 --- a/spanner_orm/condition.py +++ b/spanner_orm/condition.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,18 +13,27 @@ # limitations under the License. """Used with Model#where and Model#count to help create Spanner queries.""" -from __future__ import annotations - import abc +import base64 +import dataclasses +import datetime +import decimal import enum -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union +import string +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union from spanner_orm import error from spanner_orm import field +from spanner_orm import foreign_key_relationship from spanner_orm import index from spanner_orm import relationship -from google.cloud.spanner_v1.proto import type_pb2 +from google.api_core import datetime_helpers +from google.cloud import spanner +from google.cloud import spanner_v1 +import immutabledict + +T = TypeVar('T') class Segment(enum.Enum): @@ -97,7 +105,7 @@ def sql(self) -> str: def _sql(self) -> str: pass - def types(self) -> Dict[str, type_pb2.Type]: + def types(self) -> Dict[str, spanner_v1.Type]: """Returns parameter types to be used in the SQL query. Returns: @@ -109,7 +117,7 @@ def types(self) -> Dict[str, type_pb2.Type]: return self._types() @abc.abstractmethod - def _types(self) -> Dict[str, type_pb2.Type]: + def _types(self) -> Dict[str, spanner_v1.Type]: raise NotImplementedError @abc.abstractmethod @@ -117,6 +125,199 @@ def _validate(self, model_class: Type[Any]) -> None: raise NotImplementedError +GuessableParamType = Union[ + bool, # + int, # + float, # + datetime_helpers.DatetimeWithNanoseconds, # + datetime.datetime, # + datetime.date, # + bytes, # + str, # + decimal.Decimal, # + # These types technically include List[None] and Tuple[None, ...], but + # those can't be guessed. + List[Optional[bool]], # + List[Optional[int]], # + List[Optional[float]], # + List[Optional[datetime_helpers.DatetimeWithNanoseconds]], # + List[Optional[datetime.datetime]], # + List[Optional[datetime.date]], # + List[Optional[bytes]], # + List[Optional[str]], # + List[Optional[decimal.Decimal]], # + Tuple[Optional[bool], ...], # + Tuple[Optional[int], ...], # + Tuple[Optional[float], ...], # + Tuple[Optional[datetime_helpers.DatetimeWithNanoseconds], ...], # + Tuple[Optional[datetime.datetime], ...], # + Tuple[Optional[datetime.date], ...], # + Tuple[Optional[bytes], ...], # + Tuple[Optional[str], ...], # + Tuple[Optional[decimal.Decimal], ...], # +] + + +def _spanner_type_of_python_object( + value: GuessableParamType) -> spanner_v1.Type: + """Returns the Cloud Spanner type of the given object. + + Args: + value: Object to guess the type of. + """ + # See + # https://github.com/googleapis/python-spanner/blob/master/google/cloud/spanner_v1/proto/type.proto + # for the Cloud Spanner types, and + # https://github.com/googleapis/python-spanner/blob/e981adb3157bb06e4cb466ca81d74d85da976754/google/cloud/spanner_v1/_helpers.py#L91-L133 + # for Python types. + if value is None: + raise TypeError( + 'Cannot infer type of None, because any SQL type can be NULL.') + simple_type_code = { + bool: spanner_v1.TypeCode.BOOL, + int: spanner_v1.TypeCode.INT64, + float: spanner_v1.TypeCode.FLOAT64, + datetime_helpers.DatetimeWithNanoseconds: spanner_v1.TypeCode.TIMESTAMP, + datetime.datetime: spanner_v1.TypeCode.TIMESTAMP, + datetime.date: spanner_v1.TypeCode.DATE, + bytes: spanner_v1.TypeCode.BYTES, + str: spanner_v1.TypeCode.STRING, + decimal.Decimal: spanner_v1.TypeCode.NUMERIC, + }.get(type(value)) + if simple_type_code is not None: + return spanner_v1.Type(code=simple_type_code) + elif isinstance(value, (list, tuple)): + element_types = tuple( + _spanner_type_of_python_object(item) + for item in value + if item is not None) + if element_types and all( + a == b for a, b in zip(element_types, element_types[1:])): + return spanner_v1.Type( + code=spanner_v1.TypeCode.ARRAY, + array_element_type=element_types[0], + ) + else: + raise TypeError( + f'Array does not have elements of exactly one type: {value!r}') + else: + raise TypeError('Unknown type: {value!r}') + + +@dataclasses.dataclass +class Param: + """Parameter for substitution into a SQL query.""" + value: Any + type: spanner_v1.Type + + @classmethod + def from_value(cls: Type[T], value: GuessableParamType) -> T: + """Returns a Param with the type guessed from a Python value.""" + guessed_type = _spanner_type_of_python_object(value) + + # BYTES must be base64-encoded, see + # https://github.com/googleapis/python-spanner/blob/87789c939990794bfd91f5300bedc449fd74bd7e/google/cloud/spanner_v1/proto/type.proto#L108-L110 + if (isinstance(value, bytes) and guessed_type == spanner.param_types.BYTES): + encoded_value = base64.b64encode(value).decode() + elif (isinstance(value, (list, tuple)) and + all(isinstance(x, bytes) for x in value if x is not None) and + guessed_type == spanner_v1.Type( + code=spanner_v1.TypeCode.ARRAY, + array_element_type=spanner.param_types.BYTES, + )): + encoded_value = tuple( + None if item is None else base64.b64encode(item).decode() + for item in value) + else: + encoded_value = value + + return cls(value=encoded_value, type=guessed_type) + + +@dataclasses.dataclass +class Column: + """Named column; consider using field.Field instead.""" + name: str + + +# Something that can be substituted into a SQL query. +Substitution = Union[Param, field.Field, Column] + + +class ArbitraryCondition(Condition): + """Condition with support for arbitrary SQL.""" + + def __init__( + self, + sql_template: str, + substitutions: Mapping[str, Substitution] = immutabledict.immutabledict(), + *, + segment: Segment, + ): + """Initializer. + + Args: + sql_template: string.Template-compatible template string for the SQL. + substitutions: Substitutions to make in sql_template. + segment: Segment for this Condition. + """ + super().__init__() + self._sql_template = string.Template(sql_template) + self._substitutions = substitutions + self._segment = segment + + # This validates the template. + self._sql_template.substitute({k: '' for k in self._substitutions}) + + def segment(self) -> Segment: + """See base class.""" + return self._segment + + def _validate(self, model_class: Type[Any]) -> None: + """See base class.""" + for substitution in self._substitutions.values(): + if isinstance(substitution, field.Field): + if substitution not in model_class.fields.values(): + raise error.ValidationError( + f'Field {substitution.name!r} does not belong to the Model for ' + f'table {model_class.table!r}.') + elif isinstance(substitution, Column): + if substitution.name not in model_class.fields: + raise error.ValidationError( + f'Column {substitution.name!r} does not exist in the Model for ' + f'table {model_class.table!r}.') + + def _params(self) -> Dict[str, Any]: + """See base class.""" + return { + self.key(k): v.value + for k, v in self._substitutions.items() + if isinstance(v, Param) + } + + def _types(self) -> Dict[str, spanner_v1.Type]: + """See base class.""" + return { + self.key(k): v.type + for k, v in self._substitutions.items() + if isinstance(v, Param) + } + + def _sql_for_substitution(self, key: str, substitution: Substitution) -> str: + if isinstance(substitution, Param): + return f'@{self.key(key)}' + else: + assert isinstance(substitution, (field.Field, Column)) + return f'{self.model_class.column_prefix}.{substitution.name}' + + def _sql(self) -> str: + """See base class.""" + return self._sql_template.substitute({ + k: self._sql_for_substitution(k, v) + for k, v in self._substitutions.items() + }) + + class ColumnsEqualCondition(Condition): """Used to join records by matching column values.""" @@ -140,7 +341,7 @@ def _sql(self) -> str: other_table=self.destination_model_class.table, other_column=self.destination_column) - def _types(self) -> Dict[str, type_pb2.Type]: + def _types(self) -> Dict[str, spanner_v1.Type]: return {} def _validate(self, model_class: Type[Any]) -> None: @@ -153,16 +354,16 @@ def _validate(self, model_class: Type[Any]) -> None: self.destination_column, self.destination_model_class.table)) dest = self.destination_model_class.fields[self.destination_column] - if (origin.field_type() != dest.field_type() or + if (not origin.field_type().comparable_with(dest.field_type()) or origin.nullable() != dest.nullable()): raise error.ValidationError('Types of {} and {} do not match'.format( origin.name, dest.name)) -class ForceIndexCondition(Condition): - """Used to indicate which index should be used in a Spanner query.""" +class _IndexCondition(Condition): + """Base class for conditions based on an Index.""" - def __init__(self, index_or_name: Union[Type[index.Index], str]): + def __init__(self, index_or_name: Union[index.Index, str]): super().__init__() if isinstance(index_or_name, index.Index): self.name = index_or_name.name @@ -175,6 +376,27 @@ def bind(self, model_class: Type[Any]) -> None: super().bind(model_class) self.index = self.model_class.indexes[self.name] + def _validate(self, model_class: Type[Any]) -> None: + if self.name not in model_class.indexes: + raise error.ValidationError('{} is not an index on {}'.format( + self.name, model_class.table)) + if self.index and self.index != model_class.indexes[self.name]: + raise error.ValidationError('{} does not belong to {}'.format( + self.index.name, model_class.table)) + + +class ForceIndexCondition(_IndexCondition): + """Used to indicate which index should be used in a Spanner query.""" + + def __init__( + self, + index_or_name: Union[index.Index, str], + *, + extra_hints: Sequence[str] = (), + ): + super().__init__(index_or_name) + self._extra_hints = extra_hints + def _params(self) -> Dict[str, Any]: return {} @@ -182,31 +404,71 @@ def segment(self) -> Segment: return Segment.FROM def _sql(self) -> str: - return '@{{FORCE_INDEX={}}}'.format(self.name) + hints = (f'FORCE_INDEX={self.name}', *self._extra_hints) + return f'@{{{",".join(hints)}}}' - def _types(self) -> Dict[str, type_pb2.Type]: + def _types(self) -> Dict[str, spanner_v1.Type]: return {} def _validate(self, model_class: Type[Any]) -> None: - if self.name not in model_class.indexes: - raise error.ValidationError('{} is not an index on {}'.format( - self.name, model_class.table)) - if self.index and self.index != model_class.indexes[self.name]: - raise error.ValidationError('{} does not belong to {}'.format( - self.index.name, model_class.table)) - + super()._validate(model_class) if model_class.indexes[self.name].primary: raise error.ValidationError('Cannot force query using primary index') +class _IndexIgnoreNullsCondition(_IndexCondition): + """Condition to filter NULL values in any column of an index.""" + + def _params(self) -> Dict[str, Any]: + return {} + + def segment(self) -> Segment: + return Segment.WHERE + + def _sql(self) -> str: + return '({})'.format(' AND '.join( + f'{column} IS NOT NULL' for column in self.index.columns)) + + def _types(self) -> Dict[str, spanner_v1.Type]: + return {} + + class IncludesCondition(Condition): """Used to include related model_classs via a relation in a Spanner query.""" - def __init__(self, - relation_or_name: Union[relationship.Relationship, str], - conditions: List[Condition] = None): + def __init__( + self, + relation_or_name: Union[relationship.Relationship, + foreign_key_relationship.ForeignKeyRelationship, + str], + conditions: Optional[List[Condition]] = None, + # Default argument is `False` for backwards-compatability. + foreign_key_relation=False, + ): + """Initializer. + + + Args: + relation: Name of the relationship on the origin model or the Relationship/ + ForeignKeyRelationship on the origin model class used to retrieve + associated objects + conditions: Conditions to apply on the subquery + foreign_key_relation: True if the relation is a foreign key relation, + False if it is a legacy relation (eg not enforced in Spanner) + """ super().__init__() + self.foreign_key_relation = foreign_key_relation if isinstance(relation_or_name, relationship.Relationship): + if foreign_key_relation: + raise ValueError('Must pass foreign key relation if ' + '`foreign_key_relation=True`.') + self.name = relation_or_name.name + self.relation = relation_or_name + elif isinstance(relation_or_name, + foreign_key_relationship.ForeignKeyRelationship): + if not foreign_key_relation: + raise ValueError( + 'Must pass legacy relation if `foreign_key_relation=False`.') self.name = relation_or_name.name self.relation = relation_or_name else: @@ -216,30 +478,45 @@ def __init__(self, def bind(self, model_class: Type[Any]) -> None: super().bind(model_class) - self.relation = self.model_class.relations[self.name] + if self.foreign_key_relation: + self.relation = self.model_class.foreign_key_relations[self.name] + else: + self.relation = self.model_class.relations[self.name] @property def conditions(self) -> List[Condition]: """Generate the child conditions based on the relationship constraints.""" - if not self.relation: + relation_conditions = [] + if isinstance(self.relation, + foreign_key_relationship.ForeignKeyRelationship): + for pair in self.relation.constraint.columns.items(): + referencing_column, referenced_column = pair + relation_conditions.append( + ColumnsEqualCondition(referenced_column, self.model_class, + referencing_column)) + elif isinstance(self.relation, relationship.Relationship): + for constraint in self.relation.constraints: + # This is backward from what you might imagine because the condition + # will be processed from the context of the destination model. + relation_conditions.append( + ColumnsEqualCondition(constraint.destination_column, + constraint.origin_class, + constraint.origin_column)) + else: raise error.SpannerError( 'Condition must be bound before conditions is called') - relation_conditions = [] - for constraint in self.relation.constraints: - # This is backward from what you might imagine because the condition will - # be processed from the context of the destination model - relation_conditions.append( - ColumnsEqualCondition(constraint.destination_column, - constraint.origin_class, - constraint.origin_column)) return relation_conditions + self._conditions @property def destination(self) -> Type[Any]: - if not self.relation: + if isinstance(self.relation, + foreign_key_relationship.ForeignKeyRelationship): + return self.relation.constraint.referenced_table + elif isinstance(self.relation, relationship.Relationship): + return self.relation.destination + else: raise error.SpannerError( 'Condition must be bound before destination is called') - return self.relation.destination @property def relation_name(self) -> str: @@ -261,18 +538,25 @@ def segment(self) -> Segment: def _sql(self) -> str: return '' - def _types(self) -> Dict[str, type_pb2.Type]: + def _types(self) -> Dict[str, spanner_v1.Type]: return {} def _validate(self, model_class: Type[Any]) -> None: - if self.name not in model_class.relations: + if self.foreign_key_relation: + model_class_relations = model_class.foreign_key_relations + referenced_table_fn = lambda x: x.constraint.referenced_table + else: + model_class_relations = model_class.relations + referenced_table_fn = lambda x: x.destination + + if self.name not in model_class_relations: raise error.ValidationError('{} is not a relation on {}'.format( self.name, model_class.table)) - if self.relation and self.relation != model_class.relations[self.name]: + if self.relation and self.relation != model_class_relations[self.name]: raise error.ValidationError('{} does not belong to {}'.format( self.relation.name, model_class.table)) + other_model_class = referenced_table_fn(model_class_relations[self.name]) - other_model_class = model_class.relations[self.name].destination for condition in self._conditions: condition._validate(other_model_class) # pylint: disable=protected-access @@ -313,10 +597,10 @@ def _sql(self) -> str: limit_key=self._limit_key, offset_key=self._offset_key) return 'LIMIT @{limit_key}'.format(limit_key=self._limit_key) - def _types(self) -> Dict[str, type_pb2.Type]: - types = {self._limit_key: type_pb2.Type(code=type_pb2.INT64)} + def _types(self) -> Dict[str, spanner_v1.Type]: + types = {self._limit_key: spanner.param_types.INT64} if self.offset: - types[self._offset_key] = type_pb2.Type(code=type_pb2.INT64) + types[self._offset_key] = spanner.param_types.INT64 return types def _validate(self, model_class: Type[Any]) -> None: @@ -329,9 +613,6 @@ class OrCondition(Condition): def __init__(self, *condition_lists: List[Condition]): super().__init__() - if len(condition_lists) < 2: - raise error.SpannerError( - 'OrCondition requires at least two lists of conditions') self.condition_lists = condition_lists self.all_conditions = [] for conditions in condition_lists: @@ -358,14 +639,21 @@ def _sql(self) -> str: params += len(condition.params()) for conditions in self.condition_lists: - new_segment = ' AND '.join([condition.sql() for condition in conditions]) - segments.append('({new_segment})'.format(new_segment=new_segment)) - return '({segments})'.format(segments=' OR '.join(segments)) + if conditions: + new_segment = ' AND '.join( + [condition.sql() for condition in conditions]) + segments.append('({new_segment})'.format(new_segment=new_segment)) + else: + segments.append('TRUE') + if segments: + return '({segments})'.format(segments=' OR '.join(segments)) + else: + return 'FALSE' def segment(self) -> Segment: return Segment.WHERE - def _types(self) -> type_pb2.Type: + def _types(self) -> spanner_v1.Type: result = {} for condition in self.all_conditions: condition.suffix = str(int(self.suffix or 0) + len(result)) @@ -410,7 +698,7 @@ def _sql(self) -> str: def segment(self) -> Segment: return Segment.ORDER_BY - def _types(self) -> type_pb2.Type: + def _types(self) -> spanner_v1.Type: return {} def _validate(self, model_class: Type[Any]) -> None: @@ -455,7 +743,7 @@ def _sql(self) -> str: operator=self.operator, column_key=self._column_key) - def _types(self) -> type_pb2.Type: + def _types(self) -> spanner_v1.Type: return {self._column_key: self.model_class.fields[self.column].grpc_type()} def _validate(self, model_class: Type[Any]) -> None: @@ -481,14 +769,15 @@ def _sql(self) -> str: operator=self.operator, column_key=self._column_key) - def _types(self) -> type_pb2.Type: + def _types(self) -> spanner_v1.Type: grpc_type = self.model_class.fields[self.column].grpc_type() - list_type = type_pb2.Type(code=type_pb2.ARRAY, array_element_type=grpc_type) + list_type = spanner_v1.Type( + code=spanner_v1.TypeCode.ARRAY, array_element_type=grpc_type) return {self._column_key: list_type} def _validate(self, model_class: Type[Any]) -> None: - if not isinstance(self.value, list): - raise error.ValidationError('{} is not a list'.format(self.value)) + if not isinstance(self.value, Iterable): + raise error.ValidationError('{} is not iterable'.format(self.value)) if self.column not in model_class.fields: raise error.ValidationError('{} is not a column on {}'.format( self.column, model_class.table)) @@ -523,7 +812,7 @@ def _sql(self) -> str: operator=self.nullable_operator) return super()._sql() - def _types(self) -> type_pb2.Type: + def _types(self) -> spanner_v1.Type: if self.is_null(): return {} return super()._types() @@ -573,6 +862,35 @@ def columns_equal(origin_column: str, dest_model_class: Type[Any], return ColumnsEqualCondition(origin_column, dest_model_class, dest_column) +def contains( + haystack: Substitution, + needle: Substitution, + *, + case_sensitive: bool = True, +) -> Condition: + """Condition where the specified haystack contains the given needle. + + Args: + haystack: String or bytes to search. + needle: String or bytes to search for. Must be the same type as haystack. + case_sensitive: Whether comparison should be case sensitive or not. See + https://cloud.google.com/spanner/docs/functions-and-operators#lower for + caveats on how the case conversion works. + + Returns: + A Condition subclass that will be used in the query + """ + return ArbitraryCondition( + ('STRPOS($haystack, $needle) > 0' + if case_sensitive else 'STRPOS(LOWER($haystack), LOWER($needle)) > 0'), + dict( + haystack=haystack, + needle=needle, + ), + segment=Segment.WHERE, + ) + + def equal_to(column: Union[field.Field, str], value: Any) -> EqualityCondition: """Condition where the specified column is equal to the given value. @@ -600,6 +918,35 @@ def force_index(forced_index: Union[index.Index, str]) -> ForceIndexCondition: return ForceIndexCondition(forced_index) +def force_null_filtered_index( + forced_index: Union[index.Index, str]) -> Sequence[Condition]: + """Returns conditions to force the query to use the given NULL_FILTERED index. + + In Cloud Spanner, a query against a NULL_FILTERED index is tested to see if it + can use safely use that index. If using the index would result in incorrect + results (e.g., by ignoring NULL values that would be in the same query without + using the index), it's an error. However, the Cloud Spanner Emulator + doesn't support that check: + https://github.com/GoogleCloudPlatform/cloud-spanner-emulator/blob/e887ff5569684e6e45ce7c90d0fdfb7b1faa1491/common/errors.cc#L1790-L1800 + + For queries that can safely ignore any NULL values covered by the index, this + function returns conditions that both filter out all relevant NULLs (avoiding + the potential error in Cloud Spanner) and disable the check in Cloud Spanner + Emulator. + + Args: + forced_index: NULL_FILTERED index to use. + """ + return ( + ForceIndexCondition( + forced_index, + extra_hints=( + 'spanner_emulator.disable_query_null_filtered_index_check=true', + )), + _IndexIgnoreNullsCondition(forced_index), + ) + + def greater_than(column: Union[field.Field, str], value: Any) -> ComparisonCondition: """Condition where the specified column is greater than the given value. @@ -631,8 +978,11 @@ def greater_than_or_equal_to(column: Union[field.Field, str], return ComparisonCondition('>=', column, value) -def includes(relation: Union[relationship.Relationship, str], - conditions: List[Condition] = None) -> IncludesCondition: +def includes(relation: Union[relationship.Relationship, + foreign_key_relationship.ForeignKeyRelationship, + str], + conditions: Optional[List[Condition]] = None, + foreign_key_relation: bool = False) -> IncludesCondition: """Condition where the objects associated with a relationship are retrieved. Note that the query formed by this call is not a JOIN, but instead a @@ -641,14 +991,17 @@ def includes(relation: Union[relationship.Relationship, str], subquery may be included, but not all conditions may apply Args: - relation: Name of the relationship on the origin model or the Relationship - on the origin model class used to retrievec associated objects + relation: Name of the relationship on the origin model or the Relationship/ + ForeignKeyRelationship on the origin model class used to retrieve + associated objects conditions: Conditions to apply on the subquery + foreign_key_relation: True if the relation is a foreign key relation, + False if it is a legacy relation (ie not enforced in Spanner) Returns: A Condition subclass that will be used in the query """ - return IncludesCondition(relation, conditions) + return IncludesCondition(relation, conditions, foreign_key_relation) def in_list(column: Union[field.Field, str], diff --git a/spanner_orm/decorator.py b/spanner_orm/decorator.py index d492bc8..5e3b378 100644 --- a/spanner_orm/decorator.py +++ b/spanner_orm/decorator.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,8 +13,6 @@ # limitations under the License. """Transaction decorators.""" -from __future__ import annotations - from typing import Callable, TypeVar from spanner_orm import api @@ -93,18 +90,18 @@ def _transactional(spanner_api_method_lambda: Callable[[], Callable[..., T]], func: Callable[..., T]) -> Callable[..., T]: """Returns decorated function.""" - # Spanner library calls given function with transaction as first argument. - # It will call 'spanner_wrapper', and we will move transaction from first - # argument to 'transaction' kwarg and call actual 'func' + def wrapper(*args, **kwargs) -> T: - def spanner_wrapper(transaction, *args, **kwargs) -> T: - return func(*args, transaction=transaction, **kwargs) + def spanner_wrapper(transaction) -> T: + # Spanner library calls given function with transaction as first argument. + # It will call 'spanner_wrapper', and we will move transaction from first + # argument to 'transaction' kwarg and call actual 'func' + return func(*args, transaction=transaction, **kwargs) - def wrapper(*args, **kwargs) -> T: if 'transaction' in kwargs: return func(*args, **kwargs) spanner_api_method = spanner_api_method_lambda() - return spanner_api_method(spanner_wrapper, *args, **kwargs) + return spanner_api_method(spanner_wrapper) return wrapper diff --git a/spanner_orm/error.py b/spanner_orm/error.py index e1fe4d8..62ce862 100644 --- a/spanner_orm/error.py +++ b/spanner_orm/error.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/field.py b/spanner_orm/field.py index 1d6e5b2..cbc83f0 100644 --- a/spanner_orm/field.py +++ b/spanner_orm/field.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,47 +13,101 @@ # limitations under the License. """Helper to deal with field types in Spanner interactions.""" -from __future__ import annotations - import abc +import base64 +import binascii import datetime -from typing import Any, Type +import re +from typing import Any, Optional, Type, Union +import warnings +from google.cloud import spanner +from google.cloud import spanner_v1 from spanner_orm import error -from google.cloud.spanner_v1.proto import type_pb2 +class FieldType(abc.ABC): + """Base class for column types for Spanner interactions.""" -class Field(object): - """Represents a column in a table as a field in a model.""" + @abc.abstractmethod + def ddl(self) -> str: + """Returns the DDL for this type.""" + raise NotImplementedError - def __init__(self, - field_type: Type[FieldType], - nullable: bool = False, - primary_key: bool = False): + @abc.abstractmethod + def grpc_type(self) -> spanner_v1.Type: + """Returns the type as used in Cloud Spanner's gRPC API.""" + raise NotImplementedError + + @abc.abstractmethod + def validate_type(self, value: Any) -> None: + """Raises error.ValidationError if value doesn't match the type.""" + raise NotImplementedError + + def comparable_with(self, other: 'FieldType') -> bool: + """Returns whether two types are comparable.""" + # https://cloud.google.com/spanner/docs/reference/standard-sql/data-types#comparable_data_types + return type(self) == type(other) + + +class Field: + """Represents a column in a table as a field in a model. + + Attributes: + name: Name of the column, or None if this hasn't been bound to a column yet. + """ + name: Optional[str] + + def __init__( + self, + field_type: Union[FieldType, Type[FieldType]], + *, + nullable: bool = False, + primary_key: bool = False, + ): + """Initializer. + + Args: + field_type: Type of the field. Passing a class instead of an instance of + that class is deprecated. + nullable: Whether the field can be NULL. + primary_key: Whether the field is part of the table's primary key. + """ self.name = None - self._type = field_type + if isinstance(field_type, FieldType): + self._type = field_type + else: + warnings.warn( + DeprecationWarning( + 'Pass an instance of FieldType instead of a class.')) + self._type = field_type() self._nullable = nullable self._primary_key = primary_key def ddl(self) -> str: + """Returns DDL for the column.""" if self._nullable: return self._type.ddl() - return '{field_type} NOT NULL'.format(field_type=self._type.ddl()) + return f'{self._type.ddl()} NOT NULL' - def field_type(self) -> Type[FieldType]: + def field_type(self) -> FieldType: + """Returns the type of the field.""" return self._type - def grpc_type(self) -> str: + def grpc_type(self) -> spanner_v1.Type: + """Returns the type as used in Cloud Spanner's gRPC API.""" return self._type.grpc_type() def nullable(self) -> bool: + """Returns whether the field can be NULL.""" return self._nullable def primary_key(self) -> bool: + """Returns whether the field is part of the table's primary key.""" return self._primary_key - def validate(self, value) -> None: + def validate(self, value: Any) -> None: + """Raises error.ValidationError if value isn't compatible with the field.""" if value is None: if not self._nullable: raise error.ValidationError('None set for non-nullable field') @@ -62,111 +115,220 @@ def validate(self, value) -> None: self._type.validate_type(value) -class FieldType(abc.ABC): - """Base class for column types for Spanner interactions.""" - - @staticmethod - @abc.abstractmethod - def ddl() -> str: - raise NotImplementedError - - @staticmethod - @abc.abstractmethod - def grpc_type() -> type_pb2.Type: - raise NotImplementedError - - @staticmethod - @abc.abstractmethod - def validate_type(value: Any) -> None: - raise NotImplementedError - - class Boolean(FieldType): """Represents a boolean type.""" - @staticmethod - def ddl() -> str: + def ddl(self) -> str: + """See base class.""" + del self # Unused. return 'BOOL' - @staticmethod - def grpc_type() -> type_pb2.Type: - return type_pb2.Type(code=type_pb2.BOOL) + def grpc_type(self) -> spanner_v1.Type: + """See base class.""" + del self # Unused. + return spanner.param_types.BOOL - @staticmethod - def validate_type(value: Any) -> None: + def validate_type(self, value: Any) -> None: + """See base class.""" + del self # Unused. if not isinstance(value, bool): - raise error.ValidationError('{} is not of type bool'.format(value)) + raise error.ValidationError(f'{value!r} is not of type bool') class Integer(FieldType): """Represents an integer type.""" - @staticmethod - def ddl() -> str: + def ddl(self) -> str: + """See base class.""" + del self # Unused. return 'INT64' - @staticmethod - def grpc_type() -> type_pb2.Type: - return type_pb2.Type(code=type_pb2.INT64) + def grpc_type(self) -> spanner_v1.Type: + """See base class.""" + del self # Unused. + return spanner.param_types.INT64 - @staticmethod - def validate_type(value: Any) -> None: + def validate_type(self, value: Any) -> None: + """See base class.""" + del self # Unused. if not isinstance(value, int): - raise error.ValidationError('{} is not of type int'.format(value)) + raise error.ValidationError(f'{value!r} is not of type int') -class String(FieldType): - """Represents a string type.""" +class Float(FieldType): + """Represents a float type.""" - @staticmethod - def ddl() -> str: - return 'STRING(MAX)' + def ddl(self) -> str: + """See base class.""" + del self # Unused. + return 'FLOAT64' - @staticmethod - def grpc_type() -> type_pb2.Type: - return type_pb2.Type(code=type_pb2.STRING) + def grpc_type(self) -> spanner_v1.Type: + """See base class.""" + del self # Unused. + return spanner.param_types.FLOAT64 - @staticmethod - def validate_type(value) -> None: - if not isinstance(value, str): - raise error.ValidationError('{} is not of type str'.format(value)) + def validate_type(self, value: Any) -> None: + """See base class.""" + del self # Unused. + if not isinstance(value, (int, float)): + raise error.ValidationError(f'{value!r} is not of type float') -class StringArray(FieldType): - """Represents an array of strings type.""" +class String(FieldType): + """Represents a string type.""" - @staticmethod - def ddl() -> str: - return 'ARRAY' + def __init__(self, length: Optional[int] = None): + """Initializer. - @staticmethod - def grpc_type() -> type_pb2.Type: - return type_pb2.Type(code=type_pb2.ARRAY) + Args: + length: Length of the String. MAX is used if not specified. + """ + if length is not None and length <= 0: + raise error.ValidationError('String length must be positive') + self._length = length - @staticmethod - def validate_type(value: Any) -> None: - if not isinstance(value, list): - raise error.ValidationError('{} is not of type list'.format(value)) - for item in value: - if not isinstance(item, str): - raise error.ValidationError('{} is not of type str'.format(item)) + def ddl(self) -> str: + """See base class.""" + if self._length is not None: + return f'STRING({self._length})' + return 'STRING(MAX)' + + def grpc_type(self) -> spanner_v1.Type: + """See base class.""" + del self # Unused. + return spanner.param_types.STRING + + def validate_type(self, value: Any) -> None: + """See base class.""" + del self # Unused. + if not isinstance(value, str): + raise error.ValidationError(f'{value!r} is not of type str') class Timestamp(FieldType): """Represents a timestamp type.""" - @staticmethod - def ddl() -> str: + def ddl(self) -> str: + """See base class.""" + del self # Unused. return 'TIMESTAMP' - @staticmethod - def grpc_type() -> type_pb2.Type: - return type_pb2.Type(code=type_pb2.TIMESTAMP) + def grpc_type(self) -> spanner_v1.Type: + """See base class.""" + del self # Unused. + return spanner.param_types.TIMESTAMP - @staticmethod - def validate_type(value: Any) -> None: + def validate_type(self, value: Any) -> None: + """See base class.""" + del self # Unused. if not isinstance(value, datetime.datetime): - raise error.ValidationError('{} is not of type datetime'.format(value)) + raise error.ValidationError(f'{value!r} is not of type datetime') + + +class BytesBase64(FieldType): + """Represents a bytes type that must be base64 encoded.""" + + def __init__(self, length: Optional[int] = None): + """Initializer. + Args: + length: Length of the Bytes. MAX is used if not specified. + """ + if length is not None and length <= 0: + raise error.ValidationError('Bytes length must be positive') + self._length = length -ALL_TYPES = [Boolean, Integer, String, StringArray, Timestamp] + def ddl(self) -> str: + """See base class.""" + if self._length is not None: + return f'BYTES({self._length})' + return 'BYTES(MAX)' + + def grpc_type(self) -> spanner_v1.Type: + """See base class.""" + del self # Unused. + return spanner.param_types.BYTES + + def validate_type(self, value: Any) -> None: + """See base class.""" + del self # Unused. + if not isinstance(value, bytes): + raise error.ValidationError(f'{value!r} is not of type bytes') + # Rudimentary test to check for base64 encoding. + try: + base64.b64decode(value, altchars=None, validate=True) + except binascii.Error: + raise error.ValidationError(f'{value!r} must be base64-encoded bytes.') + + +class Array(FieldType): + """Represents an array type.""" + + def __init__(self, element_type: FieldType): + """Initializer. + + Args: + element_type: Type of the values in the array. Can't be an Array type + itself. + """ + if isinstance(element_type, Array): + # https://cloud.google.com/spanner/docs/reference/standard-sql/data-types#array_type + raise error.SpannerError( + 'Cloud Spanner does not support arrays of arrays.') + self._element_type = element_type + + def ddl(self) -> str: + """See base class.""" + return f'ARRAY<{self._element_type.ddl()}>' + + def grpc_type(self) -> spanner_v1.Type: + """See base class.""" + return spanner.param_types.Array(self._element_type.grpc_type()) + + def validate_type(self, value: Any) -> None: + """See base class.""" + if not isinstance(value, list): + raise error.ValidationError(f'{value!r} is not of type list') + for element in value: + self._element_type.validate_type(element) + + def comparable_with(self, other: FieldType) -> bool: + """See base class.""" + # Running `select [1, 2] = [1, 2];` in Cloud Spanner gives this error: Query + # failed: Equality is not defined for arguments of type ARRAY at line + # 3, column 8 + return False + + +class StringArray(Array): + """Deprecated way to represent an array of strings type.""" + + def __init__(self): + super().__init__(String()) + warnings.warn( + DeprecationWarning('Use Array(String()) instead of StringArray().')) + + +def field_type_from_ddl(ddl: str) -> FieldType: + """Returns the field type for the given DDL expression.""" + if ddl == 'BOOL': + return Boolean() + elif ddl == 'INT64': + return Integer() + elif ddl == 'FLOAT64': + return Float() + elif ddl == 'STRING(MAX)': + return String() + elif (match := re.fullmatch(r'STRING\(([0-9]+)\)', ddl)) is not None: + return String(int(match.group(1))) + elif ddl == 'TIMESTAMP': + return Timestamp() + elif ddl == 'BYTES(MAX)': + return BytesBase64() + elif (match := re.fullmatch(r'BYTES\(([0-9]+)\)', ddl)) is not None: + return BytesBase64(int(match.group(1))) + elif (match := re.fullmatch(r'ARRAY<(.*)>', ddl)) is not None: + return Array(field_type_from_ddl(match.group(1))) + else: + raise error.SpannerError(f'Invalid or unimplemented DDL type: {ddl!r}') diff --git a/spanner_orm/foreign_key_relationship.py b/spanner_orm/foreign_key_relationship.py new file mode 100644 index 0000000..9586a2b --- /dev/null +++ b/spanner_orm/foreign_key_relationship.py @@ -0,0 +1,78 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helps define a foreign key relationship between two models.""" + +from typing import Any, Mapping, Type + +import dataclasses +from spanner_orm import registry + + +@dataclasses.dataclass +class ForeignKeyRelationshipConstraint: + columns: Mapping[str, str] + referenced_table_name: str + referenced_table: Type[Any] + + +class ForeignKeyRelationship: + """Helps define a foreign key relationship between two models.""" + + def __init__(self, referenced_table_name: str, columns: Mapping[str, str]): + """Creates a ForeignKeyRelationship. + + Args: + referenced_table_name: Name of the table which the foreign key references. + columns: Dictionary where the keys are names of columns from the + referencing table and the values are the names of the columns in the + referenced table. + """ + self.origin = None + self.name = None + self._referenced_table_name = referenced_table_name + self._columns = columns + + @property + def constraint(self) -> ForeignKeyRelationshipConstraint: + return self._parse_constraint() + + @property + def ddl(self) -> str: + referencing_columns_ddl = ', '.join(self.constraint.columns.keys()) + referenced_columns_ddl = ', '.join(self.constraint.columns.values()) + return ( + 'CONSTRAINT {fk_name} FOREIGN KEY ({referencing_columns}) REFERENCES' + ' {referenced_table} ({referenced_columns})').format( + fk_name=self.name, + referencing_columns=referencing_columns_ddl, + referenced_table=self.constraint.referenced_table_name, + referenced_columns=referenced_columns_ddl, + ) + + def _parse_constraint(self) -> ForeignKeyRelationshipConstraint: + """Return the relationship constraint.""" + referenced_table = registry.model_registry().get( + self._referenced_table_name) + return ForeignKeyRelationshipConstraint( + self._columns, + referenced_table.table, + referenced_table, + ) + + @property + def single(self) -> bool: + # Spanner enforces uniqueness for values of fields referenced by + # foreign keys, because it creates a unique index on the referenced + # key. + return True diff --git a/spanner_orm/index.py b/spanner_orm/index.py index 807a61d..7702b3a 100644 --- a/spanner_orm/index.py +++ b/spanner_orm/index.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,14 +13,12 @@ # limitations under the License. """Represents an index on a Model.""" -from __future__ import annotations - from typing import List, Optional from spanner_orm import error -class Index(object): +class Index: """Represents an index on a Model.""" PRIMARY_INDEX = 'PRIMARY_KEY' diff --git a/spanner_orm/metadata.py b/spanner_orm/metadata.py index 9caaf95..794a7c6 100644 --- a/spanner_orm/metadata.py +++ b/spanner_orm/metadata.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,24 +27,26 @@ # limitations under the License. """Hold information about a Model extracted from the class attributes.""" -from __future__ import annotations - from typing import Any, Dict, Type, Optional from spanner_orm import error from spanner_orm import field +from spanner_orm import foreign_key_relationship from spanner_orm import index from spanner_orm import registry from spanner_orm import relationship -class ModelMetadata(object): +class ModelMetadata: """Hold information about a Model extracted from the class attributes.""" def __init__(self, table: Optional[str] = None, fields: Optional[Dict[str, field.Field]] = None, relations: Optional[Dict[str, relationship.Relationship]] = None, + foreign_key_relations: Optional[Dict[ + str, + foreign_key_relationship.ForeignKeyRelationship]] = None, indexes: Optional[Dict[str, index.Index]] = None, interleaved: Optional[str] = None, model_class: Optional[Type[Any]] = None): @@ -57,6 +58,7 @@ def __init__(self, self.model_class = model_class self.primary_keys = [] self.relations = dict(relations or {}) + self.foreign_key_relations = dict(foreign_key_relations or {}) self.table = table or '' def finalize(self) -> None: @@ -86,7 +88,7 @@ def finalize(self) -> None: registry.model_registry().register(self.model_class) self._finalized = True - def add_metadata(self, metadata: ModelMetadata) -> None: + def add_metadata(self, metadata: 'ModelMetadata') -> None: self.table = metadata.table or self.table self.fields.update(metadata.fields) self.relations.update(metadata.relations) @@ -103,6 +105,14 @@ def add_relation(self, name: str, new_relation.name = name self.relations[name] = new_relation + def add_foreign_key_relation( + self, + name: str, + new_relation: foreign_key_relationship.ForeignKeyRelationship, + ) -> None: + new_relation.name = name + self.foreign_key_relations[name] = new_relation + def add_index(self, name: str, new_index: index.Index) -> None: new_index.name = name self.indexes[name] = new_index diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 2dd428d..2269fe7 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,8 +13,6 @@ # limitations under the License. """Holds table-specific information to make querying spanner eaiser.""" -from __future__ import annotations - import collections import copy from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union @@ -23,6 +20,7 @@ from spanner_orm import api from spanner_orm import condition from spanner_orm import error +from spanner_orm import foreign_key_relationship from spanner_orm import field from spanner_orm import index from spanner_orm import metadata @@ -31,17 +29,19 @@ from spanner_orm import relationship from spanner_orm import table_apis +from google.api_core import exceptions from google.cloud import spanner from google.cloud.spanner_v1 import transaction as spanner_transaction +T = TypeVar('T') + class ModelMetaclass(type): """Populates ModelMetadata based on class attributes.""" + meta: metadata.ModelMetadata def __new__(mcs, name: str, bases: Any, attrs: Dict[str, Any], **kwargs: Any): parents = [base for base in bases if isinstance(base, ModelMetaclass)] - if not parents: - return super().__new__(mcs, name, bases, attrs, **kwargs) model_metadata = metadata.ModelMetadata() for parent in parents: @@ -60,6 +60,11 @@ def __new__(mcs, name: str, bases: Any, attrs: Dict[str, Any], **kwargs: Any): model_metadata.add_index(key, value) elif isinstance(value, relationship.Relationship): model_metadata.add_relation(key, value) + elif isinstance( + value, + foreign_key_relationship.ForeignKeyRelationship, + ): + model_metadata.add_foreign_key_relation(key, value) else: non_model_attrs[key] = value @@ -74,14 +79,17 @@ def __new__(mcs, name: str, bases: Any, attrs: Dict[str, Any], **kwargs: Any): return cls def __getattr__( - cls, - name: str) -> Union[field.Field, relationship.Relationship, index.Index]: + cls, name: str + ) -> Union[field.Field, relationship.Relationship, + foreign_key_relationship.ForeignKeyRelationship, index.Index]: # Unclear why pylint doesn't like this # pylint: disable=unsupported-membership-test if name in cls.fields: return cls.fields[name] elif name in cls.relations: return cls.relations[name] + elif name in cls.foreign_key_relations: + return cls.foreign_key_relations[name] elif name in cls.indexes: return cls.indexes[name] # pylint: enable=unsupported-membership-test @@ -101,7 +109,7 @@ def indexes(cls) -> Dict[str, index.Index]: return cls.meta.indexes @property - def interleaved(cls) -> Optional[Type[Model]]: + def interleaved(cls) -> Optional[Type['Model']]: if cls.meta.interleaved: return registry.model_registry().get(cls.meta.interleaved) return None @@ -114,6 +122,11 @@ def primary_keys(cls) -> List[str]: def relations(cls) -> Dict[str, relationship.Relationship]: return cls.meta.relations + @property + def foreign_key_relations( + cls) -> Dict[str, foreign_key_relationship.ForeignKeyRelationship]: + return cls.meta.foreign_key_relations + @property def fields(cls) -> Dict[str, field.Field]: return cls.meta.fields @@ -126,20 +139,64 @@ def validate_value(cls, field_name, value, error_type=error.SpannerError): try: cls.fields[field_name].validate(value) except error.ValidationError as ex: - raise error_type(*ex.args) + context = f'Validation error for field {field_name!r}' + raise error_type((f'{context}: {ex.args[0]}' if ex.args else context), + *ex.args[1:]) CallableReturn = TypeVar('CallableReturn') -class ModelApi(metaclass=ModelMetaclass): - """Implements class-level Spanner queries on top of ModelMetaclass. +class Model(metaclass=ModelMetaclass): + """Maps to a table in spanner and has basic functions for querying tables. - Note: all methods in this class should only be called on subclasses of Model - that have associated tables. Violating this will cause an exception to be - raised. + Note: all methods in this class should only be called on subclasses that have + associated tables. Violating this will cause an exception to be raised. """ + def __init__(self, + values: Dict[str, Any], + persisted: bool = False, + skip_validation: bool = False): + start_values = {} + self.__dict__['start_values'] = start_values + self.__dict__['_persisted'] = persisted + + # If the values came from Spanner or validation is explicitly skipped, trust + # them and skip validation + if not persisted and not skip_validation: + # An object is invalid if primary key values are missing + missing_keys = set(self._primary_keys) - set(values.keys()) + if missing_keys: + raise error.SpannerError( + 'All primary keys must be specified. Missing: {keys}'.format( + keys=missing_keys)) + + for column in self._columns: + self._metaclass.validate_value(column, values.get(column), ValueError) + + for column in self._columns: + value = values.get(column) + start_values[column] = copy.copy(value) + self.__dict__[column] = value + + for relation in self._relations: + if relation in values: + self.__dict__[relation] = values[relation] + + for foreign_key_relation in self._foreign_key_relations: + if foreign_key_relation in values: + self.__dict__[foreign_key_relation] = values[foreign_key_relation] + + def __eq__(self, other: Any) -> Union[bool, type(NotImplemented)]: + """Compares objects by their type and attributes.""" + if type(self) != type(other): + return NotImplemented + return self.values == other.values + + def __repr__(self) -> str: + return f'{self.__class__.__qualname__}({self.values!r})' + @classmethod def spanner_api(cls) -> api.SpannerApi: if not cls.table: @@ -149,9 +206,10 @@ def spanner_api(cls) -> api.SpannerApi: # Table read methods @classmethod def all( - cls, - transaction: Optional[spanner_transaction.Transaction] = None - ) -> List[ModelObject]: + cls: Type[T], + *, + transaction: Optional[spanner_transaction.Transaction] = None, + ) -> List[T]: """Returns all objects of this type stored in Spanner. Note: this method should only be called on subclasses of Model that have @@ -170,16 +228,19 @@ def all( return cls._results_to_models(results) @classmethod - def count(cls, transaction: Optional[spanner_transaction.Transaction], - *conditions: condition.Condition) -> int: + def count( + cls, + *conditions: condition.Condition, + transaction: Optional[spanner_transaction.Transaction] = None, + ) -> int: """Returns the number of objects in Spanner that match the given conditions. Args: - transaction: The existing transaction to use, or None to start a new - transaction *conditions: Instances of subclasses of Condition that help specify which rows should be included in the count. The includes condition is not allowed here + transaction: The existing transaction to use, or None to start a new + transaction Returns: The integer result of the COUNT query @@ -190,9 +251,12 @@ def count(cls, transaction: Optional[spanner_transaction.Transaction], return builder.process_results(results) @classmethod - def count_equal(cls, - transaction: Optional[spanner_transaction.Transaction] = None, - **constraints: Any) -> int: + def count_equal( + cls, + *, + transaction: Optional[spanner_transaction.Transaction] = None, + **constraints: Any, + ) -> int: """Returns the number of objects in Spanner that match the given conditions. Convenience method that generates equality conditions based on the keyword @@ -214,12 +278,15 @@ def count_equal(cls, conditions.append(condition.in_list(column, value)) else: conditions.append(condition.equal_to(column, value)) - return cls.count(transaction, *conditions) + return cls.count(*conditions, transaction=transaction) @classmethod - def find(cls, - transaction: Optional[spanner_transaction.Transaction] = None, - **keys: Any) -> Optional[ModelObject]: + def find( + cls: Type[T], + *, + transaction: Optional[spanner_transaction.Transaction] = None, + **keys: Any, + ) -> Optional[T]: """Retrieves an object from Spanner based on the provided key. Args: @@ -232,20 +299,52 @@ def find(cls, Returns: The requested object or None if no such object exists """ - resources = cls.find_multi(transaction, [keys]) + resources = cls.find_multi([keys], transaction=transaction) return resources[0] if resources else None @classmethod - def find_multi(cls, transaction: Optional[spanner_transaction.Transaction], - keys: Iterable[Dict[str, Any]]) -> List[ModelObject]: - """Retrieves objects from Spanner based on the provided keys. + def find_required( + cls: Type[T], + *, + transaction: Optional[spanner_transaction.Transaction] = None, + **keys: Any, + ) -> T: + """Retrieves an object from Spanner based on the provided key. Args: transaction: The existing transaction to use, or None to start a new transaction + **keys: The keys provided are the complete set of primary keys for this + table and the corresponding values make up the unique identifier of the + object being retrieved + + Returns: + The requested object. + + Raises: + exceptions.NotFound: The object wasn't found. + """ + result = cls.find(**keys, transaction=transaction) + if result is None: + raise exceptions.NotFound( + f'{cls.__qualname__} has no object with primary key {keys}') + return result + + @classmethod + def find_multi( + cls: Type[T], + keys: Iterable[Dict[str, Any]], + *, + transaction: Optional[spanner_transaction.Transaction] = None, + ) -> List[T]: + """Retrieves objects from Spanner based on the provided keys. + + Args: keys: An iterable of dictionaries, each dictionary representing the set of primary key values necessary to uniquely identify an object in this table. + transaction: The existing transaction to use, or None to start a new + transaction Returns: A list containing all requested objects that exist in the table (can be @@ -261,15 +360,18 @@ def find_multi(cls, transaction: Optional[spanner_transaction.Transaction], return cls._results_to_models(results) @classmethod - def where(cls, transaction: Optional[spanner_transaction.Transaction], - *conditions: condition.Condition) -> List[ModelObject]: + def where( + cls: Type[T], + *conditions: condition.Condition, + transaction: Optional[spanner_transaction.Transaction] = None, + ) -> List[T]: """Retrieves objects from Spanner based on the provided conditions. Args: - transaction: The existing transaction to use, or None to start a new - transaction *conditions: Instances of subclasses of Condition that help specify which objects should be retrieved + transaction: The existing transaction to use, or None to start a new + transaction Returns: A list containing all requested objects that exist in the table (can be @@ -281,9 +383,12 @@ def where(cls, transaction: Optional[spanner_transaction.Transaction], return builder.process_results(results) @classmethod - def where_equal(cls, - transaction: Optional[spanner_transaction.Transaction] = None, - **constraints: Any) -> List[ModelObject]: + def where_equal( + cls: Type[T], + *, + transaction: Optional[spanner_transaction.Transaction] = None, + **constraints: Any, + ) -> List[T]: """Retrieves objects from Spanner based on the provided constraints. Args: @@ -303,18 +408,23 @@ def where_equal(cls, conditions.append(condition.in_list(column, value)) else: conditions.append(condition.equal_to(column, value)) - return cls.where(transaction, *conditions) + return cls.where(*conditions, transaction=transaction) @classmethod - def _results_to_models(cls, - results: Iterable[Iterable[Any]]) -> List[ModelObject]: + def _results_to_models( + cls: Type[T], + results: Iterable[Iterable[Any]], + ) -> List[T]: items = [dict(zip(cls.columns, result)) for result in results] return [cls(item, persisted=True) for item in items] @classmethod - def _execute_read(cls, db_api: Callable[..., CallableReturn], - transaction: Optional[spanner_transaction.Transaction], - args: List[Any]) -> CallableReturn: + def _execute_read( + cls, + db_api: Callable[..., CallableReturn], + transaction: Optional[spanner_transaction.Transaction], + args: List[Any], + ) -> CallableReturn: if transaction is not None: return db_api(transaction, *args) else: @@ -322,9 +432,12 @@ def _execute_read(cls, db_api: Callable[..., CallableReturn], # Table write methods @classmethod - def create(cls, - transaction: Optional[spanner_transaction.Transaction] = None, - **kwargs: Any) -> None: + def create( + cls, + *, + transaction: Optional[spanner_transaction.Transaction] = None, + **kwargs: Any, + ) -> None: """Creates a row in Spanner based on the provided data. Note: may throw an exception if bad values are provided. @@ -339,49 +452,89 @@ def create(cls, cls._execute_write(table_apis.insert, transaction, [kwargs]) @classmethod - def create_or_update(cls, - transaction: Optional[ - spanner_transaction.Transaction] = None, - **kwargs: Any) -> None: + def create_or_update( + cls, + *, + transaction: Optional[spanner_transaction.Transaction] = None, + **kwargs: Any, + ) -> None: cls._execute_write(table_apis.upsert, transaction, [kwargs]) @classmethod - def delete_batch(cls, transaction: Optional[spanner_transaction.Transaction], - models: List[ModelObject]) -> None: + def _delete_by_keyset( + cls, + transaction: Optional[spanner_transaction.Transaction], + keyset: spanner.KeySet, + ) -> None: + db_api = table_apis.delete + args = [cls.table, keyset] + if transaction is not None: + db_api(transaction, *args) + else: + cls.spanner_api().run_write(db_api, *args) + + @classmethod + def delete_batch( + cls: Type[T], + models: List[T], + *, + transaction: Optional[spanner_transaction.Transaction] = None, + ) -> None: """Deletes rows from Spanner based on the provided models' primary keys. Args: + models: A list of models to be deleted from Spanner. transaction: The existing transaction to use, or None to start a new transaction - models: A list of models to be deleted from Spanner. """ key_list = [] for model in models: key_list.append([getattr(model, column) for column in cls.primary_keys]) - keyset = spanner.KeySet(keys=key_list) + cls._delete_by_keyset( + transaction=transaction, + keyset=spanner.KeySet(keys=key_list), + ) - db_api = table_apis.delete - args = [cls.table, keyset] - if transaction is not None: - db_api(transaction, *args) - else: - cls.spanner_api().run_write(db_api, *args) + @classmethod + def delete_by_key( + cls, + *, + transaction: Optional[spanner_transaction.Transaction] = None, + **keys: Any, + ) -> None: + """Deletes rows from Spanner based on the provided primary key. + + Args: + transaction: The existing transaction to use, or None to start a new + transaction. + **keys: The keys provided are the complete set of primary keys for this + table and the corresponding values make up the unique identifier of the + object being deleted. + """ + cls._delete_by_keyset( + transaction=transaction, + keyset=spanner.KeySet( + keys=[[keys[column] for column in cls.primary_keys]]), + ) @classmethod - def save_batch(cls, - transaction: Optional[spanner_transaction.Transaction], - models: List[ModelObject], - force_write: bool = False) -> None: + def save_batch( + cls: Type[T], + models: List[T], + *, + transaction: Optional[spanner_transaction.Transaction] = None, + force_write: bool = False, + ) -> None: """Writes rows to Spanner based on the provided model data. Args: - transaction: The existing transaction to use, or None to start a new - transaction models: A list of models to be written to Spanner. If the _persisted flag is set, by default we try to issue an UPDATE with values set for all columns in the table. Otherwise, we try to issue an INSERT for all columns in the table. If we try to INSERTa row that already exists (or update one that is missing), an exception will be thrown. + transaction: The existing transaction to use, or None to start a new + transaction force_write: If true, we use UPSERT instead of UPDATE/INSERT, so no exceptions are thrown based on the presence or absence of data in Spanner @@ -401,9 +554,12 @@ def save_batch(cls, cls._execute_write(api_method, transaction, values) @classmethod - def update(cls, - transaction: Optional[spanner_transaction.Transaction] = None, - **kwargs: Any) -> None: + def update( + cls, + *, + transaction: Optional[spanner_transaction.Transaction] = None, + **kwargs: Any, + ) -> None: """Updates a row in Spanner based on the provided data. Args: @@ -417,9 +573,12 @@ def update(cls, cls._execute_write(table_apis.update, transaction, [kwargs]) @classmethod - def _execute_write(cls, db_api: Callable[..., Any], - transaction: Optional[spanner_transaction.Transaction], - dictionaries: Iterable[Dict[str, Any]]) -> None: + def _execute_write( + cls, + db_api: Callable[..., Any], + transaction: Optional[spanner_transaction.Transaction], + dictionaries: Iterable[Dict[str, Any]], + ) -> None: """Validates all write value types and commits write to Spanner.""" columns, values = None, [] for dictionary in dictionaries: @@ -444,36 +603,6 @@ def _execute_write(cls, db_api: Callable[..., Any], else: return cls.spanner_api().run_write(db_api, *args) - -class Model(ModelApi): - """Maps to a table in spanner and has basic functions for querying tables.""" - - def __init__(self, values: Dict[str, Any], persisted: bool = False): - start_values = {} - self.__dict__['start_values'] = start_values - self.__dict__['_persisted'] = persisted - - # If the values came from Spanner, trust them and skip validation - if not persisted: - # An object is invalid if primary key values are missing - missing_keys = set(self._primary_keys) - set(values.keys()) - if missing_keys: - raise error.SpannerError( - 'All primary keys must be specified. Missing: {keys}'.format( - keys=missing_keys)) - - for column in self._columns: - self._metaclass.validate_value(column, values.get(column), ValueError) - - for column in self._columns: - value = values.get(column) - start_values[column] = copy.copy(value) - self.__dict__[column] = value - - for relation in self._relations: - if relation in values: - self.__dict__[relation] = values[relation] - def __setattr__(self, name: str, value: Any) -> None: if name in self._relations: raise AttributeError(name) @@ -484,7 +613,7 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) @property - def _metaclass(self) -> Type[Model]: + def _metaclass(self) -> Type['Model']: return type(self) @property @@ -503,6 +632,11 @@ def _primary_keys(self) -> List[str]: def _relations(self) -> Dict[str, relationship.Relationship]: return self._metaclass.relations + @property + def _foreign_key_relations( + self) -> Dict[str, foreign_key_relationship.ForeignKeyRelationship]: + return self._metaclass.foreign_key_relations + @property def _table(self) -> str: return self._metaclass.table @@ -529,7 +663,11 @@ def changes(self) -> Dict[str, Any]: if values[key] != self.start_values.get(key) } - def delete(self, transaction: spanner_transaction.Transaction = None) -> None: + def delete( + self, + *, + transaction: Optional[spanner_transaction.Transaction] = None, + ) -> None: """Deletes this object from the Spanner database. Args: @@ -558,7 +696,9 @@ def id(self) -> Dict[str, Any]: def reload( self, - transaction: spanner_transaction.Transaction = None) -> Optional[Model]: + *, + transaction: Optional[spanner_transaction.Transaction] = None, + ) -> Optional['Model']: """Refreshes this object with information from Spanner. Args: @@ -570,7 +710,7 @@ def reload( in Spanner, or None if no information was found (object was deleted or never was persisted) """ - updated_object = self._metaclass.find(transaction, **self.id()) + updated_object = self._metaclass.find(transaction=transaction, **self.id()) if updated_object is None: return None start_values = {} @@ -585,7 +725,11 @@ def reload( self._persisted = True return self - def save(self, transaction: spanner_transaction.Transaction = None) -> Model: + def save( + self, + *, + transaction: Optional[spanner_transaction.Transaction] = None, + ) -> 'Model': """Persists this object to Spanner. Note: if the _persisted flag doesn't match whether this object is actually @@ -603,11 +747,8 @@ def save(self, transaction: spanner_transaction.Transaction = None) -> Model: changed_values = self.changes() if changed_values: changed_values.update(self.id()) - self._metaclass.update(transaction, **changed_values) + self._metaclass.update(transaction=transaction, **changed_values) else: - self._metaclass.create(transaction, **self.values) + self._metaclass.create(transaction=transaction, **self.values) self._persisted = True return self - - -ModelObject = TypeVar('ModelObject', bound=Model) diff --git a/spanner_orm/query.py b/spanner_orm/query.py index 75c461e..ac09f9a 100644 --- a/spanner_orm/query.py +++ b/spanner_orm/query.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,13 +14,15 @@ """Helps build SQL for complex Spanner queries.""" import abc -from typing import Any, Dict, Iterable, List, Tuple, Type +from typing import Any, Dict, Generic, Iterable, List, Sequence, Tuple, Type, TypeVar from spanner_orm import condition from spanner_orm import error +ResultType = TypeVar('ResultType') -class SpannerQuery(abc.ABC): + +class SpannerQuery(abc.ABC, Generic[ResultType]): """Helps build SQL for complex Spanner queries.""" def __init__(self, model: Type[Any], @@ -47,7 +48,7 @@ def types(self) -> Dict[str, Any]: return self._types @abc.abstractmethod - def process_results(self, results: List[List[Any]]) -> None: + def process_results(self, results: List[Sequence[Any]]) -> ResultType: pass def _segments(self, @@ -134,7 +135,7 @@ def _limit(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: return (sql, parameters, types) -class CountQuery(SpannerQuery): +class CountQuery(SpannerQuery[int]): """Handles COUNT Spanner queries.""" def __init__(self, model: Type[Any], @@ -148,11 +149,11 @@ def __init__(self, model: Type[Any], def _select(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: return ('SELECT COUNT(*)', {}, {}) - def process_results(self, results: List[List[Any]]) -> int: + def process_results(self, results: List[Sequence[Any]]) -> int: return int(results[0][0]) -class SelectQuery(SpannerQuery): +class SelectQuery(SpannerQuery[List[Type[Any]]]): """Handles SELECT Spanner queries.""" def __init__(self, model: Type[Any], @@ -186,10 +187,10 @@ def _select(self) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: prefix=self._select_prefix(), columns=', '.join(columns)), parameters, types) - def process_results(self, results: List[List[Any]]) -> List[Type[Any]]: + def process_results(self, results: List[Sequence[Any]]) -> List[Type[Any]]: return [self._process_row(result) for result in results] - def _process_row(self, row: List[Any]) -> Type[Any]: + def _process_row(self, row: Sequence[Any]) -> Type[Any]: """Parses a row of results from a Spanner query based on the conditions.""" values = dict(zip(self._model.columns, row)) join_values = row[len(self._model.columns):] diff --git a/spanner_orm/registry.py b/spanner_orm/registry.py index 1c596cc..a1544a3 100644 --- a/spanner_orm/registry.py +++ b/spanner_orm/registry.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,8 +13,6 @@ # limitations under the License. """Registers Model classes so they can be referenced elsewhere.""" -from __future__ import annotations - from typing import Any, Dict, List, Type, Union import dataclasses @@ -30,7 +27,7 @@ def add(self, reference: Type[Any]) -> None: self.references.append(reference) -class Registry(object): +class Registry: def __init__(self): self._registered = {} # type: Dict[str, RegistryComponent] diff --git a/spanner_orm/relationship.py b/spanner_orm/relationship.py index 54c320c..ca3d1c2 100644 --- a/spanner_orm/relationship.py +++ b/spanner_orm/relationship.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,8 +13,6 @@ # limitations under the License. """Helps define a foreign key relationship between two models.""" -from __future__ import annotations - from typing import Any, Dict, List, Type, Union import dataclasses @@ -31,7 +28,7 @@ class RelationshipConstraint: origin_column: str -class Relationship(object): +class Relationship: """Helps define a foreign key relationship between two models.""" def __init__(self, @@ -90,9 +87,8 @@ def _parse_constraints(self) -> List[RelationshipConstraint]: raise error.ValidationError( 'Destination column must be present in destination model') - # TODO(dbrandao): remove when pytype #234 is fixed constraints.append( RelationshipConstraint(self.destination, destination_column, - self.origin, origin_column)) # type: ignore + self.origin, origin_column)) return constraints diff --git a/spanner_orm/table_apis.py b/spanner_orm/table_apis.py index b289f74..ea98410 100644 --- a/spanner_orm/table_apis.py +++ b/spanner_orm/table_apis.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,21 +13,19 @@ # limitations under the License. """Table-level API lambdas for Spanner transactions.""" -from __future__ import annotations - import logging -from typing import Any, Dict, Iterable, List +from typing import Any, Dict, Iterable, List, Sequence from google.cloud import spanner +from google.cloud import spanner_v1 from google.cloud.spanner_v1 import transaction as spanner_transaction -from google.cloud.spanner_v1.proto import type_pb2 _logger = logging.getLogger(__name__) # Read methods def find(transaction: spanner_transaction.Transaction, table_name: str, - columns: Iterable[str], keyset: spanner.KeySet) -> List[Iterable[Any]]: + columns: Iterable[str], keyset: spanner.KeySet) -> List[Sequence[Any]]: """Retrieves rows from the given table based on the provided KeySet. Args: @@ -51,9 +48,12 @@ def find(transaction: spanner_transaction.Transaction, table_name: str, return list(stream_results) -def sql_query(transaction: spanner_transaction.Transaction, query: str, - parameters: Dict[str, Any], - parameter_types: Dict[str, type_pb2.Type]) -> List[Iterable[Any]]: +def sql_query( + transaction: spanner_transaction.Transaction, + query: str, + parameters: Dict[str, Any], + parameter_types: Dict[str, spanner_v1.Type], +) -> List[Sequence[Any]]: """Executes a given SQL query against the Spanner database. This isn't technically read-only, but it's necessary to implement the read- diff --git a/spanner_orm/testlib/__init__.py b/spanner_orm/testlib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/spanner_orm/testlib/spanner_emulator/__init__.py b/spanner_orm/testlib/spanner_emulator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/spanner_orm/testlib/spanner_emulator/emulator.py b/spanner_orm/testlib/spanner_emulator/emulator.py new file mode 100644 index 0000000..55f7a67 --- /dev/null +++ b/spanner_orm/testlib/spanner_emulator/emulator.py @@ -0,0 +1,125 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Python test wrapper for the cloud spanner emulator binary.""" + +import os +import subprocess +import time +from typing import Mapping, Optional + +import portpicker + +from google.auth import credentials +from google.cloud.spanner_v1 import client + +# Environment variable with path to Spanner Emulator binary. +_EMULATOR_BINARY_PATH_ENV_VAR = "SPANNER_EMULATOR_BINARY_PATH" +# Environment variable used by the client library to set the correct URL. +_CLIENT_EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST" + + +class Emulator: + """Spanner emulator python wrapper. + + Below is an example of how this wrapper can be used in a test class. + + class SpannerTest(googletest.TestCase): + + def setUp(self): + super().setUp() + self._spanner_emulator = emulator.Emulator() + self.addCleanup(self._spanner_emulator.stop) + + def test_something(self): + client = self._spanner_emulator.get_client() + # Create tables, add data, and retrieve it + """ + + def __init__(self, + *, + spanner_emulator_port: Optional[int] = None, + log_emulator_requests: bool = False) -> None: + """Initializer. + + Args: + spanner_emulator_port: The port to start the emulator on. A random unused + port is picked if this value is None. + log_emulator_requests: If true, the emulator subprocess will log each + request and response message. + """ + + self._spanner_emulator_port = spanner_emulator_port + self._log_emulator_requests = log_emulator_requests + + self._process = None + self._host_port = None + + self._start() + time.sleep(1) + self._wait_for_ready() + + def get_client( + self, + project: str = "test-project", + client_options: Optional[Mapping[str, str]] = None) -> client.Client: + """Returns a spanner client for interacting with the emulator. + + Args: + project: Name of the project that the client should point to. + client_options: Any client options that the client should be created with. + """ + return client.Client( + project=project, + credentials=credentials.AnonymousCredentials(), + client_options=client_options) + + def _start(self) -> None: + """Starts the emulator as a subprocess.""" + port = self._spanner_emulator_port or portpicker.pick_unused_port() + self._host_port = f"localhost:{port}" + + # Used by the client library to point to the correct spanner endpoint. + os.environ[_CLIENT_EMULATOR_ENV_VAR] = self._host_port + + try: + emulator_binary_path = os.environ[_EMULATOR_BINARY_PATH_ENV_VAR] + except KeyError as key_error: + raise ValueError( + f'Please set the environment variable {_EMULATOR_BINARY_PATH_ENV_VAR} ' + 'to a binary with the Cloud Spanner Emulator. For more info, see ' + 'https://github.com/GoogleCloudPlatform/cloud-spanner-emulator.' + ) from key_error + + self._process = subprocess.Popen([ + emulator_binary_path, + "--log_requests" if self._log_emulator_requests else "--nolog_requests", + "--host_port", + self._host_port, + ]) + + def _wait_for_ready(self) -> None: + """Waits for the emulator to become ready.""" + emulator_client = self.get_client() + + # This will not return until the emulator is running. + for _ in emulator_client.list_instance_configs(): + return + + def stop(self) -> None: + """If there is an emulator process, stops it and waits for it to stop.""" + if self._process is not None: + self._process.terminate() + self._process.wait() + self._process = None + self._host_port = None diff --git a/spanner_orm/testlib/spanner_emulator/testlib.py b/spanner_orm/testlib/spanner_emulator/testlib.py new file mode 100644 index 0000000..1f7ca10 --- /dev/null +++ b/spanner_orm/testlib/spanner_emulator/testlib.py @@ -0,0 +1,125 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Superclass and helpers for tests that use the spanner emulator.""" + +import os +import unittest +import uuid + +import spanner_orm + +from google.cloud.spanner_v1 import client +from google.cloud.spanner_v1 import database +from google.cloud.spanner_v1 import instance +from spanner_orm.testlib.spanner_emulator import emulator + + +def _make_emulator_spanner_orm_connection( + db: database.Database, inst: instance.Instance, + spanner_client: client.Client) -> spanner_orm.SpannerConnection: + """Returns an spanner_orm.connection to a spanner database. + + Args: + db: database that already exists in spanner + inst: instance that already exists in spanner + spanner_client: client with access to the database and instance provided + """ + # project/project-name -> project-name. + project_name = spanner_client.project_name.split('/')[1] + return spanner_orm.SpannerConnection( + inst.instance_id, + db.database_id, + project=project_name, + credentials=spanner_client.credentials) + + +def _get_instance(spanner_client: client.Client) -> instance.Instance: + """Returns a spanner instance from the client. + + First, checks if there is an existing instance that can be re-used, returning + it if one exists. Otherwise, create a new instance, waits for it to be created + and then returns it. + + Args: + spanner_client: An initialized spanner client. + """ + existing_instances_pb = list(spanner_client.list_instances()) + if existing_instances_pb: + return instance.Instance.from_pb(existing_instances_pb[0], spanner_client) + + # The emulator has one default config. + config = list(spanner_client.list_instance_configs())[0] + inst = spanner_client.instance( + 'spanner-instance-name', configuration_name=config.name) + inst.create().result() + return inst + + +def _migrate_database_at_connection(connection: spanner_orm.SpannerConnection, + migrations_dir: str) -> None: + """Applies the migrations to the provided connection.""" + spanner_orm.from_connection(connection) + executor = spanner_orm.MigrationExecutor(connection, basedir=migrations_dir) + executor.migrate() + + +def _database_id() -> str: + """Returns a new database ID that's unlikely to conflict with any other.""" + random_string = str(uuid.uuid4()).split('-')[0] + return 'spanner-db-' + random_string + + +class TestCase(unittest.TestCase): + """Sets up a spanner emulator database for each test case. + + Any test class that subclasses this class will have a spanner database + setup for it. That database is empty by default; most subclasses will want to + call run_orm_migrations() in their setUp() method. + + Attributes: + spanner_emulator_client: Client, for use by non-ORM tests. + spanner_emulator_instance: Instance, for use by non-ORM tests. + spanner_emulator_database: Database, for use by non-ORM tests. + """ + + _spanner_emulator: emulator.Emulator + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._spanner_emulator = emulator.Emulator() + + def setUp(self): + super().setUp() + self.spanner_emulator_client = self._spanner_emulator.get_client() + self.spanner_emulator_instance = _get_instance(self.spanner_emulator_client) + self.spanner_emulator_database = self.spanner_emulator_instance.database( + _database_id()) + self.spanner_emulator_database.create().result() + + @classmethod + def tearDownClass(cls): + cls._spanner_emulator.stop() + super().tearDownClass() + + def run_orm_migrations(self, migrations_folder: str) -> None: + """Runs ORM migrations in the given directory and connects the ORM.""" + connection = _make_emulator_spanner_orm_connection( + self.spanner_emulator_database, self.spanner_emulator_instance, + self.spanner_emulator_client) + _migrate_database_at_connection(connection, migrations_folder) + # spanner_orm closes the connection to Spanner after migrating so we need to + # reconnect before making other Spanner calls. + spanner_orm.from_connection(connection) + spanner_orm.from_admin_connection(connection) diff --git a/spanner_orm/tests/admin_test.py b/spanner_orm/tests/admin_test.py index 9700634..b28ce7c 100644 --- a/spanner_orm/tests/admin_test.py +++ b/spanner_orm/tests/admin_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -109,8 +108,8 @@ def test_metadata(self, tables, columns, index_columns, indexes): self.assertEqual(meta.table, model.table) self.assertEqual(meta.columns, model.columns) for row in model.columns: - self.assertEqual(meta.fields[row].field_type(), - model.fields[row].field_type()) + self.assertEqual(meta.fields[row].field_type().ddl(), + model.fields[row].field_type().ddl()) self.assertEqual(meta.fields[row].nullable(), model.fields[row].nullable()) self.assertEqual(meta.primary_keys, model.primary_keys) diff --git a/spanner_orm/tests/api_test.py b/spanner_orm/tests/api_test.py index 86bdc91..bf2d9a1 100644 --- a/spanner_orm/tests/api_test.py +++ b/spanner_orm/tests/api_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,19 +14,76 @@ import logging import unittest from unittest import mock +import warnings + +from absl.testing import parameterized +from google.api_core import exceptions +from google.cloud import spanner from spanner_orm import api from spanner_orm import error from spanner_orm.admin import api as admin_api -class ApiTest(unittest.TestCase): +def _mock_run_in_transaction(method, *args, **kwargs): + return method(*args, **kwargs) + + +class MockSpannerApi(api.SpannerReadApi, api.SpannerWriteApi): + + def __init__(self): + self.connection_mock = mock.MagicMock() + self.connection_mock.run_in_transaction.side_effect = _mock_run_in_transaction + + @property + def _connection(self): + return self.connection_mock + + +class ApiTest(parameterized.TestCase): + + @mock.patch.object(spanner, 'Client', autospec=True, spec_set=True) + def test_connection_args(self, client): + client.return_value.instance.return_value.database.return_value = ( + 'fake-database') + connection = api.SpannerConnection( + instance='some-instance', + database='some-database', + project='some-project', + credentials='fake-credentials', + pool='fake-pool', + create_ddl=('fake-ddl',), + client_options=dict(fake='options'), + ) + self.assertEqual('fake-database', connection.database) + self.assertSequenceEqual( + ( + mock.call( + project='some-project', + credentials='fake-credentials', + client_options=dict(fake='options'), + ), + mock.call().instance('some-instance'), + mock.call().instance().database( + 'some-database', + pool='fake-pool', + ddl_statements=('fake-ddl',), + ), + ), + client.mock_calls, + ) @mock.patch('google.cloud.spanner.Client') def test_api_connection(self, client): connection = self.mock_connection(client) - api.connect('', '', '') + with warnings.catch_warnings(record=True) as connect_warnings: + api.connect('', '', '') self.assertEqual(api.spanner_api()._connection, connection) + self.assertLen(connect_warnings, 1) + connect_warning, = connect_warnings + self.assertIn('spanner_orm.from_connection', str(connect_warning.message)) + self.assertIs(DeprecationWarning, connect_warning.category) + self.assertEqual(api.__file__, connect_warning.filename) api.hangup() with self.assertRaises(error.SpannerError): @@ -40,8 +96,15 @@ def test_api_error_when_not_connected(self): @mock.patch('google.cloud.spanner.Client') def test_admin_api_connection(self, client): connection = self.mock_connection(client) - admin_api.connect('', '', '') + with warnings.catch_warnings(record=True) as connect_warnings: + admin_api.connect('', '', '') self.assertEqual(admin_api.spanner_admin_api()._connection, connection) + self.assertLen(connect_warnings, 1) + connect_warning, = connect_warnings + self.assertIn('spanner_orm.from_admin_connection', + str(connect_warning.message)) + self.assertIs(DeprecationWarning, connect_warning.category) + self.assertEqual(admin_api.__file__, connect_warning.filename) admin_api.hangup() with self.assertRaises(error.SpannerError): @@ -53,6 +116,36 @@ def test_admin_api_create_ddl_connection(self, client): admin_api.connect('', '', '', create_ddl=['create ddl']) self.assertEqual(admin_api.spanner_admin_api()._connection, connection) + @parameterized.parameters('run_read_only', 'run_write') + @mock.patch('spanner_orm.api.spanner_api') + def test_reconnect_on_expected_error(self, api_method, mock_spanner_api): + mock_api = MockSpannerApi() + + mock_method = mock.Mock() + mock_method.side_effect = [ + exceptions.NotFound('Session not found'), + 'Anything other than an exception' + ] + mock_connect = mock_spanner_api.return_value.spanner_connection.connect + + getattr(mock_api, api_method)(mock_method) + + mock_connect.assert_called_once() + mock_method.assert_called() + + @parameterized.parameters('run_read_only', 'run_write') + @mock.patch('spanner_orm.api.spanner_api') + def test_raise_on_expected_error(self, api_method, mock_spanner_api): + mock_api = MockSpannerApi() + + mock_method = mock.Mock() + mock_method.side_effect = exceptions.NotFound('Database not found') + + with self.assertRaises(exceptions.NotFound): + getattr(mock_api, api_method)(mock_method) + + mock_method.assert_called() + def mock_connection(self, client): connection = mock.Mock() client().instance().database.return_value = connection diff --git a/spanner_orm/tests/condition_test.py b/spanner_orm/tests/condition_test.py new file mode 100644 index 0000000..5fdd5bb --- /dev/null +++ b/spanner_orm/tests/condition_test.py @@ -0,0 +1,349 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for spanner_orm.condition.""" + +import datetime +import decimal +import logging +import os +import unittest + +from absl.testing import parameterized +from google.api_core import datetime_helpers +from google.cloud import spanner +from google.cloud import spanner_v1 + +import spanner_orm +from spanner_orm import condition +from spanner_orm import error +from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib +from spanner_orm.tests import models + + +class ConditionTest( + spanner_emulator_testlib.TestCase, + parameterized.TestCase, +): + + def setUp(self): + super().setUp() + self.run_orm_migrations( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'migrations_for_emulator_test', + )) + + @parameterized.parameters( + (True, spanner_v1.param_types.BOOL), + (0, spanner_v1.param_types.INT64), + (0.0, spanner_v1.param_types.FLOAT64), + ( + datetime_helpers.DatetimeWithNanoseconds(2021, 1, 5), + spanner_v1.param_types.TIMESTAMP, + ), + (datetime.datetime(2021, 1, 5), spanner_v1.param_types.TIMESTAMP), + (datetime.date(2021, 1, 5), spanner_v1.param_types.DATE), + (b'\x01', spanner_v1.param_types.BYTES), + ('foo', spanner_v1.param_types.STRING), + (decimal.Decimal('1.23'), spanner_v1.param_types.NUMERIC), + ( + (0, 1), + spanner_v1.Type( + code=spanner_v1.TypeCode.ARRAY, + array_element_type=spanner_v1.param_types.INT64, + ), + ), + ( + ['a', None, 'b'], + spanner_v1.Type( + code=spanner_v1.TypeCode.ARRAY, + array_element_type=spanner_v1.param_types.STRING, + ), + ), + ) + def test_param_from_value(self, value, expected_type): + param = condition.Param.from_value(value) + self.assertEqual(expected_type, param.type) + # Test that the value and inferred type are compatible. This will raise an + # exception if they're not. + self.assertEmpty( + models.SmallTestModel.where( + condition.ArbitraryCondition( + '$param IS NULL', + dict(param=param), + segment=condition.Segment.WHERE, + ))) + + @parameterized.parameters( + (None, 'Cannot infer type of None'), + ((0, 'some-string'), 'elements of exactly one type'), + ((0, 'some-string', None), 'elements of exactly one type'), + (object(), 'Unknown type'), + ) + def test_param_from_value_error(self, value, error_regex): + with self.assertRaisesRegex(TypeError, error_regex): + condition.Param.from_value(value) + + @parameterized.named_parameters( + ( + 'bytes', + condition.ArbitraryCondition( + '$param = b"\x01\x02"', + dict(param=condition.Param.from_value(b'\x01\x02')), + segment=condition.Segment.WHERE, + ), + ), + ( + 'array_of_bytes', + condition.ArbitraryCondition( + '${param}[OFFSET(0)] = b"\x01\x02"', + dict(param=condition.Param.from_value([b'\x01\x02'])), + segment=condition.Segment.WHERE, + ), + ), + ( + 'array_of_bytes_and_null', + condition.ArbitraryCondition( + '${param}[OFFSET(0)] IS NULL', + dict(param=condition.Param.from_value((None, b'\x01\x02'))), + segment=condition.Segment.WHERE, + ), + ), + ) + def test_param_from_value_correctly_encodes(self, tautology): + test_model = models.SmallTestModel( + dict( + key='some-key', + value_1='some-value', + value_2='other-value', + )) + test_model.save() + self.assertCountEqual((test_model,), models.SmallTestModel.where(tautology)) + + @parameterized.named_parameters( + ( + 'minimal', + condition.ArbitraryCondition( + 'FALSE', + segment=condition.Segment.WHERE, + ), + {}, + {}, + 'FALSE', + (), + ), + ( + 'full', + condition.ArbitraryCondition( + '$key = IF($true_param, ${key_param}, $value_1)', + dict( + key=models.SmallTestModel.key, + true_param=condition.Param.from_value(True), + key_param=condition.Param.from_value('some-key'), + value_1=condition.Column('value_1'), + ), + segment=condition.Segment.WHERE, + ), + dict( + true_param0=True, + key_param0='some-key', + ), + dict( + true_param0=spanner_v1.param_types.BOOL, + key_param0=spanner_v1.param_types.STRING, + ), + ('SmallTestModel.key = ' + 'IF(@true_param0, @key_param0, SmallTestModel.value_1)'), + ('some-key',), + ), + ) + def test_arbitrary_condition( + self, + condition_, + expected_params, + expected_types, + expected_sql, + expected_row_keys, + ): + models.SmallTestModel( + dict( + key='some-key', + value_1='some-value', + value_2='other-value', + )).save() + rows = models.SmallTestModel.where(condition_) + self.assertEqual(expected_params, condition_.params()) + self.assertEqual(expected_types, condition_.types()) + self.assertEqual(expected_sql, condition_.sql()) + self.assertCountEqual(expected_row_keys, tuple(row.key for row in rows)) + + @parameterized.named_parameters( + ('key_not_found', '$not_found', KeyError, 'not_found'), + ('invalid_template', '$', ValueError, 'Invalid placeholder'), + ) + def test_arbitrary_condition_template_error( + self, + template, + error_class, + error_regex, + ): + with self.assertRaisesRegex(error_class, error_regex): + condition.ArbitraryCondition(template, segment=condition.Segment.WHERE) + + @parameterized.named_parameters( + ( + 'field_from_wrong_model', + models.ChildTestModel.key, + 'does not belong to the Model', + ), + ( + 'column_not_found', + condition.Column('not_found'), + 'does not exist in the Model', + ), + ) + def test_arbitrary_condition_validation_error( + self, + substitution, + error_regex, + ): + condition_ = condition.ArbitraryCondition( + '$substitution', + dict(substitution=substitution), + segment=condition.Segment.WHERE, + ) + with self.assertRaisesRegex(error.ValidationError, error_regex): + models.SmallTestModel.where(condition_) + + @parameterized.named_parameters( + ( + 'empty_or', + condition.OrCondition(), + {}, + {}, + 'FALSE', + '', + ), + ( + 'empty_and', + condition.OrCondition([]), + {}, + {}, + '(TRUE)', + 'ab', + ), + ( + 'single', + condition.OrCondition( + [condition.equal_to(models.SmallTestModel.key, 'a')]), + dict(key0='a'), + dict(key0=spanner_v1.param_types.STRING), + '((SmallTestModel.key = @key0))', + 'a', + ), + ( + 'multiple', + condition.OrCondition( + [ + condition.equal_to(models.SmallTestModel.key, 'a'), + condition.equal_to(models.SmallTestModel.value_1, 'a'), + ], + [ + condition.equal_to(models.SmallTestModel.key, 'b'), + condition.equal_to(models.SmallTestModel.value_1, 'b'), + ], + ), + dict( + key0='a', + value_11='a', + key2='b', + value_13='b', + ), + dict( + key0=spanner_v1.param_types.STRING, + value_11=spanner_v1.param_types.STRING, + key2=spanner_v1.param_types.STRING, + value_13=spanner_v1.param_types.STRING, + ), + ('(' + '(SmallTestModel.key = @key0 AND SmallTestModel.value_1 = @value_11)' + ' OR ' + '(SmallTestModel.key = @key2 AND SmallTestModel.value_1 = @value_13)' + ')'), + 'ab', + ), + ) + def test_or_condition( + self, + condition_, + expected_params, + expected_types, + expected_sql, + expected_row_keys, + ): + models.SmallTestModel(dict(key='a', value_1='a', value_2='a')).save() + models.SmallTestModel(dict(key='b', value_1='b', value_2='b')).save() + rows = models.SmallTestModel.where(condition_) + self.assertEqual(expected_params, condition_.params()) + self.assertEqual(expected_types, condition_.types()) + self.assertEqual(expected_sql, condition_.sql()) + self.assertCountEqual(expected_row_keys, tuple(row.key for row in rows)) + + @parameterized.parameters( + ('ABCD', 'BC', True), + ('ABCD', 'bc', False), + ('ABCD', 'CB', False), + (b'ABCD', b'BC', True), + (b'ABCD', b'bc', False), + (b'ABCD', b'CB', False), + ('ABCD', 'BC', True, dict(case_sensitive=False)), + ('ABCD', 'bc', True, dict(case_sensitive=False)), + ('ABCD', 'CB', False, dict(case_sensitive=False)), + (b'ABCD', b'BC', True, dict(case_sensitive=False)), + (b'ABCD', b'bc', True, dict(case_sensitive=False)), + (b'ABCD', b'CB', False, dict(case_sensitive=False)), + ) + def test_contains( + self, + haystack, + needle, + expect_results, + kwargs={}, + ): + test_model = models.SmallTestModel(dict(key='a', value_1='a', value_2='a')) + test_model.save() + self.assertCountEqual( + ((test_model,) if expect_results else ()), + models.SmallTestModel.where( + spanner_orm.contains( + condition.Param.from_value(haystack), + condition.Param.from_value(needle), + **kwargs, + )), + ) + + def test_force_null_filtered_index(self): + non_null_model = models.NullFilteredIndexModel( + dict(key='a', value_1='a', value_2=1)) + non_null_model.save() + models.NullFilteredIndexModel(dict(key='b', value_1=None, value_2=2)).save() + self.assertCountEqual((non_null_model,), + models.NullFilteredIndexModel.where( + *spanner_orm.force_null_filtered_index( + models.NullFilteredIndexModel.value_index))) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main() diff --git a/spanner_orm/tests/conftest.py b/spanner_orm/tests/conftest.py new file mode 100644 index 0000000..8b8ed63 --- /dev/null +++ b/spanner_orm/tests/conftest.py @@ -0,0 +1,26 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pytest fixture.""" + +import sys + +from absl import flags +import pytest + + +@pytest.fixture(scope='session', autouse=True) +def parse_flags(): + # Only pass the first item, because pytest flags shouldn't be parsed as absl + # flags. + flags.FLAGS(sys.argv[:1]) diff --git a/spanner_orm/tests/decorator_test.py b/spanner_orm/tests/decorator_test.py index ba30d25..d41cee7 100644 --- a/spanner_orm/tests/decorator_test.py +++ b/spanner_orm/tests/decorator_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -36,16 +35,16 @@ def test_transactional_injects_new_transaction(self, decorator_in_test, mock_api_method.side_effect = mock_spanner_method(mock_tx) @decorator_in_test - def get_book(book_id, genre=None, transaction=None): + def get_book(book_id, method=None, transaction=None): self.assertEqual(mock_tx, transaction) self.assertEqual(123, book_id) - self.assertEqual('horror', genre) + self.assertEqual('library', method) return 200 - result = get_book(123, genre='horror') + result = get_book(123, method='library') self.assertEqual(200, result) - mock_api_method.assert_called_once_with(mock.ANY, 123, genre='horror') + mock_api_method.assert_called_once() @parameterized.parameters(decorator.transactional_read, decorator.transactional_write) @@ -53,14 +52,14 @@ def test_transactional_uses_given_transaction(self, decorator_in_test): mock_tx = mock.Mock() @decorator_in_test - def get_book(book_id, genre=None, transaction=None): + def get_book(book_id, method=None, transaction=None): self.assertEqual(mock_tx, transaction) self.assertEqual(123, book_id) - self.assertEqual('horror', genre) + self.assertEqual('library', method) return 200 - result = get_book(123, genre='horror', transaction=mock_tx) + result = get_book(123, method='library', transaction=mock_tx) self.assertEqual(200, result) diff --git a/spanner_orm/tests/field_test.py b/spanner_orm/tests/field_test.py new file mode 100644 index 0000000..36da5aa --- /dev/null +++ b/spanner_orm/tests/field_test.py @@ -0,0 +1,182 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for field.""" + +import base64 +import datetime +from typing import Any +import warnings + +from absl.testing import absltest +from absl.testing import parameterized +from google.cloud import spanner +from google.cloud import spanner_v1 +from spanner_orm import error +from spanner_orm import field + + +class FieldTest(parameterized.TestCase): + + @parameterized.parameters( + (field.Boolean(), 'BOOL'), + (field.Integer(), 'INT64'), + (field.Float(), 'FLOAT64'), + (field.String(), 'STRING(MAX)'), + (field.String(10), 'STRING(10)'), + (field.Timestamp(), 'TIMESTAMP'), + (field.BytesBase64(), 'BYTES(MAX)'), + (field.BytesBase64(10), 'BYTES(10)'), + (field.Array(field.Boolean()), 'ARRAY'), + (field.Array(field.String()), 'ARRAY'), + (field.Array(field.String(10)), 'ARRAY'), + (field.Array(field.BytesBase64()), 'ARRAY'), + (field.Array(field.BytesBase64(10)), 'ARRAY'), + ) + def test_field_type_ddl( + self, + field_type: field.FieldType, + ddl: str, + ): + self.assertEqual(field_type.ddl(), ddl) + + @parameterized.parameters( + (field.Boolean(), spanner.param_types.BOOL), + (field.Integer(), spanner.param_types.INT64), + (field.Float(), spanner.param_types.FLOAT64), + (field.String(), spanner.param_types.STRING), + (field.String(10), spanner.param_types.STRING), + (field.Timestamp(), spanner.param_types.TIMESTAMP), + (field.BytesBase64(), spanner.param_types.BYTES), + (field.BytesBase64(10), spanner.param_types.BYTES), + (field.Array(field.Boolean()), + spanner.param_types.Array(spanner.param_types.BOOL)), + (field.Array(field.String()), + spanner.param_types.Array(spanner.param_types.STRING)), + (field.Array(field.String(10)), + spanner.param_types.Array(spanner.param_types.STRING)), + ) + def test_field_type_grpc_type( + self, + field_type: field.FieldType, + grpc_type: spanner_v1.Type, + ): + self.assertEqual(field_type.grpc_type(), grpc_type) + + @parameterized.parameters( + (field.Boolean(), True), + (field.Integer(), 1), + (field.Float(), 1), + (field.Float(), 1.0), + (field.String(), 'foo'), + (field.String(10), 'foo'), + (field.Timestamp(), datetime.datetime(2022, 9, 21)), + (field.BytesBase64(), base64.b64encode(b'\x00')), + (field.BytesBase64(10), base64.b64encode(b'\x00')), + (field.Array(field.Boolean()), [True]), + ) + def test_field_type_validate_type_ok( + self, + field_type: field.FieldType, + value: Any, + ): + field_type.validate_type(value) + + @parameterized.parameters( + (field.Boolean(), 1), + (field.Integer(), 1.0), + (field.Float(), '1.0'), + (field.String(), b'foo'), + (field.String(10), b'foo'), + (field.Timestamp(), datetime.date(2022, 9, 21)), + (field.BytesBase64(), base64.b64encode(b'\x00').decode('utf-8')), + (field.BytesBase64(), b'!'), + (field.BytesBase64(10), b'!'), + (field.Array(field.Boolean()), {True}), + (field.Array(field.Boolean()), [1]), + ) + def test_field_type_validate_type_error( + self, + field_type: field.FieldType, + value: Any, + ): + with self.assertRaises(error.ValidationError): + field_type.validate_type(value) + + @parameterized.parameters( + (field.Boolean(), field.Boolean(), True), + (field.Boolean(), field.String(), False), + (field.String(10), field.String(20), True), + (field.String(), field.String(10), True), + (field.Array(field.Integer()), field.Array(field.Integer()), False), + (field.Array(field.Integer()), field.Integer(), False), + ) + def test_field_type_comparable_with( + self, + field_type_1: field.FieldType, + field_type_2: field.FieldType, + expected_comparable: bool, + ): + self.assertEqual( + field_type_1.comparable_with(field_type_2), expected_comparable) + self.assertEqual( + field_type_2.comparable_with(field_type_1), expected_comparable) + + def test_field_field_type_is_class(self): + with warnings.catch_warnings(record=True) as actual_warnings: + self.assertIsInstance( + field.Field(field.String).field_type(), field.String) + self.assertLen(actual_warnings, 1) + self.assertIn('instance of FieldType', str(actual_warnings[0].message)) + self.assertIs(actual_warnings[0].category, DeprecationWarning) + + def test_array_of_array_is_invalid(self): + with self.assertRaisesRegex(error.SpannerError, 'arrays of arrays'): + field.Array(field.Array(field.String())) + + def test_string_array_is_deprecated_and_equivalent_to_array_of_string(self): + with warnings.catch_warnings(record=True) as actual_warnings: + string_array = field.StringArray() + array_of_string = field.Array(field.String()) + self.assertLen(actual_warnings, 1) + self.assertIn('Use Array(String()) instead', + str(actual_warnings[0].message)) + self.assertIs(actual_warnings[0].category, DeprecationWarning) + self.assertEqual(string_array.ddl(), array_of_string.ddl()) + self.assertEqual(string_array.grpc_type(), array_of_string.grpc_type()) + + @parameterized.parameters( + 'BOOL', + 'INT64', + 'FLOAT64', + 'STRING(MAX)', + 'STRING(10)', + 'TIMESTAMP', + 'BYTES(MAX)', + 'BYTES(10)', + 'ARRAY', + 'ARRAY', + 'ARRAY', + ) + def test_ddl_to_field_type_to_ddl(self, ddl: str): + self.assertEqual(field.field_type_from_ddl(ddl).ddl(), ddl) + + @parameterized.parameters('UNICORN(MAX)', 'STRING(MAX1)', 'STRING(MIN)', + 'ARRAY', 'BYTES(MAX1)', 'BYTES(MIN)') + def test_field_type_from_ddl_invalid(self, ddl: str): + with self.assertRaisesRegex(error.SpannerError, 'DDL type'): + field.field_type_from_ddl(ddl) + + +if __name__ == '__main__': + absltest.main() diff --git a/spanner_orm/tests/metadata_test.py b/spanner_orm/tests/metadata_test.py index 8d6b46f..cd1f66d 100644 --- a/spanner_orm/tests/metadata_test.py +++ b/spanner_orm/tests/metadata_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging import unittest diff --git a/spanner_orm/tests/migrations/test_1_4a7a7dee0718.py b/spanner_orm/tests/migrations/test_1_4a7a7dee0718.py index cc39fea..29078a8 100644 --- a/spanner_orm/tests/migrations/test_1_4a7a7dee0718.py +++ b/spanner_orm/tests/migrations/test_1_4a7a7dee0718.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/migrations/test_2_5c078bbb4d43.py b/spanner_orm/tests/migrations/test_2_5c078bbb4d43.py index 6a347b0..31f6a4d 100644 --- a/spanner_orm/tests/migrations/test_2_5c078bbb4d43.py +++ b/spanner_orm/tests/migrations/test_2_5c078bbb4d43.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/migrations/test_3_eceb25f170dd.py b/spanner_orm/tests/migrations/test_3_eceb25f170dd.py index 44f7152..8510398 100644 --- a/spanner_orm/tests/migrations/test_3_eceb25f170dd.py +++ b/spanner_orm/tests/migrations/test_3_eceb25f170dd.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/spanner_orm/tests/migrations_emulator_test.py b/spanner_orm/tests/migrations_emulator_test.py new file mode 100644 index 0000000..96978bd --- /dev/null +++ b/spanner_orm/tests/migrations_emulator_test.py @@ -0,0 +1,241 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import logging +import os +import textwrap +from typing import Iterable, Type +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +import spanner_orm +from spanner_orm.admin import metadata +from spanner_orm.tests import models +from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib + +from google.api_core import exceptions as google_api_exceptions + + +class MigrationsEmulatorTest(spanner_emulator_testlib.TestCase): + """Basic tests using generic migrations.""" + + TEST_MIGRATIONS_DIR = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'migrations_for_emulator_test', + ) + + def setUp(self): + super().setUp() + self.run_orm_migrations(self.TEST_MIGRATIONS_DIR) + + def test_basic(self): + models.SmallTestModel({'key': 'key', 'value_1': 'value'}).save() + self.assertEqual( + [x.values for x in models.SmallTestModel.all()], + [{ + 'key': 'key', + 'value_1': 'value', + 'value_2': None, + }], + ) + + def test_error_with_missing_referencing_key(self): + with self.assertRaisesRegex( + google_api_exceptions.FailedPrecondition, + 'Cannot find referenced key', + ): + models.ForeignKeyTestModel({ + 'referencing_key_1': 'key', + 'referencing_key_2': 'key', + 'referencing_key_3': 42, + 'value': 'value' + }).save() + + def test_key(self): + models.SmallTestModel({'key': 'key', 'value_1': 'value'}).save() + models.UnittestModel({ + 'string': 'string', + 'int_': 42, + 'float_': 4.2, + 'bytes_': b'A1A1', + 'timestamp': datetime.datetime.now(tz=datetime.timezone.utc), + }).save() + models.ForeignKeyTestModel({ + 'referencing_key_1': 'key', + 'referencing_key_2': 'string', + 'referencing_key_3': 42, + 'value': 'value' + }).save() + + +class SpecificMigrationsEmulatorTest( + parameterized.TestCase, + spanner_emulator_testlib.TestCase, +): + """Tests of specific migrations.""" + + def setUp(self): + super().setUp() + self._migrations_dir = self.create_tempdir() + self._migration_index = None + + def _append_migrations(self, *migrations: str) -> None: + """Appends migrations to the sequence of migrations in self._migrations_dir. + + Args: + *migrations: Each string is the python code to define a single upgrade() + function. Leading indentation is stripped and migration boilerplate is + added. + """ + for migration in migrations: + if self._migration_index is None: + prev_migration_id = None + self._migration_index = 0 + else: + prev_migration_id = str(self._migration_index) + self._migration_index += 1 + migration_id = str(self._migration_index) + self._migrations_dir.create_file( + f'migration_{migration_id}.py', + '\n'.join(( + 'import spanner_orm', + f'migration_id = {migration_id!r}', + f'prev_migration_id = {prev_migration_id!r}', + textwrap.dedent(migration), + 'def downgrade(): raise NotImplementedError()', + )), + ) + + def test_drop_interleaved_table(self): + self._append_migrations( + """ + class _Parent(spanner_orm.Model): + __table__ = 'Parent' + parent_key = spanner_orm.Field( + spanner_orm.String(), primary_key=True) + + def upgrade(): + return spanner_orm.CreateTable(_Parent) + """, + """ + class _Parent(spanner_orm.Model): + __table__ = 'Parent' + parent_key = spanner_orm.Field( + spanner_orm.String(), primary_key=True) + + class _Child(spanner_orm.Model): + __table__ = 'Child' + __interleaved__ = _Parent + parent_key = spanner_orm.Field( + spanner_orm.String(), primary_key=True) + child_key = spanner_orm.Field( + spanner_orm.String(), primary_key=True) + + def upgrade(): + return spanner_orm.CreateTable(_Child) + """, + """ + def upgrade(): + return spanner_orm.DropTable('Child') + """, + ) + self.run_orm_migrations(self._migrations_dir) + self.assertCountEqual( + ('Parent',), + metadata.SpannerMetadata.tables().keys() - {'spanner_orm_migrations'}, + ) + + @parameterized.named_parameters( + dict( + testcase_name='does_not_exist', + create_migrations=(), + error_class=google_api_exceptions.NotFound, + ), + dict( + testcase_name='has_secondary_index', + create_migrations=( + """ + class _TableToDrop(spanner_orm.Model): + __table__ = 'TableToDrop' + key = spanner_orm.Field( + spanner_orm.String(), primary_key=True) + value = spanner_orm.Field(spanner_orm.String()) + + def upgrade(): + return spanner_orm.CreateTable(_TableToDrop) + """, + """ + def upgrade(): + return spanner_orm.CreateIndex( + table_name='TableToDrop', + index_name='value_index', + columns=['value'], + ) + """, + ), + error_class=google_api_exceptions.FailedPrecondition, + ), + dict( + testcase_name='has_interleaved_child', + create_migrations=( + """ + class _TableToDrop(spanner_orm.Model): + __table__ = 'TableToDrop' + parent_key = spanner_orm.Field( + spanner_orm.String(), primary_key=True) + + def upgrade(): + return spanner_orm.CreateTable(_TableToDrop) + """, + """ + class _TableToDrop(spanner_orm.Model): + __table__ = 'TableToDrop' + parent_key = spanner_orm.Field( + spanner_orm.String(), primary_key=True) + + class _Child(spanner_orm.Model): + __table__ = 'Child' + __interleaved__ = _TableToDrop + parent_key = spanner_orm.Field( + spanner_orm.String(), primary_key=True) + child_key = spanner_orm.Field( + spanner_orm.String(), primary_key=True) + + def upgrade(): + return spanner_orm.CreateTable(_Child) + """, + ), + error_class=google_api_exceptions.FailedPrecondition, + ), + ) + def test_drop_table_error( + self, + *, + create_migrations: Iterable[str], + error_class: Type[Exception], + ): + self._append_migrations(*create_migrations) + self.run_orm_migrations(self._migrations_dir) + self._append_migrations(""" + def upgrade(): + return spanner_orm.DropTable('TableToDrop') + """) + with self.assertRaises(error_class): + self.run_orm_migrations(self._migrations_dir) + + +if __name__ == '__main__': + logging.basicConfig() + absltest.main() diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py new file mode 100644 index 0000000..915d1f3 --- /dev/null +++ b/spanner_orm/tests/migrations_for_emulator_test/create_custom_length_field_f959b767457d.py @@ -0,0 +1,35 @@ +"""Spanner ORM migration: create_custom_length_field. + +Migration ID: 'f959b767457d' +Created: 2022-09-13 13:28:34-07:00 +""" + +import spanner_orm + +migration_id = 'f959b767457d' +prev_migration_id = '69a8f072dacf' + + +class OriginalTeeTable(spanner_orm.model.Model): + """ORM Model with the original schema for the Commands table. + Don't update this model, create new migrations instead. + """ + + __table__ = 'Tee' + id = spanner_orm.Field(spanner_orm.String(), primary_key=True) + custom_string_length = spanner_orm.Field(spanner_orm.String(20)) + custom_array_string_length = spanner_orm.Field( + spanner_orm.Array(spanner_orm.String(4))) + custom_bytes_length = spanner_orm.Field(spanner_orm.BytesBase64(20)) + custom_array_bytes_length = spanner_orm.Field( + spanner_orm.Array(spanner_orm.BytesBase64(4))) + + +def upgrade() -> spanner_orm.CreateTable: + """Creates the original Commands table.""" + return spanner_orm.CreateTable(OriginalTeeTable) + + +def downgrade() -> spanner_orm.DropTable: + """Drops the original Commands table.""" + return spanner_orm.DropTable(OriginalTeeTable.__table__) diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py new file mode 100644 index 0000000..02c947d --- /dev/null +++ b/spanner_orm/tests/migrations_for_emulator_test/create_foreign_key_test_model.py @@ -0,0 +1,56 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Creates table with ForeignKeyTestModel. + +Migration ID: 'f735d6b706d3' +Created: 2020-07-10 16:24 +""" + +import spanner_orm +from spanner_orm import field +from spanner_orm import foreign_key_relationship + +migration_id = 'f735d6b706d4' +prev_migration_id = 'f735d6b706d3' + + +class OriginalForeignKeyTestModelTable(spanner_orm.model.Model): + """ORM Model with the original schema for the ForeignKeyTestModel table.""" + + __table__ = 'ForeignKeyTestModel' + referencing_key_1 = field.Field(field.String(), primary_key=True) + referencing_key_2 = field.Field(field.String(), primary_key=True) + referencing_key_3 = field.Field(field.Integer(), primary_key=True) + self_referencing_key = field.Field(field.String(), nullable=True) + foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( + 'SmallTestModel', {'referencing_key_1': 'key'}) + foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( + 'UnittestModel', + { + 'referencing_key_2': 'string', + 'referencing_key_3': 'int_' + }, + ) + foreign_key_3 = foreign_key_relationship.ForeignKeyRelationship( + 'ForeignKeyTestModel', {'self_referencing_key': 'referencing_key_1'}) + + +def upgrade() -> spanner_orm.CreateTable: + """See ORM migrations interface.""" + return spanner_orm.CreateTable(OriginalForeignKeyTestModelTable) + + +def downgrade() -> spanner_orm.DropTable: + """See ORM migrations interface.""" + return spanner_orm.DropTable(OriginalForeignKeyTestModelTable.__table__) diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_760ec5fae5da.py b/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_760ec5fae5da.py new file mode 100644 index 0000000..4dafccb --- /dev/null +++ b/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_760ec5fae5da.py @@ -0,0 +1,40 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Spanner ORM migration: create_null_filtered_index_model. + +Migration ID: '760ec5fae5da' +Created: 2022-03-01 16:50:32-05:00 +""" + +import spanner_orm + +migration_id = '760ec5fae5da' +prev_migration_id = 'f735d6b706d4' + + +class _NullFilteredIndexModel(spanner_orm.Model): + __table__ = 'NullFilteredIndexModel' + key = spanner_orm.Field(spanner_orm.String(), primary_key=True) + value_1 = spanner_orm.Field(spanner_orm.String(), nullable=True) + value_2 = spanner_orm.Field(spanner_orm.Integer()) + + +def upgrade() -> spanner_orm.MigrationUpdate: + """See spanner_orm migrations interface.""" + return spanner_orm.CreateTable(_NullFilteredIndexModel) + + +def downgrade() -> spanner_orm.MigrationUpdate: + """See spanner_orm migrations interface.""" + return spanner_orm.DropTable(_NullFilteredIndexModel.__table__) diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_value_index_69a8f072dacf.py b/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_value_index_69a8f072dacf.py new file mode 100644 index 0000000..e0c16bf --- /dev/null +++ b/spanner_orm/tests/migrations_for_emulator_test/create_null_filtered_index_model_value_index_69a8f072dacf.py @@ -0,0 +1,28 @@ +"""Spanner ORM migration: create_null_filtered_index_model_value_index. + +Migration ID: '69a8f072dacf' +Created: 2022-03-01 16:53:59-05:00 +""" + +import spanner_orm + +migration_id = '69a8f072dacf' +prev_migration_id = '760ec5fae5da' + + +def upgrade() -> spanner_orm.MigrationUpdate: + """See spanner_orm migrations interface.""" + return spanner_orm.CreateIndex( + table_name='NullFilteredIndexModel', + index_name='value_index', + columns=['value_1', 'value_2'], + null_filtered=True, + ) + + +def downgrade() -> spanner_orm.MigrationUpdate: + """See spanner_orm migrations interface.""" + return spanner_orm.DropIndex( + table_name='NullFilteredIndexModel', + index_name='value_index', + ) diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py new file mode 100644 index 0000000..5878337 --- /dev/null +++ b/spanner_orm/tests/migrations_for_emulator_test/create_small_test_model.py @@ -0,0 +1,43 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Creates table with SmallTestModel. + +Migration ID: 'f735d6b706d2' +Created: 2020-07-10 16:24 +""" + +import spanner_orm +from spanner_orm import field + +migration_id = 'f735d6b706d2' +prev_migration_id = None + + +class OriginalSmallTestModelTable(spanner_orm.model.Model): + """ORM Model with the original schema for the SmallTestModel table.""" + + __table__ = 'SmallTestModel' + key = field.Field(field.String(), primary_key=True) + value_1 = field.Field(field.String()) + value_2 = field.Field(field.String(), nullable=True) + + +def upgrade() -> spanner_orm.CreateTable: + """See ORM migrations interface.""" + return spanner_orm.CreateTable(OriginalSmallTestModelTable) + + +def downgrade() -> spanner_orm.DropTable: + """See ORM migrations interface.""" + return spanner_orm.DropTable(OriginalSmallTestModelTable.__table__) diff --git a/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py new file mode 100644 index 0000000..8fcf25e --- /dev/null +++ b/spanner_orm/tests/migrations_for_emulator_test/create_unittest_model.py @@ -0,0 +1,53 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Creates table with UnittestModel. + +Migration ID: 'f735d6b706d3' +Created: 2020-07-10 16:24 +""" + +import spanner_orm +from spanner_orm import field + +migration_id = 'f735d6b706d3' +prev_migration_id = 'f735d6b706d2' + + +class OriginalUnittestModelTable(spanner_orm.model.Model): + """ORM Model with the original schema for the UnittestModel table.""" + + __table__ = 'table' + int_ = field.Field(field.Integer(), primary_key=True) + int_2 = field.Field(field.Integer(), nullable=True) + float_ = field.Field(field.Float(), primary_key=True) + float_2 = field.Field(field.Float(), nullable=True) + string = field.Field(field.String(), primary_key=True) + string_2 = field.Field(field.String(), nullable=True) + string_3 = field.Field(field.String(20), nullable=True) + bytes_ = field.Field(field.BytesBase64, primary_key=True) + bytes_2 = field.Field(field.BytesBase64, nullable=True) + bytes_3 = field.Field(field.BytesBase64(20), nullable=True) + timestamp = field.Field(field.Timestamp) + string_array = field.Field(field.Array(field.String()), nullable=True) + string_array_2 = field.Field(field.Array(field.String(20)), nullable=True) + + +def upgrade() -> spanner_orm.CreateTable: + """See ORM migrations interface.""" + return spanner_orm.CreateTable(OriginalUnittestModelTable) + + +def downgrade() -> spanner_orm.DropTable: + """See ORM migrations interface.""" + return spanner_orm.DropTable(OriginalUnittestModelTable.__table__) diff --git a/spanner_orm/tests/migrations_test.py b/spanner_orm/tests/migrations_test.py index 07c1610..e596cc4 100644 --- a/spanner_orm/tests/migrations_test.py +++ b/spanner_orm/tests/migrations_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,6 +13,9 @@ # limitations under the License. import logging import os +import shutil +import stat +import tempfile import unittest from unittest import mock @@ -25,11 +27,12 @@ class MigrationsTest(unittest.TestCase): - TEST_DIR = os.path.dirname(__file__) + TEST_DIR = tempfile.mkdtemp() TEST_MIGRATIONS_DIR = os.path.join(TEST_DIR, 'migrations') def test_retrieve(self): - manager = migration_manager.MigrationManager(self.TEST_MIGRATIONS_DIR) + testdata_filename = os.path.join(os.path.dirname(__file__), 'migrations') + manager = migration_manager.MigrationManager(testdata_filename) migrations = manager.migrations self.assertEqual(len(migrations), 3) self.assertEqual(migrations[2].prev_migration_id, @@ -38,6 +41,15 @@ def test_retrieve(self): migrations[0].migration_id) def test_generate(self): + testdata_filename = os.path.join(os.path.dirname(__file__), 'migrations') + shutil.rmtree(self.TEST_MIGRATIONS_DIR) + shutil.copytree(testdata_filename, self.TEST_MIGRATIONS_DIR) + os.chmod(self.TEST_MIGRATIONS_DIR, stat.S_IRWXO | stat.S_IRWXU) + for f in os.listdir(self.TEST_MIGRATIONS_DIR): + file_path = os.path.join(self.TEST_MIGRATIONS_DIR, f) + if not os.path.isdir(file_path): + os.chmod(file_path, + stat.S_IROTH | stat.S_IWOTH | stat.S_IRUSR | stat.S_IWUSR) manager = migration_manager.MigrationManager(self.TEST_MIGRATIONS_DIR) path = manager.generate('test migration') try: @@ -47,7 +59,7 @@ def test_generate(self): self.assertIsInstance(migration_.upgrade(), update.NoUpdate) self.assertIsInstance(migration_.downgrade(), update.NoUpdate) finally: - os.remove(path) + shutil.rmtree(self.TEST_MIGRATIONS_DIR) def test_order_migrations(self): first = migration.Migration('1', None) @@ -110,7 +122,8 @@ def test_order_migrations_error_on_no_successor(self): def test_filter_migrations(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -130,7 +143,8 @@ def test_filter_migrations(self): def test_filter_migrations_error_on_bad_last_migration(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -147,7 +161,8 @@ def test_filter_migrations_error_on_bad_last_migration(self): def test_validate_migrations(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -165,7 +180,8 @@ def test_validate_migrations(self): def test_validate_migrations_error_on_unmigrated_after_migrated(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -185,7 +201,8 @@ def test_validate_migrations_error_on_unmigrated_after_migrated(self): def test_validate_migrations_error_on_unmigrated_first(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('2', '1') with mock.patch.object(executor, 'migrations') as migrations: @@ -203,7 +220,8 @@ def test_validate_migrations_error_on_unmigrated_first(self): def test_migrate(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -217,7 +235,8 @@ def test_migrate(self): def test_rollback(self): connection = mock.Mock() - executor = migration_executor.MigrationExecutor(connection) + executor = migration_executor.MigrationExecutor(connection, + self.TEST_MIGRATIONS_DIR) first = migration.Migration('1', None) second = migration.Migration('2', '1') @@ -229,6 +248,11 @@ def test_rollback(self): executor.rollback('1') self.assertEqual(migrated, {'1': False, '2': False, '3': False}) + @classmethod + def tearDownClass(cls): + super().tearDownClass() + shutil.rmtree(MigrationsTest.TEST_DIR) + if __name__ == '__main__': logging.basicConfig() diff --git a/spanner_orm/tests/model_api_test.py b/spanner_orm/tests/model_api_test.py deleted file mode 100644 index 4ce0a52..0000000 --- a/spanner_orm/tests/model_api_test.py +++ /dev/null @@ -1,143 +0,0 @@ -# python3 -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import unittest -from unittest import mock - -from spanner_orm import error -from spanner_orm.tests import models - - -class ModelApiTest(unittest.TestCase): - - @mock.patch('spanner_orm.table_apis.find') - def test_find_calls_api(self, find): - mock_transaction = mock.Mock() - models.UnittestModel.find(mock_transaction, string='string', int_=1) - - find.assert_called_once() - (transaction, table, columns, keyset), _ = find.call_args - self.assertEqual(transaction, mock_transaction) - self.assertEqual(table, models.UnittestModel.table) - self.assertEqual(columns, models.UnittestModel.columns) - self.assertEqual(keyset.keys, [[1, 'string']]) - - @mock.patch('spanner_orm.table_apis.find') - def test_find_result(self, find): - mock_transaction = mock.Mock() - - find.return_value = [['key', 'value_1', None]] - result = models.SmallTestModel.find(mock_transaction, key='key') - if result: - self.assertEqual(result.key, 'key') - self.assertEqual(result.value_1, 'value_1') - self.assertIsNone(result.value_2) - else: - self.fail('Failed to find result') - - @mock.patch('spanner_orm.table_apis.find') - def test_find_multi_calls_api(self, find): - mock_transaction = mock.Mock() - models.UnittestModel.find_multi(mock_transaction, [{ - 'string': 'string', - 'int_': 1 - }]) - - find.assert_called_once() - (transaction, table, columns, keyset), _ = find.call_args - self.assertEqual(transaction, mock_transaction) - self.assertEqual(table, models.UnittestModel.table) - self.assertEqual(columns, models.UnittestModel.columns) - self.assertEqual(keyset.keys, [[1, 'string']]) - - @mock.patch('spanner_orm.table_apis.find') - def test_find_multi_result(self, find): - mock_transaction = mock.Mock() - find.return_value = [['key', 'value_1', None]] - results = models.SmallTestModel.find_multi(mock_transaction, [{ - 'key': 'key' - }]) - - self.assertEqual(results[0].key, 'key') - self.assertEqual(results[0].value_1, 'value_1') - self.assertIsNone(results[0].value_2) - - @mock.patch('spanner_orm.table_apis.insert') - def test_create_calls_api(self, insert): - mock_transaction = mock.Mock() - models.SmallTestModel.create(mock_transaction, key='key', value_1='value') - - insert.assert_called_once() - (transaction, table, columns, values), _ = insert.call_args - self.assertEqual(transaction, mock_transaction) - self.assertEqual(table, models.SmallTestModel.table) - self.assertEqual(list(columns), ['key', 'value_1']) - self.assertEqual(list(values), [['key', 'value']]) - - def test_create_error_on_invalid_keys(self): - with self.assertRaises(error.SpannerError): - models.SmallTestModel.create(key_2='key') - - def assert_api_called(self, mock_api, mock_transaction): - mock_api.assert_called_once() - (transaction, table, columns, values), _ = mock_api.call_args - self.assertEqual(transaction, mock_transaction) - self.assertEqual(table, models.SmallTestModel.table) - self.assertEqual(list(columns), ['key', 'value_1', 'value_2']) - self.assertEqual(list(values), [['key', 'value', None]]) - - @mock.patch('spanner_orm.table_apis.insert') - def test_save_batch_inserts(self, insert): - mock_transaction = mock.Mock() - values = {'key': 'key', 'value_1': 'value'} - not_persisted = models.SmallTestModel(values) - models.SmallTestModel.save_batch(mock_transaction, [not_persisted]) - self.assert_api_called(insert, mock_transaction) - - @mock.patch('spanner_orm.table_apis.update') - def test_save_batch_updates(self, update): - mock_transaction = mock.Mock() - values = {'key': 'key', 'value_1': 'value'} - persisted = models.SmallTestModel(values, persisted=True) - models.SmallTestModel.save_batch(mock_transaction, [persisted]) - - self.assert_api_called(update, mock_transaction) - - @mock.patch('spanner_orm.table_apis.upsert') - def test_save_batch_force_write_upserts(self, upsert): - mock_transaction = mock.Mock() - values = {'key': 'key', 'value_1': 'value'} - not_persisted = models.SmallTestModel(values) - models.SmallTestModel.save_batch( - mock_transaction, [not_persisted], force_write=True) - self.assert_api_called(upsert, mock_transaction) - - @mock.patch('spanner_orm.table_apis.delete') - def test_delete_batch_deletes(self, delete): - mock_transaction = mock.Mock() - values = {'key': 'key', 'value_1': 'value'} - model = models.SmallTestModel(values) - models.SmallTestModel.delete_batch(mock_transaction, [model]) - - delete.assert_called_once() - (transaction, table, keyset), _ = delete.call_args - self.assertEqual(transaction, mock_transaction) - self.assertEqual(table, models.SmallTestModel.table) - self.assertEqual(keyset.keys, [[model.key]]) - - -if __name__ == '__main__': - logging.basicConfig() - unittest.main() diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 804368a..1dc3c10 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,16 +13,203 @@ # limitations under the License. import datetime import logging +import os +import typing from typing import List import unittest from unittest import mock from absl.testing import parameterized +from google.api_core import exceptions +from google.cloud import spanner +from spanner_orm import error from spanner_orm import field +from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib from spanner_orm.tests import models +_TIMESTAMP = datetime.datetime.now(tz=datetime.timezone.utc) -class ModelTest(parameterized.TestCase): + +class ModelTest( + spanner_emulator_testlib.TestCase, + parameterized.TestCase, +): + + def setUp(self): + super().setUp() + self.run_orm_migrations( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'migrations_for_emulator_test', + )) + + @mock.patch('spanner_orm.table_apis.find') + def test_find_calls_api(self, find): + mock_transaction = mock.Mock() + models.UnittestModel.find( + string='string', + int_=1, + float_=2.3, + bytes_=b'A1A1', + transaction=mock_transaction, + ) + + find.assert_called_once() + (transaction, table, columns, keyset), _ = find.call_args + self.assertEqual(transaction, mock_transaction) + self.assertEqual(table, models.UnittestModel.table) + self.assertEqual(columns, models.UnittestModel.columns) + self.assertEqual(keyset.keys, [[1, 2.3, 'string', b'A1A1']]) + + @mock.patch('spanner_orm.table_apis.find') + def test_find_result(self, find): + mock_transaction = mock.Mock() + + find.return_value = [['key', 'value_1', None]] + result = models.SmallTestModel.find(key='key', transaction=mock_transaction) + if result: + self.assertEqual(result.key, 'key') + self.assertEqual(result.value_1, 'value_1') + self.assertIsNone(result.value_2) + else: + self.fail('Failed to find result') + + def test_find_required(self): + test_model = models.SmallTestModel( + dict( + key='some-key', + value_1='foo', + value_2='bar', + )) + test_model.save() + self.assertEqual( + test_model, + models.SmallTestModel.find_required(key='some-key'), + ) + + def test_find_required_not_found(self): + with self.assertRaisesRegex(exceptions.NotFound, + 'SmallTestModel has no object'): + models.SmallTestModel.find_required(key='some-key') + + @mock.patch('spanner_orm.table_apis.find') + def test_find_multi_calls_api(self, find): + mock_transaction = mock.Mock() + models.UnittestModel.find_multi( + [{ + 'string': 'string', + 'bytes_': b'bytes', + 'int_': 1, + 'float_': 2.3 + }], + transaction=mock_transaction, + ) + + find.assert_called_once() + (transaction, table, columns, keyset), _ = find.call_args + self.assertEqual(transaction, mock_transaction) + self.assertEqual(table, models.UnittestModel.table) + self.assertEqual(columns, models.UnittestModel.columns) + self.assertEqual(keyset.keys, [[1, 2.3, 'string', b'bytes']]) + + @mock.patch('spanner_orm.table_apis.find') + def test_find_multi_result(self, find): + mock_transaction = mock.Mock() + find.return_value = [['key', 'value_1', None]] + results = models.SmallTestModel.find_multi( + [{ + 'key': 'key' + }], + transaction=mock_transaction, + ) + + self.assertEqual(results[0].key, 'key') + self.assertEqual(results[0].value_1, 'value_1') + self.assertIsNone(results[0].value_2) + + @mock.patch('spanner_orm.table_apis.insert') + def test_create_calls_api(self, insert): + mock_transaction = mock.Mock() + models.SmallTestModel.create( + key='key', + value_1='value', + transaction=mock_transaction, + ) + + insert.assert_called_once() + (transaction, table, columns, values), _ = insert.call_args + self.assertEqual(transaction, mock_transaction) + self.assertEqual(table, models.SmallTestModel.table) + self.assertEqual(list(columns), ['key', 'value_1']) + self.assertEqual(list(values), [['key', 'value']]) + + def test_create_error_on_invalid_keys(self): + with self.assertRaises(error.SpannerError): + models.SmallTestModel.create(key_2='key') + + def assert_api_called(self, mock_api, mock_transaction): + mock_api.assert_called_once() + (transaction, table, columns, values), _ = mock_api.call_args + self.assertEqual(transaction, mock_transaction) + self.assertEqual(table, models.SmallTestModel.table) + self.assertEqual(list(columns), ['key', 'value_1', 'value_2']) + self.assertEqual(list(values), [['key', 'value', None]]) + + @mock.patch('spanner_orm.table_apis.insert') + def test_save_batch_inserts(self, insert): + mock_transaction = mock.Mock() + values = {'key': 'key', 'value_1': 'value'} + not_persisted = models.SmallTestModel(values) + models.SmallTestModel.save_batch([not_persisted], + transaction=mock_transaction) + self.assert_api_called(insert, mock_transaction) + + @mock.patch('spanner_orm.table_apis.update') + def test_save_batch_updates(self, update): + mock_transaction = mock.Mock() + values = {'key': 'key', 'value_1': 'value'} + persisted = models.SmallTestModel(values, persisted=True) + models.SmallTestModel.save_batch([persisted], transaction=mock_transaction) + + self.assert_api_called(update, mock_transaction) + + @mock.patch('spanner_orm.table_apis.upsert') + def test_save_batch_force_write_upserts(self, upsert): + mock_transaction = mock.Mock() + values = {'key': 'key', 'value_1': 'value'} + not_persisted = models.SmallTestModel(values) + models.SmallTestModel.save_batch( + [not_persisted], + force_write=True, + transaction=mock_transaction, + ) + self.assert_api_called(upsert, mock_transaction) + + @mock.patch('spanner_orm.table_apis.delete') + def test_delete_batch_deletes(self, delete): + mock_transaction = mock.Mock() + values = {'key': 'key', 'value_1': 'value'} + model = models.SmallTestModel(values) + models.SmallTestModel.delete_batch([model], transaction=mock_transaction) + + delete.assert_called_once() + (transaction, table, keyset), _ = delete.call_args + self.assertEqual(transaction, mock_transaction) + self.assertEqual(table, models.SmallTestModel.table) + self.assertEqual(keyset.keys, [[model.key]]) + + @mock.patch('spanner_orm.table_apis.delete') + def test_delete_by_key_deletes(self, delete): + mock_transaction = mock.Mock() + models.SmallTestModel.delete_by_key( + key='some-key', + transaction=mock_transaction, + ) + delete.assert_called_once_with( + mock_transaction, + models.SmallTestModel.table, + spanner.KeySet(keys=[['some-key']]), + ) def test_set_attr(self): test_model = models.SmallTestModel({'key': 'key', 'value_1': 'value'}) @@ -40,14 +226,17 @@ def test_set_error_on_primary_key(self): with self.assertRaises(AttributeError): test_model.key = 'error' - @parameterized.parameters(('int_2', 'foo'), ('string_2', 5), - ('string_array', 'foo'), ('timestamp', 5)) + @parameterized.parameters( + ('int_2', 'foo'), ('float_2', 'bar'), ('string_2', 5), ('string_3', 5), + ('bytes_2', 'string'), ('string_array', 'foo'), ('timestamp', 5)) def test_set_error_on_invalid_type(self, attribute, value): string_array = ['foo', 'bar'] timestamp = datetime.datetime.now(tz=datetime.timezone.utc) test_model = models.UnittestModel({ 'int_': 0, + 'float_': 0, 'string': '', + 'bytes_': b'', 'string_array': string_array, 'timestamp': timestamp }) @@ -60,8 +249,95 @@ def test_get_attr(self): self.assertEqual(test_model.value_1, 'value') self.assertEqual(test_model.value_2, None) + @parameterized.parameters( + (True, True), + (True, False), + (False, True), + ) + def test_skip_validation(self, persisted, skip_validation): + models.SmallTestModel( + {'value_1': 'value'}, + persisted=persisted, + skip_validation=skip_validation, + ) + + def test_validation(self): + with self.assertRaises(error.SpannerError): + models.SmallTestModel( + {'value_1': 'value'}, + persisted=False, + skip_validation=False, + ) + + def test_model_equates(self): + timestamp = datetime.datetime.now(tz=datetime.timezone.utc) + test_model1 = models.UnittestModel({ + 'int_': 0, + 'float_': 0, + 'string': '', + 'bytes_': b'', + 'string_array': ['foo', 'bar'], + 'timestamp': timestamp, + }) + test_model2 = models.UnittestModel({ + 'int_': 0, + 'float_': 0.0, + 'string': '', + 'bytes_': b'', + 'string_array': ['foo', 'bar'], + 'timestamp': timestamp, + }) + self.assertEqual(test_model1, test_model2) + + @parameterized.parameters( + (models.UnittestModel({ + 'int_': 0, + 'float_': 0, + 'string': '1', + 'bytes_': b'1111', + 'timestamp': _TIMESTAMP, + }), + models.UnittestModel({ + 'int_': 0, + 'float_': 0, + 'string': 'a', + 'bytes_': b'A1A1', + 'timestamp': _TIMESTAMP, + })), + (models.UnittestModel({ + 'int_': 0, + 'float_': 0, + 'string': '', + 'bytes_': b'A1A1', + 'string_array': ['foo', 'bar'], + 'timestamp': _TIMESTAMP, + }), + models.UnittestModel({ + 'int_': 0, + 'float_': 0, + 'string': '', + 'bytes_': b'A1A1', + 'string_array': ['bar', 'foo'], + 'timestamp': _TIMESTAMP, + })), + (models.SmallTestModel({ + 'key': 'key', + 'value_1': 'value' + }), models.InheritanceTestModel({ + 'key': 'key', + 'value_1': 'value' + })), + ) + def test_model_are_different(self, test_model1, test_model2): + self.assertNotEqual(test_model1, test_model2) + + def test_repr(self): + self.assertEqual( + "SmallTestModel({'key': 'a', 'value_1': 'b', 'value_2': None})", + repr(models.SmallTestModel(dict(key='a', value_1='b', value_2=None)))) + def test_id(self): - primary_key = {'string': 'foo', 'int_': 5} + primary_key = {'string': 'foo', 'int_': 5, 'float_': 2.3, 'bytes_': b'A1A1'} all_data = primary_key.copy() all_data.update({ 'timestamp': datetime.datetime.now(tz=datetime.timezone.utc), @@ -83,19 +359,22 @@ def test_object_changes(self): timestamp = datetime.datetime.now(tz=datetime.timezone.utc) test_model = models.UnittestModel({ 'int_': 0, + 'float_': 0, 'string': '', + 'bytes_': b'', 'string_array': array, 'timestamp': timestamp }) # Make sure that changing an object on the model shows up in changes() - string_array = test_model.string_array # type: List + string_array = typing.cast(List[str], test_model.string_array) string_array.append('bat') self.assertIn('string_array', test_model.changes()) def test_field_exists_on_model_class(self): self.assertIsInstance(models.SmallTestModel.key, field.Field) - self.assertEqual(models.SmallTestModel.key.field_type(), field.String) + self.assertEqual(models.SmallTestModel.key.field_type().ddl(), + 'STRING(MAX)') self.assertFalse(models.SmallTestModel.key.nullable()) self.assertEqual(models.SmallTestModel.key.name, 'key') @@ -126,19 +405,16 @@ def test_relation_get_error_on_unretrieved(self): def test_interleaved(self): self.assertEqual(models.ChildTestModel.interleaved, models.SmallTestModel) - @mock.patch('spanner_orm.model.ModelApi.find') + @mock.patch('spanner_orm.model.Model.find') def test_reload(self, find): values = {'key': 'key', 'value_1': 'value_1'} model = models.SmallTestModel(values, persisted=False) find.return_value = None self.assertIsNone(model.reload()) - find.assert_called_once() - (transaction,), kwargs = find.call_args - self.assertIsNone(transaction) - self.assertEqual(kwargs, model.id()) + find.assert_called_once_with(**model.id(), transaction=None) - @mock.patch('spanner_orm.model.ModelApi.find') + @mock.patch('spanner_orm.model.Model.find') def test_reload_reloads(self, find): values = {'key': 'key', 'value_1': 'value_1'} model = models.SmallTestModel(values, persisted=False) @@ -149,18 +425,15 @@ def test_reload_reloads(self, find): self.assertEqual(model.value_1, updated_values['value_1']) self.assertEqual(model.changes(), {}) - @mock.patch('spanner_orm.model.ModelApi.create') + @mock.patch('spanner_orm.model.Model.create') def test_save_creates(self, create): values = {'key': 'key', 'value_1': 'value_1'} model = models.SmallTestModel(values, persisted=False) model.save() - create.assert_called_once() - (transaction,), kwargs = create.call_args - self.assertIsNone(transaction) - self.assertEqual(kwargs, {**values, 'value_2': None}) + create.assert_called_once_with(**values, value_2=None, transaction=None) - @mock.patch('spanner_orm.model.ModelApi.update') + @mock.patch('spanner_orm.model.Model.update') def test_save_updates(self, update): values = {'key': 'key', 'value_1': 'value_1'} model = models.SmallTestModel(values, persisted=True) @@ -169,12 +442,9 @@ def test_save_updates(self, update): model.value_1 = values['value_1'] model.save() - update.assert_called_once() - (transaction,), kwargs = update.call_args - self.assertIsNone(transaction) - self.assertEqual(kwargs, values) + update.assert_called_once_with(**values, transaction=None) - @mock.patch('spanner_orm.model.ModelApi.update') + @mock.patch('spanner_orm.model.Model.update') def test_save_no_changes(self, update): values = {'key': 'key', 'value_1': 'value_1'} model = models.SmallTestModel(values, persisted=True) @@ -186,7 +456,7 @@ def test_delete_deletes(self, delete): mock_transaction = mock.Mock() values = {'key': 'key', 'value_1': 'value_1'} model = models.SmallTestModel(values) - model.delete(mock_transaction) + model.delete(transaction=mock_transaction) delete.assert_called_once() (transaction, table, keyset), _ = delete.call_args diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 2a8218a..5f4a370 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,6 +14,7 @@ """Models used by unit tests.""" from spanner_orm import field +from spanner_orm import foreign_key_relationship from spanner_orm import index from spanner_orm import model from spanner_orm import relationship @@ -30,6 +30,14 @@ class SmallTestModel(model.Model): index_1 = index.Index(['value_1']) +class SmallTestModelWithoutSecondaryIndexes(model.Model): + """Same as SmallTestModel, but with no secondary indexes.""" + __table__ = 'SmallTestModel' + key = field.Field(field.String, primary_key=True) + value_1 = field.Field(field.String) + value_2 = field.Field(field.String, nullable=True) + + class ChildTestModel(model.Model): """Model class for testing interleaved tables.""" @@ -62,6 +70,28 @@ class RelationshipTestModel(model.Model): {'parent_key': 'key'}) +class ForeignKeyTestModel(model.Model): + """Model class for testing foreign keys.""" + + __table__ = 'ForeignKeyTestModel' + referencing_key_1 = field.Field(field.String, primary_key=True) + referencing_key_2 = field.Field(field.String, primary_key=True) + referencing_key_3 = field.Field(field.Integer, primary_key=True) + self_referencing_key = field.Field(field.String, nullable=True) + foreign_key_1 = foreign_key_relationship.ForeignKeyRelationship( + 'spanner_orm.tests.models.SmallTestModel', {'referencing_key_1': 'key'}) + foreign_key_2 = foreign_key_relationship.ForeignKeyRelationship( + 'spanner_orm.tests.models.UnittestModel', + { + 'referencing_key_2': 'string', + 'referencing_key_3': 'int_' + }, + ) + foreign_key_3 = foreign_key_relationship.ForeignKeyRelationship( + 'spanner_orm.tests.models.ForeignKeyTestModel', + {'self_referencing_key': 'referencing_key_1'}) + + class InheritanceTestModel(SmallTestModel): """Model class used for testing model inheritance.""" value_3 = field.Field(field.String, nullable=True) @@ -73,9 +103,45 @@ class UnittestModel(model.Model): __table__ = 'table' int_ = field.Field(field.Integer, primary_key=True) int_2 = field.Field(field.Integer, nullable=True) + float_ = field.Field(field.Float, primary_key=True) + float_2 = field.Field(field.Float, nullable=True) string = field.Field(field.String, primary_key=True) string_2 = field.Field(field.String, nullable=True) + string_3 = field.Field(field.String(20), nullable=True) + bytes_ = field.Field(field.BytesBase64, primary_key=True) + bytes_2 = field.Field(field.BytesBase64, nullable=True) + bytes_3 = field.Field(field.BytesBase64(20), nullable=True) timestamp = field.Field(field.Timestamp) string_array = field.Field(field.StringArray, nullable=True) + string_array_2 = field.Field(field.Array(field.String(20)), nullable=True) test_index = index.Index(['string_2']) + + +class UnittestModelWithoutSecondaryIndexes(model.Model): + """Same as UnittestModel, but with no secondary indexes.""" + + __table__ = 'table' + int_ = field.Field(field.Integer, primary_key=True) + int_2 = field.Field(field.Integer, nullable=True) + float_ = field.Field(field.Float, primary_key=True) + float_2 = field.Field(field.Float, nullable=True) + string = field.Field(field.String, primary_key=True) + string_2 = field.Field(field.String, nullable=True) + string_3 = field.Field(field.String(20), nullable=True) + bytes_ = field.Field(field.BytesBase64, primary_key=True) + bytes_2 = field.Field(field.BytesBase64, nullable=True) + bytes_3 = field.Field(field.BytesBase64(20), nullable=True) + timestamp = field.Field(field.Timestamp) + string_array = field.Field(field.StringArray, nullable=True) + string_array_2 = field.Field(field.Array(field.String(20)), nullable=True) + + +class NullFilteredIndexModel(model.Model): + """Model class for testing NULL_FILTERED indexes.""" + + __table__ = 'NullFilteredIndexModel' + key = field.Field(field.String, primary_key=True) + value_1 = field.Field(field.String, nullable=True) + value_2 = field.Field(field.Integer) + value_index = index.Index(['value_1', 'value_2'], null_filtered=True) diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index 02ffa28..ff18627 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,7 +23,7 @@ from spanner_orm import query from spanner_orm.tests import models -from google.cloud.spanner_v1.proto import type_pb2 +from google.cloud import spanner_v1 def now(): @@ -37,19 +36,19 @@ class QueryTest(parameterized.TestCase): def test_where(self, sql_query): sql_query.return_value = [] - models.UnittestModel.where_equal(True, int_=3) + models.UnittestModel.where_equal(int_=3, transaction=True) (_, sql, parameters, types), _ = sql_query.call_args expected_sql = 'SELECT .* FROM table WHERE table.int_ = @int_0' self.assertRegex(sql, expected_sql) self.assertEqual(parameters, {'int_0': 3}) - self.assertEqual(types, {'int_0': field.Integer.grpc_type()}) + self.assertEqual(types, {'int_0': field.Integer().grpc_type()}) @mock.patch('spanner_orm.table_apis.sql_query') def test_count(self, sql_query): sql_query.return_value = [[0]] column, value = 'int_', 3 - models.UnittestModel.count_equal(True, int_=3) + models.UnittestModel.count_equal(int_=3, transaction=True) (_, sql, parameters, types), _ = sql_query.call_args column_key = '{}0'.format(column) @@ -57,7 +56,7 @@ def test_count(self, sql_query): column, column_key) self.assertRegex(sql, expected_sql) self.assertEqual({column_key: value}, parameters) - self.assertEqual(types, {column_key: field.Integer.grpc_type()}) + self.assertEqual(types, {column_key: field.Integer().grpc_type()}) def test_count_allows_force_index(self): force_index = condition.force_index('test_index') @@ -67,7 +66,8 @@ def test_count_allows_force_index(self): self.assertEqual(expected_sql, sql) @parameterized.parameters( - condition.limit(1), condition.order_by(('int_', condition.OrderType.DESC))) + condition.limit(1), condition.order_by( + ('int_', condition.OrderType.DESC))) def test_count_only_allows_where_and_from_segment_conditions(self, condition): with self.assertRaises(error.SpannerError): query.CountQuery(models.UnittestModel, [condition]) @@ -81,7 +81,7 @@ def test_query_limit(self): self.assertEndsWith(select_query.sql(), ' LIMIT @{}'.format(key)) self.assertEqual(select_query.parameters(), {key: value}) - self.assertEqual(select_query.types(), {key: field.Integer.grpc_type()}) + self.assertEqual(select_query.types(), {key: field.Integer().grpc_type()}) select_query = self.select() self.assertNotRegex(select_query.sql(), 'LIMIT') @@ -97,10 +97,11 @@ def test_query_limit_offset(self): limit_key: limit, offset_key: offset }) - self.assertEqual(select_query.types(), { - limit_key: field.Integer.grpc_type(), - offset_key: field.Integer.grpc_type() - }) + self.assertEqual( + select_query.types(), { + limit_key: field.Integer().grpc_type(), + offset_key: field.Integer().grpc_type() + }) def test_query_order_by(self): order = ('int_', condition.OrderType.DESC) @@ -124,9 +125,9 @@ def test_query_order_by_with_object(self): select_query = self.select() self.assertNotRegex(select_query.sql(), 'ORDER BY') - @parameterized.parameters(('int_', 5, field.Integer.grpc_type()), - ('string', 'foo', field.String.grpc_type()), - ('timestamp', now(), field.Timestamp.grpc_type())) + @parameterized.parameters(('int_', 5, field.Integer().grpc_type()), + ('string', 'foo', field.String().grpc_type()), + ('timestamp', now(), field.Timestamp().grpc_type())) def test_query_where_comparison(self, column, value, grpc_type): condition_generators = [ condition.greater_than, condition.not_less_than, condition.less_than, @@ -144,9 +145,9 @@ def test_query_where_comparison(self, column, value, grpc_type): self.assertEqual(select_query.types(), {column_key: grpc_type}) @parameterized.parameters( - (models.UnittestModel.int_, 5, field.Integer.grpc_type()), - (models.UnittestModel.string, 'foo', field.String.grpc_type()), - (models.UnittestModel.timestamp, now(), field.Timestamp.grpc_type())) + (models.UnittestModel.int_, 5, field.Integer().grpc_type()), + (models.UnittestModel.string, 'foo', field.String().grpc_type()), + (models.UnittestModel.timestamp, now(), field.Timestamp().grpc_type())) def test_query_where_comparison_with_object(self, column, value, grpc_type): condition_generators = [ condition.greater_than, condition.not_less_than, condition.less_than, @@ -164,9 +165,10 @@ def test_query_where_comparison_with_object(self, column, value, grpc_type): self.assertEqual(select_query.types(), {column_key: grpc_type}) @parameterized.parameters( - ('int_', [1, 2, 3], field.Integer.grpc_type()), - ('string', ['a', 'b', 'c'], field.String.grpc_type()), - ('timestamp', [now()], field.Timestamp.grpc_type())) + ('int_', [1, 2, 3], field.Integer().grpc_type()), + ('int_', (4, 5, 6), field.Integer().grpc_type()), + ('string', ['a', 'b', 'c'], field.String().grpc_type()), + ('timestamp', [now()], field.Timestamp().grpc_type())) def test_query_where_list_comparison(self, column, values, grpc_type): condition_generators = [condition.in_list, condition.not_in_list] for condition_generator in condition_generators: @@ -176,8 +178,8 @@ def test_query_where_list_comparison(self, column, values, grpc_type): column_key = '{}0'.format(column) expected_sql = ' WHERE table.{} {} UNNEST(@{})'.format( column, current_condition.operator, column_key) - list_type = type_pb2.Type( - code=type_pb2.ARRAY, array_element_type=grpc_type) + list_type = spanner_v1.Type( + code=spanner_v1.TypeCode.ARRAY, array_element_type=grpc_type) self.assertEndsWith(select_query.sql(), expected_sql) self.assertEqual(select_query.parameters(), {column_key: values}) self.assertEqual(select_query.types(), {column_key: list_type}) @@ -207,35 +209,86 @@ def test_force_index_with_object(self): expected_sql = 'FROM table@{FORCE_INDEX=test_index}' self.assertEndsWith(select_query.sql(), expected_sql) - def includes(self, relation, *conditions): - include_condition = condition.includes(relation, list(conditions)) - return query.SelectQuery(models.RelationshipTestModel, [include_condition]) - - def test_includes(self): - select_query = self.includes('parent') + def includes(self, relation, *conditions, foreign_key_relation=False): + include_condition = condition.includes(relation, list(conditions), + foreign_key_relation) + return query.SelectQuery( + models.ForeignKeyTestModel + if foreign_key_relation else models.RelationshipTestModel, + [include_condition], + ) + + @parameterized.parameters((models.RelationshipTestModel.parent, True), + (models.ForeignKeyTestModel.foreign_key_1, False)) + def test_bad_includes_args(self, relation_key, foreign_key_relation): + with self.assertRaisesRegex(ValueError, 'Must pass'): + self.includes( + relation_key, + foreign_key_relation=foreign_key_relation, + ) + + @parameterized.named_parameters( + ( + 'legacy_relationship', + { + 'relation': 'parent' + }, + r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ' + r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' + r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' + r'RelationshipTestModel.parent_key\)', + ), + ( + 'legacy_relationship_with_object_arg', + { + 'relation': models.RelationshipTestModel.parent + }, + r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ' + r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' + r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' + r'RelationshipTestModel.parent_key\)', + ), + ( + 'foreign_key_relationship', + { + 'relation': 'foreign_key_1', + 'foreign_key_relation': True + }, + r'SELECT ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ' + r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' + r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' + r'ForeignKeyTestModel.referencing_key_1\)', + ), + ( + 'foreign_key_relationship_with_object_arg', + { + 'relation': models.ForeignKeyTestModel.foreign_key_1, + 'foreign_key_relation': True + }, + r'SELECT ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ForeignKeyTestModel\S* ' + r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' + r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' + r'ForeignKeyTestModel.referencing_key_1\)', + ), + ) + def test_includes(self, includes_kwargs, expected_sql): + select_query = self.includes(**includes_kwargs) # The column order varies between test runs - expected_sql = ( - r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ' - r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' - r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' - r'RelationshipTestModel.parent_key\)') self.assertRegex(select_query.sql(), expected_sql) self.assertEmpty(select_query.parameters()) self.assertEmpty(select_query.types()) - def test_includes_with_object(self): - select_query = self.includes(models.RelationshipTestModel.parent) - - # The column order varies between test runs - expected_sql = ( - r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ' - r'ARRAY\(SELECT AS STRUCT SmallTestModel\S* SmallTestModel\S* ' - r'SmallTestModel\S* FROM SmallTestModel WHERE SmallTestModel.key = ' - r'RelationshipTestModel.parent_key\)') - self.assertRegex(select_query.sql(), expected_sql) - self.assertEmpty(select_query.parameters()) - self.assertEmpty(select_query.types()) + @parameterized.parameters(({ + 'relation': models.RelationshipTestModel.parent, + 'foreign_key_relation': True + },), ({ + 'relation': models.ForeignKeyTestModel.foreign_key_1, + 'foreign_key_relation': False + },)) + def test_error_mismatched_params(self, includes_kwargs): + with self.assertRaisesRegex(ValueError, 'Must pass'): + self.includes(**includes_kwargs) def test_includes_subconditions_query(self): select_query = self.includes('parents', condition.equal_to('key', 'value')) @@ -254,24 +307,87 @@ def includes_result(self, related=1): result.append(parents) return child, parent, [result] - def test_includes_single_related_object_result(self): - select_query = self.includes('parent') - child_values, parent_values, rows = self.includes_result(related=1) + def fk_includes_result(self, related=1): + child = { + 'referencing_key_1': 'parent_key', + 'referencing_key_2': 'child', + 'referencing_key_3': 'child', + 'self_referencing_key': 'child' + } + result = [child[name] for name in models.ForeignKeyTestModel.columns] + parent = {'key': 'key', 'value_1': 'value_1', 'value_2': None} + parents = [] + for _ in range(related): + parents.append([parent[name] for name in models.SmallTestModel.columns]) + result.append(parents) + return child, parent, [result] + + @parameterized.named_parameters( + ( + 'legacy_relationship', + { + 'relation': 'parent' + }, + lambda x: x.parent, + lambda x: x.includes_result(related=1), + ), + ( + 'foreign_key_relationship', + { + 'relation': 'foreign_key_1', + 'foreign_key_relation': True + }, + lambda x: x.foreign_key_1, + lambda x: x.fk_includes_result(related=1), + ), + ) + def test_includes_single_related_object_result( + self, + includes_kwargs, + referenced_table_fn, + includes_result_fn, + ): + select_query = self.includes(**includes_kwargs) + child_values, parent_values, rows = includes_result_fn(self) result = select_query.process_results(rows)[0] - self.assertIsInstance(result.parent, models.SmallTestModel) + self.assertIsInstance( + referenced_table_fn(result), + models.SmallTestModel, + ) for name, value in child_values.items(): self.assertEqual(getattr(result, name), value) for name, value in parent_values.items(): - self.assertEqual(getattr(result.parent, name), value) - - def test_includes_single_no_related_object_result(self): - select_query = self.includes('parent') - child_values, _, rows = self.includes_result(related=0) + self.assertEqual(getattr(referenced_table_fn(result), name), value) + + @parameterized.named_parameters( + ( + 'legacy_relationship', + { + 'relation': 'parent' + }, + lambda x: x.parent, + lambda x: x.includes_result(related=0), + ), + ( + 'foreign_key_relationship', + { + 'relation': 'foreign_key_1', + 'foreign_key_relation': True + }, + lambda x: x.foreign_key_1, + lambda x: x.fk_includes_result(related=0), + ), + ) + def test_includes_single_no_related_object_result(self, includes_kwargs, + referenced_table_fn, + includes_result_fn): + select_query = self.includes(**includes_kwargs) + child_values, _, rows = includes_result_fn(self) result = select_query.process_results(rows)[0] - self.assertIsNone(result.parent) + self.assertIsNone(referenced_table_fn(result)) for name, value in child_values.items(): self.assertEqual(getattr(result, name), value) @@ -288,21 +404,52 @@ def test_includes_subcondition_result(self): for name, value in parent_values.items(): self.assertEqual(getattr(result.parents[0], name), value) - def test_includes_error_on_multiple_results_for_single(self): - select_query = self.includes('parent') - _, _, rows = self.includes_result(related=2) + @parameterized.named_parameters( + ( + 'legacy_relationship', + { + 'relation': 'parent' + }, + lambda x: x.includes_result(related=2), + ), + ( + 'foreign_key_relationship', + { + 'relation': 'foreign_key_1', + 'foreign_key_relation': True + }, + lambda x: x.fk_includes_result(related=2), + ), + ) + def test_includes_error_on_multiple_results_for_single( + self, includes_kwargs, includes_result_fn): + select_query = self.includes(**includes_kwargs) + _, _, rows = includes_result_fn(self) with self.assertRaises(error.SpannerError): _ = select_query.process_results(rows) - def test_includes_error_on_invalid_relation(self): + @parameterized.parameters(True, False) + def test_includes_error_on_invalid_relation(self, foreign_key_relation): with self.assertRaises(error.ValidationError): - self.includes('bad_relation') + self.includes('bad_relation', foreign_key_relation=foreign_key_relation) - @parameterized.parameters(('bad_column', 0), ('child_key', 'good value'), - ('key', ['bad value'])) - def test_includes_error_on_invalid_subconditions(self, column, value): + @parameterized.parameters( + ('bad_column', 0, 'parent', False), + ('bad_column', 0, 'foreign_key_1', True), + ('child_key', 'good value', 'parent', False), + ('child_key', 'good value', 'foreign_key_1', False), + ('key', ['bad value'], 'parent', False), + ('key', ['bad value'], 'foreign_key_1', False), + ) + def test_includes_error_on_invalid_subconditions(self, column, value, + relation, + foreign_key_relation): with self.assertRaises(error.ValidationError): - self.includes('parent', condition.equal_to(column, value)) + self.includes( + relation, + condition.equal_to(column, value), + foreign_key_relation, + ) def test_or(self): condition_1 = condition.equal_to('int_', 1) @@ -313,8 +460,8 @@ def test_or(self): self.assertEndsWith(select_query.sql(), expected_sql) self.assertEqual(select_query.parameters(), {'int_0': 1, 'int_1': 2}) self.assertEqual(select_query.types(), { - 'int_0': field.Integer.grpc_type(), - 'int_1': field.Integer.grpc_type() + 'int_0': field.Integer().grpc_type(), + 'int_1': field.Integer().grpc_type() }) diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index 7b55965..6486cde 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -1,4 +1,3 @@ -# python3 # Copyright 2019 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,13 +15,32 @@ import unittest from unittest import mock +from absl.testing import parameterized + +import spanner_orm from spanner_orm import error from spanner_orm import field from spanner_orm.admin import update +from spanner_orm.testlib.spanner_emulator import testlib as spanner_emulator_testlib from spanner_orm.tests import models -class UpdateTest(unittest.TestCase): +class UpdateTest( + spanner_emulator_testlib.TestCase, + parameterized.TestCase, +): + + def setUp(self): + super().setUp() + _, project_id = self.spanner_emulator_client.project_name.split('/') + connection = spanner_orm.SpannerConnection( + instance=self.spanner_emulator_instance.instance_id, + database=self.spanner_emulator_database.database_id, + project=project_id, + credentials=self.spanner_emulator_client.credentials, + ) + spanner_orm.from_connection(connection) + spanner_orm.from_admin_connection(connection) @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_add_column(self, get_model): @@ -61,14 +79,20 @@ def test_drop_column_error_on_primary_key(self, get_model, index_count): @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_create_table(self, get_model): get_model.return_value = None - new_model = models.UnittestModel + new_model = models.UnittestModelWithoutSecondaryIndexes test_update = update.CreateTable(new_model) test_update.validate() test_model_ddl = ('CREATE TABLE table (int_ INT64 NOT NULL, int_2 INT64,' + ' float_ FLOAT64 NOT NULL, float_2 FLOAT64,' ' string STRING(MAX) NOT NULL, string_2 STRING(MAX),' - ' timestamp TIMESTAMP NOT NULL, string_array' - ' ARRAY) PRIMARY KEY (int_, string)') + ' string_3 STRING(20),' + ' bytes_ BYTES(MAX) NOT NULL, bytes_2 BYTES(MAX),' + ' bytes_3 BYTES(20),' + ' timestamp TIMESTAMP NOT NULL,' + ' string_array ARRAY,' + ' string_array_2 ARRAY)' + ' PRIMARY KEY (int_, float_, string, bytes_)') self.assertEqual(test_update.ddl(), test_model_ddl) @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') @@ -85,6 +109,31 @@ def test_create_table_interleaved(self, get_model): 'INTERLEAVE IN PARENT SmallTestModel ON DELETE CASCADE') self.assertEqual(test_update.ddl(), test_model_ddl) + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') + def test_create_table_foreign_key(self, get_model): + self.maxDiff = 2000 + + get_model.return_value = None + new_model = models.ForeignKeyTestModel + test_update = update.CreateTable(new_model) + test_update.validate() + + test_model_ddl = ( + 'CREATE TABLE ForeignKeyTestModel (' + 'referencing_key_1 STRING(MAX) NOT NULL, ' + 'referencing_key_2 STRING(MAX) NOT NULL, ' + 'referencing_key_3 INT64 NOT NULL, ' + 'self_referencing_key STRING(MAX), ' + 'CONSTRAINT foreign_key_1 FOREIGN KEY (referencing_key_1) ' + 'REFERENCES SmallTestModel (key), ' + 'CONSTRAINT foreign_key_2 ' + 'FOREIGN KEY (referencing_key_2, referencing_key_3) ' + 'REFERENCES table (string, int_), ' + 'CONSTRAINT foreign_key_3 FOREIGN KEY (self_referencing_key) ' + 'REFERENCES ForeignKeyTestModel (referencing_key_1)) ' + 'PRIMARY KEY (referencing_key_1, referencing_key_2, referencing_key_3)') + self.assertEqual(test_update.ddl(), test_model_ddl) + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_create_table_error_on_existing_table(self, get_model): get_model.return_value = models.SmallTestModel @@ -93,6 +142,17 @@ def test_create_table_error_on_existing_table(self, get_model): with self.assertRaisesRegex(error.SpannerError, 'already exists'): test_update.validate() + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') + def test_create_table_error_on_table_with_index(self, get_model): + get_model.return_value = None + new_model = models.IndexTestModel + test_update = update.CreateTable(new_model) + with self.assertRaisesRegex( + error.SpannerError, + 'indexes cannot be created', + ): + test_update.validate() + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.indexes') @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.tables') @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') @@ -106,15 +166,58 @@ def test_drop_table(self, get_model, tables, indexes): test_update.validate() self.assertEqual(test_update.ddl(), 'DROP TABLE {}'.format(table_name)) + @parameterized.named_parameters( + ( + 'basic', + update.CreateIndex( + table_name=models.SmallTestModel.table, + index_name='foo', + columns=['value_1'], + ), + f'CREATE INDEX foo ON {models.SmallTestModel.table} (value_1)', + ), + ( + 'with_options', + update.CreateIndex( + table_name=models.SmallTestModel.table, + index_name='foo', + columns=['value_1'], + null_filtered=True, + unique=True, + ), + (f'CREATE UNIQUE NULL_FILTERED INDEX foo ' + f'ON {models.SmallTestModel.table} (value_1)'), + ), + ) @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') - def test_add_index(self, get_model): - table_name = models.SmallTestModel.table + def test_add_index(self, test_update, expected_ddl, get_model): get_model.return_value = models.SmallTestModel - test_update = update.CreateIndex(table_name, 'foo', ['value_1']) test_update.validate() - self.assertEqual(test_update.ddl(), - 'CREATE INDEX foo ON {} (value_1)'.format(table_name)) + self.assertEqual(test_update.ddl(), expected_ddl) + + def test_execute_partitioned_dml(self): + update.CreateTable(models.SmallTestModelWithoutSecondaryIndexes).execute() + test_model = models.SmallTestModel( + dict( + key='some-key', + value_1='foo', + value_2='bar', + )) + test_model.save() + update.ExecutePartitionedDml( + "UPDATE SmallTestModel SET value_2 = value_1 WHERE value_2 = 'bar'", + ).execute() + test_model.reload() + self.assertEqual( + models.SmallTestModel( + dict( + key='some-key', + value_1='foo', + value_2='foo', + )), + test_model, + ) if __name__ == '__main__':