Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 60 additions & 31 deletions spanner_orm/admin/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,27 @@ class CreateTable(SchemaUpdate):

def __init__(self, model):
self._model = model
self._existing_model = metadata.SpannerMetadata.model(model.table)

def ddl(self):
return self._model.creation_ddl
fields = [
'{} {}'.format(name, field.ddl())
for name, field in self._model.schema.items()
]
index_ddl = 'PRIMARY KEY ({})'.format(', '.join(self._model.primary_keys))
statement = 'CREATE TABLE {} ({}) {}'.format(self._model.table,
', '.join(fields), index_ddl)

if self._model.interleaved:
statement += ', INTERLEAVE IN PARENT {parent} ON CASCADE DELETE'.format(
parent=self._model.interleaved.table)
return statement

def validate(self):
if not self._model.table:
raise error.SpannerError('New table has no name')

if self._existing_model:
existing_model = metadata.SpannerMetadata.model(self._model.table)
if existing_model:
raise error.SpannerError('Table {} already exists'.format(
self._model.table))

Expand All @@ -71,7 +82,7 @@ def _validate_parent(self):
self._model.table, self._model.interleaved.table)
for parent_key, key in zip(parent_primary_keys, primary_keys):
if parent_key != key:
raise error.SpnnerError(message)
raise error.SpannerError(message)
if len(parent_primary_keys) > len(primary_keys):
raise error.SpannerError(message)

Expand All @@ -92,25 +103,25 @@ class DropTable(SchemaUpdate):

def __init__(self, table_name):
self._table = table_name
self._existing_model = metadata.SpannerMetadata.model(table_name)

def ddl(self):
return 'DROP TABLE {}'.format(self._table)

def validate(self):
if not self._existing_model:
existing_model = metadata.SpannerMetadata.model(self._table)
if not existing_model:
raise error.SpannerError('Table {} does not exist'.format(self._table))

# Model indexes include the primary index
if len(self._existing_model.indexes) > 1:
if len(existing_model.indexes) > 1:
raise error.SpannerError('Table {} has a secondary index'.format(
self._table))

self._validate_not_interleaved()

def _validate_not_interleaved(self):
for model in metadata.SpannerMetadata.models().values():
if model.interleaved == self._existing_model:
if model.interleaved == existing_model:
raise error.SpannerError('Table {} has interleaved table {}'.format(
self._table, model.table))
for index in model.indexes.values():
Expand All @@ -129,14 +140,14 @@ def __init__(self, table_name, column_name, field):
self._table = table_name
self._column = column_name
self._field = field
self._model = metadata.SpannerMetadata.model(table_name)

def ddl(self):
return 'ALTER TABLE {} ADD COLUMN {} {}'.format(self._table, self._column,
self._field.ddl())

def validate(self):
if not self._model:
model = metadata.SpannerMetadata.model(self._table)
if not model:
raise error.SpannerError('Table {} does not exist'.format(self._table))
if not self._field.nullable():
raise error.SpannerError('Column {} is not nullable'.format(self._column))
Expand All @@ -151,16 +162,16 @@ class DropColumn(SchemaUpdate):
def __init__(self, table_name, column_name):
self._table = table_name
self._column = column_name
self._model = metadata.SpannerMetadata.model(table_name)

def ddl(self):
return 'ALTER TABLE {} DROP COLUMN {}'.format(self._table, self._column)

def validate(self):
if not self._model:
model = metadata.SpannerMetadata.model(self._table)
if not model:
raise error.SpannerError('Table {} does not exist'.format(self._table))

if self._column not in self._model.schema:
if self._column not in model.schema:
raise error.SpannerError('Column {} does not exist on {}'.format(
self._column, self._table))

Expand All @@ -182,25 +193,25 @@ def __init__(self, table_name, column_name, field):
self._table = table_name
self._column = column_name
self._field = field
self._model = metadata.SpannerMetadata.model(table_name)

def ddl(self):
return 'ALTER TABLE {} ALTER COLUMN {} {}'.format(self._table, self._column,
self._field.ddl())

def validate(self):
if not self._model:
model = metadata.SpannerMetadata.model(self._table)
if not model:
raise error.SpannerError('Table {} does not exist'.format(self._table))

if self._column not in self._model.schema:
if self._column not in model.schema:
raise error.SpannerError('Column {} does not exist on {}'.format(
self._column, self._table))

if self._column in self._model.primary_keys:
if self._column in model.primary_keys:
raise error.SpannerError('Column {} is a primary key on {}'.format(
self._column, self._table))

old_field = self._model.schema[self._column]
old_field = model.schema[self._column]
# Validate that the only alteration is to change column nullability
if self._field.field_type() != old_field.field_type():
raise error.SpannerError('Column {} is changing type'.format(
Expand All @@ -223,7 +234,6 @@ def __init__(self,
self._columns = columns
self._parent_table = interleaved
self._storing_columns = storing_columns or []
self._model = metadata.SpannerMetadata.model(table_name)

def ddl(self):
statement = 'CREATE INDEX {} ON {} ({})'.format(self._index, self._table,
Expand All @@ -235,40 +245,42 @@ def ddl(self):
return statement

def validate(self):
if not self._model:
model = metadata.SpannerMetadata.model(self._table)
if not model:
raise error.SpannerError('Table {} does not exist'.format(self._table))

if not self._columns:
raise error.SpannerError('Index {} has no columns'.format(self._index))

if self._index in self._model.indexes:
if self._index in model.indexes:
raise error.SpannerError('Index {} already exists'.format(self._index))

self._validate_columns()
self._validate_columns(model)

if self._parent_table:
self._validate_parent()
self._validate_parent(model)

def _validate_columns(self):
def _validate_columns(self, model):
for column in self._columns:
if not column in self._model.columns:
if not column in model.columns:
raise error.SpannerError('Table {} has no column {}'.format(
self._table, column))

for column in self._storing_columns:
if not column in self._model.columns:
if not column in model.columns:
raise error.SpannerError('Table {} has no column {}'.format(
self._table, column))
if column in self._model.primary_keys:
if column in model.primary_keys:
raise error.SpannerError('{} is part of the primary key for {}'.format(
column, self._table))

def validate_parent(self):
parent = self._model.interleaved
def validate_parent(self, model):
parent = model.interleaved
while parent:
if parent == self._parent_table:
break
parent = parent.interleaved

if not parent:
raise error.SpannerError('{} is not a parent of table {}'.format(
self._parent_table, self._table))
Expand All @@ -280,13 +292,13 @@ class DropIndex(SchemaUpdate):
def __init__(self, table_name, index_name):
self._table = table_name
self._index = index_name
self._model = metadata.SpannerMetadata.model(table_name)

def ddl(self):
return 'DROP INDEX {}'.format(self._index)

def validate(self):
if not self._model:
model = metadata.SpannerMetadata.model(self._table)
if not model:
raise error.SpannerError('Table {} does not exist'.format(self._table))

db_index = self._model.indexes.get(self._index)
Expand All @@ -308,3 +320,20 @@ def execute(self):

def validate(self):
pass


def model_creation_ddl(model):
ddl_list = [CreateTable(model).ddl()]

for model_index in model.indexes:
if model_index.primary_key:
continue
create_index = CreateIndex(
model.table,
model_index.name,
model_index.columns,
interleaved=model_index.parent,
storing_columns=model_index.storing_columns)
ddl_list.append(create_index.ddl())

return ddl_list
16 changes: 0 additions & 16 deletions spanner_orm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,22 +138,6 @@ def __getattr__(cls, name):
return cls.indexes[name]
raise AttributeError(name)

@property
def creation_ddl(cls):
fields = [
'{} {}'.format(name, field.ddl()) for name, field in cls.schema.items()
]
field_ddl = '({})'.format(', '.join(fields))
index_ddl = 'PRIMARY KEY ({})'.format(', '.join(cls.primary_keys))
trailing_statements = [index_ddl]
if cls.interleaved:
interleave_ddl = 'INTERLEAVE IN PARENT {parent} ON CASCADE DELETE'.format(
parent=cls.interleaved.table)
trailing_statements.append(interleave_ddl)
trailing_ddl = ', '.join(trailing_statements)
return 'CREATE TABLE {table_name} {fields} {trailing}'.format(
table_name=cls.table, fields=field_ddl, trailing=trailing_ddl)

@property
def column_prefix(cls):
return cls.table.split('.')[-1]
Expand Down
19 changes: 2 additions & 17 deletions spanner_orm/tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,6 @@ def test_object_changes(self):
test_model.string_array.append('bat')
self.assertIn('string_array', test_model.changes())

def test_creation_ddl(self):
test_model_ddl = ('CREATE TABLE table (int_ INT64 NOT NULL, int_2 INT64,'
' string STRING(MAX) NOT NULL, string_2 STRING(MAX),'
' timestamp TIMESTAMP NOT NULL, string_array'
' ARRAY<STRING(MAX)>) PRIMARY KEY (int_, string)')
self.assertEqual(models.UnittestModel.creation_ddl, test_model_ddl)

def test_interleaved_creation_ddl(self):
test_model_ddl = ('CREATE TABLE ChildTestModel ('
'parent_key STRING(MAX) NOT NULL, '
'child_key STRING(MAX) NOT NULL) '
'PRIMARY KEY (parent_key, child_key), '
'INTERLEAVE IN PARENT SmallTestModel ON CASCADE DELETE')
self.assertEqual(models.ChildTestModel.creation_ddl, test_model_ddl)

def test_field_exists_on_model_class(self):
self.assertIsInstance(models.SmallTestModel.key, field.Field)
self.assertEqual(models.SmallTestModel.key.field_type(), field.String)
Expand All @@ -122,15 +107,15 @@ def test_field_inheritance(self):
self.assertEqual(getattr(test_model, name), value)

def test_relation_get(self):
test_model = models.ChildTestModel({
test_model = models.RelationshipTestModel({
'parent_key': 'parent',
'child_key': 'child',
'parent': []
})
self.assertEqual(test_model.parent, [])

def test_relation_get_error_on_unretrieved(self):
test_model = models.ChildTestModel({
test_model = models.RelationshipTestModel({
'parent_key': 'parent',
'child_key': 'child'
})
Expand Down
18 changes: 13 additions & 5 deletions spanner_orm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class SmallTestModel(model.Model):
"""Model class used for testing"""
"""Model class used for testing."""

__table__ = 'SmallTestModel'
key = field.Field(field.String, primary_key=True)
Expand All @@ -30,10 +30,18 @@ class SmallTestModel(model.Model):


class ChildTestModel(model.Model):
"""Model class for testing relationships"""

"""Model class for testing interleaved tables."""
__table__ = 'ChildTestModel'
__interleaved__ = SmallTestModel

key = field.Field(field.String, primary_key=True)
child_key = field.Field(field.String, primary_key=True)


class RelationshipTestModel(model.Model):
"""Model class for testing relationships."""

__table__ = 'RelationshipTestModel'
parent_key = field.Field(field.String, primary_key=True)
child_key = field.Field(field.String, primary_key=True)
parent = relationship.Relationship(
Expand All @@ -44,12 +52,12 @@ class ChildTestModel(model.Model):


class InheritanceTestModel(SmallTestModel):
"""Model class used for testing model inheritance"""
"""Model class used for testing model inheritance."""
value_3 = field.Field(field.String, nullable=True)


class UnittestModel(model.Model):
"""Model class used for model testing"""
"""Model class used for model testing."""

__table__ = 'table'
int_ = field.Field(field.Integer, primary_key=True)
Expand Down
10 changes: 5 additions & 5 deletions spanner_orm/tests/query_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,30 +156,30 @@ def test_force_index(self):

def includes(self, relation, *conditions):
include_condition = condition.includes(relation, list(conditions))
return query.SelectQuery(models.ChildTestModel, [include_condition])
return query.SelectQuery(models.RelationshipTestModel, [include_condition])

def test_includes(self):
select_query = self.includes('parent')

# The column order varies between test runs
expected_sql = (
r'SELECT ChildTestModel\S* ChildTestModel\S* ARRAY\(SELECT AS '
r'SELECT RelationshipTestModel\S* RelationshipTestModel\S* ARRAY\(SELECT AS '
r'STRUCT SmallTestModel\S* SmallTestModel\S* SmallTestModel\S* FROM '
r'SmallTestModel WHERE SmallTestModel.key = '
r'ChildTestModel.parent_key\)')
r'RelationshipTestModel.parent_key\)')
self.assertRegex(select_query.sql(), expected_sql)
self.assertEmpty(select_query.parameters())
self.assertEmpty(select_query.types())

def test_includes_subconditions_query(self):
select_query = self.includes('parents', condition.equal_to('key', 'value'))
expected_sql = ('WHERE SmallTestModel.key = ChildTestModel.parent_key '
expected_sql = ('WHERE SmallTestModel.key = RelationshipTestModel.parent_key '
'AND SmallTestModel.key = @key0')
self.assertRegex(select_query.sql(), expected_sql)

def includes_result(self, related=1):
child = {'parent_key': 'parent_key', 'child_key': 'child'}
result = [child[name] for name in models.ChildTestModel.columns]
result = [child[name] for name in models.RelationshipTestModel.columns]
parent = {'key': 'key', 'value_1': 'value_1', 'value_2': None}
parents = []
for _ in range(related):
Expand Down
Loading