diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 76a086e..ee3fd6e 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -45,16 +45,27 @@ class CreateTable(SchemaUpdate): def __init__(self, model): self._model = model - self._existing_model = metadata.SpannerMetadata.model(model.table) def ddl(self): - return self._model.creation_ddl + fields = [ + '{} {}'.format(name, field.ddl()) + for name, field in self._model.schema.items() + ] + index_ddl = 'PRIMARY KEY ({})'.format(', '.join(self._model.primary_keys)) + statement = 'CREATE TABLE {} ({}) {}'.format(self._model.table, + ', '.join(fields), index_ddl) + + if self._model.interleaved: + statement += ', INTERLEAVE IN PARENT {parent} ON CASCADE DELETE'.format( + parent=self._model.interleaved.table) + return statement def validate(self): if not self._model.table: raise error.SpannerError('New table has no name') - if self._existing_model: + existing_model = metadata.SpannerMetadata.model(self._model.table) + if existing_model: raise error.SpannerError('Table {} already exists'.format( self._model.table)) @@ -71,7 +82,7 @@ def _validate_parent(self): self._model.table, self._model.interleaved.table) for parent_key, key in zip(parent_primary_keys, primary_keys): if parent_key != key: - raise error.SpnnerError(message) + raise error.SpannerError(message) if len(parent_primary_keys) > len(primary_keys): raise error.SpannerError(message) @@ -92,17 +103,17 @@ class DropTable(SchemaUpdate): def __init__(self, table_name): self._table = table_name - self._existing_model = metadata.SpannerMetadata.model(table_name) def ddl(self): return 'DROP TABLE {}'.format(self._table) def validate(self): - if not self._existing_model: + existing_model = metadata.SpannerMetadata.model(self._table) + if not existing_model: raise error.SpannerError('Table {} does not exist'.format(self._table)) # Model indexes include the primary index - if len(self._existing_model.indexes) > 1: + if len(existing_model.indexes) > 1: raise error.SpannerError('Table {} has a secondary index'.format( self._table)) @@ -110,7 +121,7 @@ def validate(self): def _validate_not_interleaved(self): for model in metadata.SpannerMetadata.models().values(): - if model.interleaved == self._existing_model: + if model.interleaved == existing_model: raise error.SpannerError('Table {} has interleaved table {}'.format( self._table, model.table)) for index in model.indexes.values(): @@ -129,14 +140,14 @@ def __init__(self, table_name, column_name, field): self._table = table_name self._column = column_name self._field = field - self._model = metadata.SpannerMetadata.model(table_name) def ddl(self): return 'ALTER TABLE {} ADD COLUMN {} {}'.format(self._table, self._column, self._field.ddl()) def validate(self): - if not self._model: + model = metadata.SpannerMetadata.model(self._table) + if not model: raise error.SpannerError('Table {} does not exist'.format(self._table)) if not self._field.nullable(): raise error.SpannerError('Column {} is not nullable'.format(self._column)) @@ -151,16 +162,16 @@ class DropColumn(SchemaUpdate): def __init__(self, table_name, column_name): self._table = table_name self._column = column_name - self._model = metadata.SpannerMetadata.model(table_name) def ddl(self): return 'ALTER TABLE {} DROP COLUMN {}'.format(self._table, self._column) def validate(self): - if not self._model: + model = metadata.SpannerMetadata.model(self._table) + if not model: raise error.SpannerError('Table {} does not exist'.format(self._table)) - if self._column not in self._model.schema: + if self._column not in model.schema: raise error.SpannerError('Column {} does not exist on {}'.format( self._column, self._table)) @@ -182,25 +193,25 @@ def __init__(self, table_name, column_name, field): self._table = table_name self._column = column_name self._field = field - self._model = metadata.SpannerMetadata.model(table_name) def ddl(self): return 'ALTER TABLE {} ALTER COLUMN {} {}'.format(self._table, self._column, self._field.ddl()) def validate(self): - if not self._model: + model = metadata.SpannerMetadata.model(self._table) + if not model: raise error.SpannerError('Table {} does not exist'.format(self._table)) - if self._column not in self._model.schema: + if self._column not in model.schema: raise error.SpannerError('Column {} does not exist on {}'.format( self._column, self._table)) - if self._column in self._model.primary_keys: + if self._column in model.primary_keys: raise error.SpannerError('Column {} is a primary key on {}'.format( self._column, self._table)) - old_field = self._model.schema[self._column] + old_field = model.schema[self._column] # Validate that the only alteration is to change column nullability if self._field.field_type() != old_field.field_type(): raise error.SpannerError('Column {} is changing type'.format( @@ -223,7 +234,6 @@ def __init__(self, self._columns = columns self._parent_table = interleaved self._storing_columns = storing_columns or [] - self._model = metadata.SpannerMetadata.model(table_name) def ddl(self): statement = 'CREATE INDEX {} ON {} ({})'.format(self._index, self._table, @@ -235,40 +245,42 @@ def ddl(self): return statement def validate(self): - if not self._model: + model = metadata.SpannerMetadata.model(self._table) + if not model: raise error.SpannerError('Table {} does not exist'.format(self._table)) if not self._columns: raise error.SpannerError('Index {} has no columns'.format(self._index)) - if self._index in self._model.indexes: + if self._index in model.indexes: raise error.SpannerError('Index {} already exists'.format(self._index)) - self._validate_columns() + self._validate_columns(model) if self._parent_table: - self._validate_parent() + self._validate_parent(model) - def _validate_columns(self): + def _validate_columns(self, model): for column in self._columns: - if not column in self._model.columns: + if not column in model.columns: raise error.SpannerError('Table {} has no column {}'.format( self._table, column)) for column in self._storing_columns: - if not column in self._model.columns: + if not column in model.columns: raise error.SpannerError('Table {} has no column {}'.format( self._table, column)) - if column in self._model.primary_keys: + if column in model.primary_keys: raise error.SpannerError('{} is part of the primary key for {}'.format( column, self._table)) - def validate_parent(self): - parent = self._model.interleaved + def validate_parent(self, model): + parent = model.interleaved while parent: if parent == self._parent_table: break parent = parent.interleaved + if not parent: raise error.SpannerError('{} is not a parent of table {}'.format( self._parent_table, self._table)) @@ -280,13 +292,13 @@ class DropIndex(SchemaUpdate): def __init__(self, table_name, index_name): self._table = table_name self._index = index_name - self._model = metadata.SpannerMetadata.model(table_name) def ddl(self): return 'DROP INDEX {}'.format(self._index) def validate(self): - if not self._model: + model = metadata.SpannerMetadata.model(self._table) + if not model: raise error.SpannerError('Table {} does not exist'.format(self._table)) db_index = self._model.indexes.get(self._index) @@ -308,3 +320,20 @@ def execute(self): def validate(self): pass + + +def model_creation_ddl(model): + ddl_list = [CreateTable(model).ddl()] + + for model_index in model.indexes: + if model_index.primary_key: + continue + create_index = CreateIndex( + model.table, + model_index.name, + model_index.columns, + interleaved=model_index.parent, + storing_columns=model_index.storing_columns) + ddl_list.append(create_index.ddl()) + + return ddl_list diff --git a/spanner_orm/model.py b/spanner_orm/model.py index 9e291d4..0f934e7 100644 --- a/spanner_orm/model.py +++ b/spanner_orm/model.py @@ -138,22 +138,6 @@ def __getattr__(cls, name): return cls.indexes[name] raise AttributeError(name) - @property - def creation_ddl(cls): - fields = [ - '{} {}'.format(name, field.ddl()) for name, field in cls.schema.items() - ] - field_ddl = '({})'.format(', '.join(fields)) - index_ddl = 'PRIMARY KEY ({})'.format(', '.join(cls.primary_keys)) - trailing_statements = [index_ddl] - if cls.interleaved: - interleave_ddl = 'INTERLEAVE IN PARENT {parent} ON CASCADE DELETE'.format( - parent=cls.interleaved.table) - trailing_statements.append(interleave_ddl) - trailing_ddl = ', '.join(trailing_statements) - return 'CREATE TABLE {table_name} {fields} {trailing}'.format( - table_name=cls.table, fields=field_ddl, trailing=trailing_ddl) - @property def column_prefix(cls): return cls.table.split('.')[-1] diff --git a/spanner_orm/tests/model_test.py b/spanner_orm/tests/model_test.py index 63ff2ad..55f1568 100644 --- a/spanner_orm/tests/model_test.py +++ b/spanner_orm/tests/model_test.py @@ -92,21 +92,6 @@ def test_object_changes(self): test_model.string_array.append('bat') self.assertIn('string_array', test_model.changes()) - def test_creation_ddl(self): - test_model_ddl = ('CREATE TABLE table (int_ INT64 NOT NULL, int_2 INT64,' - ' string STRING(MAX) NOT NULL, string_2 STRING(MAX),' - ' timestamp TIMESTAMP NOT NULL, string_array' - ' ARRAY) PRIMARY KEY (int_, string)') - self.assertEqual(models.UnittestModel.creation_ddl, test_model_ddl) - - def test_interleaved_creation_ddl(self): - test_model_ddl = ('CREATE TABLE ChildTestModel (' - 'parent_key STRING(MAX) NOT NULL, ' - 'child_key STRING(MAX) NOT NULL) ' - 'PRIMARY KEY (parent_key, child_key), ' - 'INTERLEAVE IN PARENT SmallTestModel ON CASCADE DELETE') - self.assertEqual(models.ChildTestModel.creation_ddl, test_model_ddl) - def test_field_exists_on_model_class(self): self.assertIsInstance(models.SmallTestModel.key, field.Field) self.assertEqual(models.SmallTestModel.key.field_type(), field.String) @@ -122,7 +107,7 @@ def test_field_inheritance(self): self.assertEqual(getattr(test_model, name), value) def test_relation_get(self): - test_model = models.ChildTestModel({ + test_model = models.RelationshipTestModel({ 'parent_key': 'parent', 'child_key': 'child', 'parent': [] @@ -130,7 +115,7 @@ def test_relation_get(self): self.assertEqual(test_model.parent, []) def test_relation_get_error_on_unretrieved(self): - test_model = models.ChildTestModel({ + test_model = models.RelationshipTestModel({ 'parent_key': 'parent', 'child_key': 'child' }) diff --git a/spanner_orm/tests/models.py b/spanner_orm/tests/models.py index 98b5966..76d6ab8 100644 --- a/spanner_orm/tests/models.py +++ b/spanner_orm/tests/models.py @@ -21,7 +21,7 @@ class SmallTestModel(model.Model): - """Model class used for testing""" + """Model class used for testing.""" __table__ = 'SmallTestModel' key = field.Field(field.String, primary_key=True) @@ -30,10 +30,18 @@ class SmallTestModel(model.Model): class ChildTestModel(model.Model): - """Model class for testing relationships""" - + """Model class for testing interleaved tables.""" __table__ = 'ChildTestModel' __interleaved__ = SmallTestModel + + key = field.Field(field.String, primary_key=True) + child_key = field.Field(field.String, primary_key=True) + + +class RelationshipTestModel(model.Model): + """Model class for testing relationships.""" + + __table__ = 'RelationshipTestModel' parent_key = field.Field(field.String, primary_key=True) child_key = field.Field(field.String, primary_key=True) parent = relationship.Relationship( @@ -44,12 +52,12 @@ class ChildTestModel(model.Model): class InheritanceTestModel(SmallTestModel): - """Model class used for testing model inheritance""" + """Model class used for testing model inheritance.""" value_3 = field.Field(field.String, nullable=True) class UnittestModel(model.Model): - """Model class used for model testing""" + """Model class used for model testing.""" __table__ = 'table' int_ = field.Field(field.Integer, primary_key=True) diff --git a/spanner_orm/tests/query_test.py b/spanner_orm/tests/query_test.py index 7f02859..9acd474 100644 --- a/spanner_orm/tests/query_test.py +++ b/spanner_orm/tests/query_test.py @@ -156,30 +156,30 @@ def test_force_index(self): def includes(self, relation, *conditions): include_condition = condition.includes(relation, list(conditions)) - return query.SelectQuery(models.ChildTestModel, [include_condition]) + return query.SelectQuery(models.RelationshipTestModel, [include_condition]) def test_includes(self): select_query = self.includes('parent') # The column order varies between test runs expected_sql = ( - r'SELECT ChildTestModel\S* ChildTestModel\S* ARRAY\(SELECT AS ' + r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ARRAY\(SELECT AS ' r'STRUCT SmallTestModel\S* SmallTestModel\S* SmallTestModel\S* FROM ' r'SmallTestModel WHERE SmallTestModel.key = ' - r'ChildTestModel.parent_key\)') + r'RelationshipTestModel.parent_key\)') self.assertRegex(select_query.sql(), expected_sql) self.assertEmpty(select_query.parameters()) self.assertEmpty(select_query.types()) def test_includes_subconditions_query(self): select_query = self.includes('parents', condition.equal_to('key', 'value')) - expected_sql = ('WHERE SmallTestModel.key = ChildTestModel.parent_key ' + expected_sql = ('WHERE SmallTestModel.key = RelationshipTestModel.parent_key ' 'AND SmallTestModel.key = @key0') self.assertRegex(select_query.sql(), expected_sql) def includes_result(self, related=1): child = {'parent_key': 'parent_key', 'child_key': 'child'} - result = [child[name] for name in models.ChildTestModel.columns] + result = [child[name] for name in models.RelationshipTestModel.columns] parent = {'key': 'key', 'value_1': 'value_1', 'value_2': None} parents = [] for _ in range(related): diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index 1a07254..4d84095 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -63,10 +63,29 @@ def test_drop_column_error_on_primary_key(self, get_model, index_count): @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_create_table(self, get_model): get_model.return_value = None - new_model = models.SmallTestModel + new_model = models.UnittestModel + test_update = update.CreateTable(new_model) + test_update.validate() + + test_model_ddl = ('CREATE TABLE table (int_ INT64 NOT NULL, int_2 INT64,' + ' string STRING(MAX) NOT NULL, string_2 STRING(MAX),' + ' timestamp TIMESTAMP NOT NULL, string_array' + ' ARRAY) PRIMARY KEY (int_, string)') + self.assertEqual(test_update.ddl(), test_model_ddl) + + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') + def test_create_table_interleaved(self, get_model): + get_model.return_value = None + new_model = models.ChildTestModel test_update = update.CreateTable(new_model) test_update.validate() - self.assertEqual(test_update.ddl(), new_model.creation_ddl) + + test_model_ddl = ('CREATE TABLE ChildTestModel (' + 'key STRING(MAX) NOT NULL, ' + 'child_key STRING(MAX) NOT NULL) ' + 'PRIMARY KEY (key, child_key), ' + 'INTERLEAVE IN PARENT SmallTestModel ON CASCADE DELETE') + self.assertEqual(test_update.ddl(), test_model_ddl) @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') def test_create_table_error_on_existing_table(self, get_model):