diff --git a/spanner_orm/admin/metadata.py b/spanner_orm/admin/metadata.py index 1a2ab5c..648e5d8 100644 --- a/spanner_orm/admin/metadata.py +++ b/spanner_orm/admin/metadata.py @@ -28,41 +28,36 @@ class SpannerMetadata(object): """Gathers information about a table from Spanner.""" - _models = None - _tables = None - _indexes = None @classmethod def models(cls): """Constructs model classes from Spanner table schema.""" - if cls._models is None: - tables = cls.tables() - indexes = cls.indexes() - models = {} - - for table_name, table_data in tables.items(): - primary_index = indexes[table_name][index.Index.PRIMARY_INDEX] - primary_keys = set(primary_index.columns) - klass = model.ModelBase('Model_{}'.format(table_name), (model.Model,), - {}) - for field in table_data['fields'].values(): - field._primary_key = field.name in primary_keys # pylint: disable=protected-access - - klass.meta = model.Metadata( - table=table_name, - fields=table_data['fields'], - interleaved=table_data['parent_table'], - indexes=indexes[table_name], - model_class=klass) - models[table_name] = klass - - for table_model in models.values(): - if table_model.meta.interleaved: - table_model.meta.interleaved = models[table_model.meta.interleaved] - table_model.meta.finalize() - - cls._models = models - return cls._models + tables = cls.tables() + indexes = cls.indexes() + models = {} + + for table_name, table_data in tables.items(): + primary_index = indexes[table_name][index.Index.PRIMARY_INDEX] + primary_keys = set(primary_index.columns) + klass = model.ModelBase('Model_{}'.format(table_name), (model.Model,), + {}) + for field in table_data['fields'].values(): + field._primary_key = field.name in primary_keys # pylint: disable=protected-access + + klass.meta = model.Metadata( + table=table_name, + fields=table_data['fields'], + interleaved=table_data['parent_table'], + indexes=indexes[table_name], + model_class=klass) + models[table_name] = klass + + for table_model in models.values(): + if table_model.meta.interleaved: + table_model.meta.interleaved = models[table_model.meta.interleaved] + table_model.meta.finalize() + + return models @classmethod def model(cls, table_name): @@ -71,63 +66,58 @@ def model(cls, table_name): @classmethod def tables(cls): """Compiles table information from column schema.""" - if cls._tables is None: - column_data = collections.defaultdict(dict) - columns = column.ColumnSchema.where( - None, condition.equal_to('table_catalog', ''), - condition.equal_to('table_schema', '')) - for column_row in columns: - new_field = field.Field( - column_row.field_type(), nullable=column_row.nullable()) - new_field.name = column_row.column_name - new_field.position = column_row.ordinal_position - column_data[column_row.table_name][column_row.column_name] = new_field - - table_data = collections.defaultdict(dict) - tables = table.TableSchema.where( - None, condition.equal_to('table_catalog', ''), - condition.equal_to('table_schema', '')) - for table_row in tables: - name = table_row.table_name - table_data[name]['parent_table'] = table_row.parent_table_name - table_data[name]['fields'] = column_data[name] - cls._tables = table_data - return cls._tables + column_data = collections.defaultdict(dict) + columns = column.ColumnSchema.where( + None, condition.equal_to('table_catalog', ''), + condition.equal_to('table_schema', '')) + for column_row in columns: + new_field = field.Field( + column_row.field_type(), nullable=column_row.nullable()) + new_field.name = column_row.column_name + new_field.position = column_row.ordinal_position + column_data[column_row.table_name][column_row.column_name] = new_field + + table_data = collections.defaultdict(dict) + tables = table.TableSchema.where( + None, condition.equal_to('table_catalog', ''), + condition.equal_to('table_schema', '')) + for table_row in tables: + name = table_row.table_name + table_data[name]['parent_table'] = table_row.parent_table_name + table_data[name]['fields'] = column_data[name] + return table_data @classmethod def indexes(cls): """Compiles index information from index and index columns schemas.""" - if cls._indexes is None: - # ordinal_position is the position of the column in the indicated index. - # Results are ordered by that so the index columns are added in the - # correct order. - index_column_schemas = index_column.IndexColumnSchema.where( - None, condition.equal_to('table_catalog', ''), - condition.equal_to('table_schema', ''), - condition.order_by(('ordinal_position', condition.OrderType.ASC))) - - index_columns = collections.defaultdict(list) - storing_columns = collections.defaultdict(list) - for schema in index_column_schemas: - key = (schema.table_name, schema.index_name) - if schema.ordinal_position is not None: - index_columns[key].append(schema.column_name) - else: - storing_columns[key].append(schema.column_name) - - index_schemas = index_schema.IndexSchema.where( - None, condition.equal_to('table_catalog', ''), - condition.equal_to('table_schema', '')) - indexes = collections.defaultdict(dict) - for schema in index_schemas: - key = (schema.table_name, schema.index_name) - indexes[schema.table_name][schema.index_name] = index.Index( - index_columns[key], - schema.index_name, - parent=schema.parent_table_name, - null_filtered=schema.is_null_filtered, - unique=schema.is_unique, - storing_columns=storing_columns[key]) - cls._indexes = indexes - - return cls._indexes + # ordinal_position is the position of the column in the indicated index. + # Results are ordered by that so the index columns are added in the + # correct order. + index_column_schemas = index_column.IndexColumnSchema.where( + None, condition.equal_to('table_catalog', ''), + condition.equal_to('table_schema', ''), + condition.order_by(('ordinal_position', condition.OrderType.ASC))) + + index_columns = collections.defaultdict(list) + storing_columns = collections.defaultdict(list) + for schema in index_column_schemas: + key = (schema.table_name, schema.index_name) + if schema.ordinal_position is not None: + index_columns[key].append(schema.column_name) + else: + storing_columns[key].append(schema.column_name) + + index_schemas = index_schema.IndexSchema.where( + None, condition.equal_to('table_catalog', ''), + condition.equal_to('table_schema', '')) + indexes = collections.defaultdict(dict) + for schema in index_schemas: + key = (schema.table_name, schema.index_name) + indexes[schema.table_name][schema.index_name] = index.Index( + index_columns[key], + schema.index_name, + parent=schema.parent_table_name, + null_filtered=schema.is_null_filtered, + unique=schema.is_unique, + storing_columns=storing_columns[key]) + return indexes diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 51312a7..074acdd 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -17,6 +17,8 @@ import abc from spanner_orm import condition +from spanner_orm import error +from spanner_orm import index from spanner_orm.admin import api from spanner_orm.admin import index_column from spanner_orm.admin import metadata @@ -38,7 +40,7 @@ def validate(self): raise NotImplementedError -class CreateTableUpdate(SchemaUpdate): +class CreateTable(SchemaUpdate): """Update that allows creating a new table.""" def __init__(self, model): @@ -48,30 +50,80 @@ def __init__(self, model): def ddl(self): return self._model.creation_ddl + def validate(self): + if not self._model.table: + raise error.SpannerError('New table has no name') + + if self._existing_model: + raise error.SpannerError('Table {} already exists'.format( + self._model.table)) + + if self._model.interleaved: + self._validate_parent() + + self._validate_primary_keys() + def _validate_parent(self): parent_primary_keys = self._model.interleaved.primary_keys - primary_keys = self._model.primary_keys, ('Non-matching primary keys in ' - 'interleaved table') - assert len(parent_primary_keys) <= len(primary_keys) + primary_keys = self._model.primary_keys + + message = 'Table {} is not a child of parent table {}'.format( + self._model.table, self._model.interleaved.table) for parent_key, key in zip(parent_primary_keys, primary_keys): - assert parent_key == key, 'Non-matching primary keys in interleaved table' + if parent_key != key: + raise error.SpnnerError(message) + if len(parent_primary_keys) > len(primary_keys): + raise error.SpannerError(message) def _validate_primary_keys(self): - assert self._model.primary_keys, 'Creating a table with no primary key' + if not self._model.primary_keys: + raise error.SpannerError('Table {} has no primary key'.format( + self._model.table)) + for key in self._model.primary_keys: - assert key in self._model.schema, 'Trying to index fields not in table' + if key not in self._model.schema: + raise error.SpannerError( + 'Table {} column {} in primary key but not in schema'.format( + self._model.table, key)) + + +class DropTable(SchemaUpdate): + """Update for dropping an existing table.""" + + def __init__(self, table_name): + self._table = table_name + self._existing_model = metadata.SpannerMetadata.model(table_name) + + def ddl(self): + return 'DROP TABLE {}'.format(self._table) def validate(self): - assert self._model.table, 'Trying to create a table with no name' - assert not self._existing_model, ('Trying to create a table that already ' - 'exists') - if self._model.interleaved: - self._validate_parent() - self._validate_primary_keys() + if not self._existing_model: + raise error.SpannerError('Table {} does not exist'.format(self._table)) + # Model indexes include the primary index + if len(self._existing_model.indexes) > 1: + raise error.SpannerError('Table {} has a secondary index'.format( + self._table)) -class ColumnUpdate(SchemaUpdate): - """Specifies column updates such as ADD, DROP, and ALTER.""" + self._validate_not_interleaved() + + def _validate_not_interleaved(self): + for model in metadata.SpannerMetadata.models().values(): + if model.interleaved == self._existing_model: + raise error.SpannerError('Table {} has interleaved table {}'.format( + self._table, model.table)) + for index in model.indexes.values(): + if index.parent == self._table: + raise error.SpannerError('Table {} has interleaved index {}'.format( + self._table, index.name)) + + +class AddColumn(SchemaUpdate): + """Update for adding a column to an existing table. + + Only supports adding nullable columns + """ def __init__(self, table_name, column_name, field): self._table = table_name @@ -80,62 +132,173 @@ def __init__(self, table_name, column_name, field): self._model = metadata.SpannerMetadata.model(table_name) def ddl(self): - if self._field is None: - return 'ALTER TABLE {} DROP COLUMN {}'.format(self._table, self._column) - elif self._column in self._model.schema: - operation = 'ALTER COLUMN' - else: - operation = 'ADD COLUMN' - return 'ALTER TABLE {} {} {} {}'.format(self._table, operation, - self._column, self._field.ddl()) - - def _validate_alter_column(self): - assert self._column in self._model.schema, 'Altering a nonexistent column' - old_field = self._model.schema[self._column] - # Validate that the only alteration is to change column nullability - assert self._field.field_type() == old_field.field_type( - ), 'Changing the type of a column' - assert self._field.nullable() != old_field.nullable() + return 'ALTER TABLE {} ADD COLUMN {} {}'.format(self._table, self._column, + self._field.ddl()) + + def validate(self): + if not self._model: + raise error.SpannerError('Table {} does not exist'.format(self._table)) + if not self._field.nullable(): + raise error.SpannerError('Column {} is not nullable'.format(self._column)) + if self._field.primary_key(): + raise error.SpannerError('Column {} is a primary key'.format( + self._column)) + + +class DropColumn(SchemaUpdate): + """Update for dropping a column from an existing table.""" + + def __init__(self, table_name, column_name): + self._table = table_name + self._column = column_name + self._model = metadata.SpannerMetadata.model(table_name) + + def ddl(self): + return 'ALTER TABLE {} DROP COLUMN {}'.format(self._table, self._column) + + def validate(self): + if not self._model: + raise error.SpannerError('Table {} does not exist'.format(self._table)) + + if self._column not in self._model.schema: + raise error.SpannerError('Column {} does not exist on {}'.format( + self._column, self._table)) - def _validate_drop_column(self): - assert self._column in self._model.schema, 'Dropping a nonexistent column' # Verify no indices exist on the column we're trying to drop - num_index_columns = index_column.IndexColumnSchema.count( + num_indexed_columns = index_column.IndexColumnSchema.count( None, condition.equal_to('column_name', self._column), condition.equal_to('table_name', self._table)) - assert num_index_columns == 0, 'Dropping an indexed column' + if num_indexed_columns > 0: + raise error.SpannerError('Column {} is indexed'.format(self._column)) + + +class AlterColumn(SchemaUpdate): + """Update for altering a column an existing table. + + Only supports changing the nullability of a column + """ + + def __init__(self, table_name, column_name, field): + self._table = table_name + self._column = column_name + self._field = field + self._model = metadata.SpannerMetadata.model(table_name) + + def ddl(self): + return 'ALTER TABLE {} ALTER COLUMN {} {}'.format(self._table, self._column, + self._field.ddl()) def validate(self): - assert self._model - if self._field is None: - self._validate_drop_column() - elif self._column in self._model.schema: - self._validate_alter_column() - else: - assert self._field.nullable(), 'Adding a non-nullable column' + if not self._model: + raise error.SpannerError('Table {} does not exist'.format(self._table)) + + if self._column not in self._model.schema: + raise error.SpannerError('Column {} does not exist on {}'.format( + self._column, self._table)) + + if self._column in self._model.primary_keys: + raise error.SpannerError('Column {} is a primary key on {}'.format( + self._column, self._table)) + old_field = self._model.schema[self._column] + # Validate that the only alteration is to change column nullability + if self._field.field_type() != old_field.field_type(): + raise error.SpannerError('Column {} is changing type'.format( + self._column)) + if self._field.nullable() == old_field.nullable(): + raise error.SpannerError('Column {} has no changes'.format(self._column)) -class IndexUpdate(SchemaUpdate): - """Specifies index updates such as ADD and DROP.""" - def __init__(self, table_name, index_name, columns): +class CreateIndex(SchemaUpdate): + """Update for creating an index on an existing table.""" + + def __init__(self, + table_name, + index_name, + columns, + interleaved=None, + storing_columns=None): self._table = table_name self._index = index_name self._columns = columns + self._parent_table = interleaved + self._storing_columns = storing_columns or [] self._model = metadata.SpannerMetadata.model(table_name) - # TODO(dbrandao): implement def ddl(self): - raise NotImplementedError + statement = 'CREATE INDEX {} ON {} ({})'.format(self._index, self._table, + ', '.join(self._columns)) + if self._storing_columns: + statement += 'STORING ({})'.format(', '.join(self._storing_columns)) + if self._parent_table: + statement += ', INTERLEAVE IN {}'.format(self._parent_table) + return statement - # TODO(dbrandao): implement def validate(self): - assert self._model - raise NotImplementedError + if not self._model: + raise error.SpannerError('Table {} does not exist'.format(self._table)) + + if not self._columns: + raise error.SpannerError('Index {} has no columns'.format(self._index)) + + if self._index in self._model.indexes: + raise error.SpannerError('Index {} already exists'.format(self._index)) + + self._validate_columns() + + if self._parent_table: + self._validate_parent() + + def _validate_columns(self): + for column in self._columns: + if not column in self._model.columns: + raise error.SpannerError('Table {} has no column {}'.format( + self._table, column)) + + for column in self._storing_columns: + if not column in self._model.columns: + raise error.SpannerError('Table {} has no column {}'.format( + self._table, column)) + if column in self._model.primary_keys: + raise error.SpannerError('{} is part of the primary key for {}'.format( + column, self._table)) + + def validate_parent(self): + parent = self._model.interleaved + while parent: + if parent == self._parent_table: + break + parent = parent.interleaved + if not parent: + raise error.SpannerError('{} is not a parent of table {}'.format( + self._parent_table, self._table)) + + +class DropIndex(SchemaUpdate): + """Update for dropping a secondary index on an existing table.""" + + def __init__(self, table_name, index_name): + self._table = table_name + self._index = index_name + self._model = metadata.SpannerMetadata.model(table_name) + + def ddl(self): + return 'DROP INDEX {}'.format(self._index) + + def validate(self): + if not self._model: + raise error.SpannerError('Table {} does not exist'.format(self._table)) + + if self._index not in self._model.indexes: + raise error.SpannerError('Index {} does not exist'.format(self._index)) + if self._index == index.Index.PRIMARY_INDEX: + raise error.SpannerError('Index {} is the primary index'.format( + self._index)) class NoUpdate(SchemaUpdate): """Update that does nothing, for migrations that don't update db schemas.""" + def ddl(self): return '' diff --git a/spanner_orm/tests/admin_test.py b/spanner_orm/tests/admin_test.py index 6573610..ceb78fb 100644 --- a/spanner_orm/tests/admin_test.py +++ b/spanner_orm/tests/admin_test.py @@ -27,11 +27,6 @@ class AdminTest(unittest.TestCase): - def clear_metadata(self): - metadata.SpannerMetadata._models = None - metadata.SpannerMetadata._indexes = None - metadata.SpannerMetadata._tables = None - def make_test_tables(self, model, parent_table=None): tables = [{ 'table_catalog': '', @@ -102,7 +97,6 @@ def make_test_index(self, model, name=None): @mock.patch('spanner_orm.admin.column.ColumnSchema.where') @mock.patch('spanner_orm.admin.table.TableSchema.where') def test_metadata(self, tables, columns, index_columns, indexes): - self.clear_metadata() model = models.SmallTestModel tables.return_value = self.make_test_tables(model) columns.return_value = self.make_test_columns(model) @@ -127,7 +121,6 @@ def test_metadata(self, tables, columns, index_columns, indexes): @mock.patch('spanner_orm.admin.column.ColumnSchema.where') @mock.patch('spanner_orm.admin.table.TableSchema.where') def test_interleaved(self, tables, columns, index_columns, indexes): - self.clear_metadata() model = models.SmallTestModel parent_model = models.UnittestModel tables.return_value = ( @@ -151,7 +144,6 @@ def test_interleaved(self, tables, columns, index_columns, indexes): @mock.patch('spanner_orm.admin.column.ColumnSchema.where') @mock.patch('spanner_orm.admin.table.TableSchema.where') def test_secondary_index(self, tables, columns, index_columns, indexes): - self.clear_metadata() model = models.SmallTestModel name = 'secondary_index' index_cols = ['value_1'] diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index bc614d0..1a07254 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -16,6 +16,7 @@ import unittest from unittest import mock +from spanner_orm import error from spanner_orm import field from spanner_orm.admin import update from spanner_orm.tests import models @@ -25,38 +26,81 @@ class UpdateTest(unittest.TestCase): @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') - def test_column_update_add_column(self, get_model): + def test_add_column(self, get_model): + table_name = models.SmallTestModel.table get_model.return_value = models.SmallTestModel + + new_field = field.Field(field.String, nullable=True) + test_update = update.AddColumn(table_name, 'bar', new_field) + test_update.validate() + self.assertEqual( + test_update.ddl(), + 'ALTER TABLE {} ADD COLUMN bar STRING(MAX)'.format(table_name)) + + @mock.patch('spanner_orm.admin.index_column.IndexColumnSchema.count') + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') + def test_drop_column(self, get_model, index_count): + table_name = models.SmallTestModel.table + get_model.return_value = models.SmallTestModel + index_count.return_value = 0 + new_field = field.Field(field.String, nullable=True) - test_update = update.ColumnUpdate('foo', 'bar', new_field) + test_update = update.DropColumn(table_name, 'value_2') test_update.validate() self.assertEqual(test_update.ddl(), - 'ALTER TABLE foo ADD COLUMN bar STRING(MAX)') + 'ALTER TABLE {} DROP COLUMN value_2'.format(table_name)) @mock.patch('spanner_orm.admin.index_column.IndexColumnSchema.count') @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') - def test_column_update_error_on_primary_key(self, get_model, index_count): - index_count.return_value = 1 + def test_drop_column_error_on_primary_key(self, get_model, index_count): get_model.return_value = models.SmallTestModel - test_update = update.ColumnUpdate(models.SmallTestModel.table, 'key', None) - with self.assertRaisesRegex(AssertionError, 'indexed column'): + index_count.return_value = 1 + + test_update = update.DropColumn(models.SmallTestModel.table, 'key') + with self.assertRaisesRegex(error.SpannerError, 'Column key is indexed'): test_update.validate() @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_create_table(self, get_model): get_model.return_value = None new_model = models.SmallTestModel - test_update = update.CreateTableUpdate(new_model) + test_update = update.CreateTable(new_model) + test_update.validate() self.assertEqual(test_update.ddl(), new_model.creation_ddl) @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_create_table_error_on_existing_table(self, get_model): get_model.return_value = models.SmallTestModel new_model = models.SmallTestModel - test_update = update.CreateTableUpdate(new_model) - with self.assertRaisesRegex(AssertionError, 'already exists'): + test_update = update.CreateTable(new_model) + with self.assertRaisesRegex(error.SpannerError, 'already exists'): test_update.validate() + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.indexes') + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.tables') + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') + def test_drop_table(self, get_model, tables, indexes): + table_name = models.SmallTestModel.table + get_model.return_value = models.SmallTestModel + tables.return_value = {} + indexes.return_value = {} + + test_update = update.DropTable(table_name) + test_update.validate() + self.assertEqual(test_update.ddl(), 'DROP TABLE {}'.format(table_name)) + + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') + def test_add_index(self, get_model): + table_name = models.SmallTestModel.table + get_model.return_value = models.SmallTestModel + + test_update = update.CreateIndex(table_name, 'foo', ['value_1']) + test_update.validate() + self.assertEqual( + test_update.ddl(), + 'CREATE INDEX foo ON {} (value_1)'.format(table_name)) + + if __name__ == '__main__': logging.basicConfig() unittest.main()