Skip to content

Commit 5aedcc6

Browse files
committed
Implement most of functionality for includes
1 parent 5249af6 commit 5aedcc6

6 files changed

Lines changed: 332 additions & 91 deletions

File tree

spanner_orm/admin/update.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from spanner_orm import condition
2121
from spanner_orm import error
2222
from spanner_orm import field
23+
from spanner_orm import foreign_key_relationship
2324
from spanner_orm import model
2425
from spanner_orm.admin import api
2526
from spanner_orm.admin import index_column
@@ -55,16 +56,7 @@ def ddl(self) -> str:
5556
]
5657
key_fields_ddl = ', '.join(key_fields)
5758
for relation in self._model.foreign_key_relations.values():
58-
referencing_columns_ddl = ', '.join(relation.constraint.columns.keys())
59-
referenced_columns_ddl = ', '.join(relation.constraint.columns.values())
60-
key_fields_ddl += (
61-
', CONSTRAINT {fk_name} FOREIGN KEY ({referencing_columns}) REFERENCES'
62-
' {referenced_table} ({referenced_columns})').format(
63-
fk_name=relation.name,
64-
referencing_columns=referencing_columns_ddl,
65-
referenced_table=relation.constraint.referenced_table_name,
66-
referenced_columns=referenced_columns_ddl,
67-
)
59+
key_fields_ddl += f', {relation.ddl}'
6860
index_ddl = 'PRIMARY KEY ({})'.format(', '.join(self._model.primary_keys))
6961
statement = 'CREATE TABLE {} ({}) {}'.format(self._model.table,
7062
key_fields_ddl, index_ddl)
@@ -327,6 +319,58 @@ def validate(self) -> None:
327319
raise error.SpannerError('Index {} is the primary index'.format(
328320
self._index))
329321

322+
class AddForeignKeyRelationship(SchemaUpdate):
323+
"""Update for adding a column to an existing table.
324+
325+
Only supports adding nullable columns
326+
"""
327+
328+
def __init__(
329+
self,
330+
referencing_table_name: str,
331+
referenced_table_name: str,
332+
column_mapping,
333+
):
334+
self._table = table_name
335+
self._column = column_name
336+
self._field = field_
337+
338+
def ddl(self) -> str:
339+
return 'ALTER TABLE {} ADD'.format(self._table, self._column,
340+
self._field.ddl())
341+
342+
def validate(self) -> None:
343+
model_ = metadata.SpannerMetadata.model(self._table)
344+
if not model_:
345+
raise error.SpannerError('Table {} does not exist'.format(self._table))
346+
347+
348+
class DropForeignKeyRelationship(SchemaUpdate):
349+
"""Update for dropping a column from an existing table."""
350+
351+
def __init__(self, table_name: str, column_name: str):
352+
self._table = table_name
353+
self._column = column_name
354+
355+
def ddl(self) -> str:
356+
return 'ALTER TABLE {} DROP COLUMN {}'.format(self._table, self._column)
357+
358+
def validate(self) -> None:
359+
model_ = metadata.SpannerMetadata.model(self._table)
360+
if not model_:
361+
raise error.SpannerError('Table {} does not exist'.format(self._table))
362+
363+
if self._column not in model_.fields:
364+
raise error.SpannerError('Column {} does not exist on {}'.format(
365+
self._column, self._table))
366+
367+
# Verify no indices exist on the column we're trying to drop
368+
num_indexed_columns = index_column.IndexColumnSchema.count(
369+
None, condition.equal_to('column_name', self._column),
370+
condition.equal_to('table_name', self._table))
371+
if num_indexed_columns > 0:
372+
raise error.SpannerError('Column {} is indexed'.format(self._column))
373+
330374

331375
class NoUpdate(SchemaUpdate):
332376
"""Update that does nothing, for migrations that don't update db schemas."""

spanner_orm/condition.py

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from spanner_orm import error
2222
from spanner_orm import field
23+
from spanner_orm import foreign_key_relationship
2324
from spanner_orm import index
2425
from spanner_orm import relationship
2526

@@ -200,11 +201,38 @@ def _validate(self, model_class: Type[Any]) -> None:
200201
class IncludesCondition(Condition):
201202
"""Used to include related model_classs via a relation in a Spanner query."""
202203

203-
def __init__(self,
204-
relation_or_name: Union[relationship.Relationship, str],
205-
conditions: List[Condition] = None):
204+
def __init__(
205+
self,
206+
relation_or_name: Union[relationship.Relationship,
207+
foreign_key_relationship.ForeignKeyRelationship,
208+
str],
209+
conditions: List[Condition] = None,
210+
foreign_key_relation=False,
211+
):
212+
"""Initializer.
213+
214+
215+
Args:
216+
relation: Name of the relationship on the origin model or the Relationship/
217+
ForeignKeyRelationship on the origin model class used to retrieve
218+
associated objects
219+
conditions: Conditions to apply on the subquery
220+
foreign_key_relation: True if the relation is a foreign key relation,
221+
False if it is a legacy relation (eg not enforced in Spanner)
222+
"""
206223
super().__init__()
224+
self.foreign_key_relation = foreign_key_relation
207225
if isinstance(relation_or_name, relationship.Relationship):
226+
if foreign_key_relation:
227+
raise ValueError(
228+
'Must pass foreign key relation if ''`foreign_key_relation=True`.')
229+
self.name = relation_or_name.name
230+
self.relation = relation_or_name
231+
elif isinstance(relation_or_name,
232+
foreign_key_relationship.ForeignKeyRelationship):
233+
if not foreign_key_relation:
234+
raise ValueError(
235+
'Must pass legacy relation if `foreign_key_relation=False`.')
208236
self.name = relation_or_name.name
209237
self.relation = relation_or_name
210238
else:
@@ -214,7 +242,10 @@ def __init__(self,
214242

215243
def bind(self, model_class: Type[Any]) -> None:
216244
super().bind(model_class)
217-
self.relation = self.model_class.relations[self.name]
245+
if self.foreign_key_relation:
246+
self.relation = self.model_class.foreign_key_relations[self.name]
247+
else:
248+
self.relation = self.model_class.relations[self.name]
218249

219250
@property
220251
def conditions(self) -> List[Condition]:
@@ -223,21 +254,31 @@ def conditions(self) -> List[Condition]:
223254
raise error.SpannerError(
224255
'Condition must be bound before conditions is called')
225256
relation_conditions = []
226-
for constraint in self.relation.constraints:
227-
# This is backward from what you might imagine because the condition will
228-
# be processed from the context of the destination model
229-
relation_conditions.append(
230-
ColumnsEqualCondition(constraint.destination_column,
231-
constraint.origin_class,
232-
constraint.origin_column))
257+
if not self.foreign_key_relation:
258+
for constraint in self.relation.constraints:
259+
# This is backward from what you might imagine because the condition
260+
# will be processed from the context of the destination model.
261+
relation_conditions.append(
262+
ColumnsEqualCondition(constraint.destination_column,
263+
constraint.origin_class,
264+
constraint.origin_column))
265+
else:
266+
for pair in self.relation.constraint.columns.items():
267+
referencing_column, referenced_column = pair
268+
relation_conditions.append(
269+
ColumnsEqualCondition(referenced_column, self.model_class,
270+
referencing_column))
233271
return relation_conditions + self._conditions
234272

235273
@property
236274
def destination(self) -> Type[Any]:
237275
if not self.relation:
238276
raise error.SpannerError(
239277
'Condition must be bound before destination is called')
240-
return self.relation.destination
278+
if self.foreign_key_relation:
279+
return self.relation.constraint.referenced_table
280+
else:
281+
return self.relation.destination
241282

242283
@property
243284
def relation_name(self) -> str:
@@ -263,14 +304,25 @@ def _types(self) -> Dict[str, type_pb2.Type]:
263304
return {}
264305

265306
def _validate(self, model_class: Type[Any]) -> None:
266-
if self.name not in model_class.relations:
267-
raise error.ValidationError('{} is not a relation on {}'.format(
268-
self.name, model_class.table))
269-
if self.relation and self.relation != model_class.relations[self.name]:
270-
raise error.ValidationError('{} does not belong to {}'.format(
271-
self.relation.name, model_class.table))
307+
if self.foreign_key_relation:
308+
if self.name not in model_class.foreign_key_relations:
309+
raise error.ValidationError('{} is not a relation on {}'.format(
310+
self.name, model_class.table))
311+
if self.relation and self.relation != model_class.foreign_key_relations[
312+
self.name]:
313+
raise error.ValidationError('{} does not belong to {}'.format(
314+
self.relation.name, model_class.table))
315+
other_model_class = model_class.foreign_key_relations[
316+
self.name].constraint.referenced_table
317+
else:
318+
if self.name not in model_class.relations:
319+
raise error.ValidationError('{} is not a relation on {}'.format(
320+
self.name, model_class.table))
321+
if self.relation and self.relation != model_class.relations[self.name]:
322+
raise error.ValidationError('{} does not belong to {}'.format(
323+
self.relation.name, model_class.table))
324+
other_model_class = model_class.relations[self.name].destination
272325

273-
other_model_class = model_class.relations[self.name].destination
274326
for condition in self._conditions:
275327
condition._validate(other_model_class) # pylint: disable=protected-access
276328

@@ -629,8 +681,11 @@ def greater_than_or_equal_to(column: Union[field.Field, str],
629681
return ComparisonCondition('>=', column, value)
630682

631683

632-
def includes(relation: Union[relationship.Relationship, str],
633-
conditions: List[Condition] = None) -> IncludesCondition:
684+
def includes(relation: Union[relationship.Relationship,
685+
foreign_key_relationship.ForeignKeyRelationship,
686+
str],
687+
conditions: List[Condition] = None,
688+
foreign_key_relation: bool = False) -> IncludesCondition:
634689
"""Condition where the objects associated with a relationship are retrieved.
635690
636691
Note that the query formed by this call is not a JOIN, but instead a
@@ -639,14 +694,18 @@ def includes(relation: Union[relationship.Relationship, str],
639694
subquery may be included, but not all conditions may apply
640695
641696
Args:
642-
relation: Name of the relationship on the origin model or the Relationship
643-
on the origin model class used to retrievec associated objects
697+
relation: Name of the relationship on the origin model or the Relationship/
698+
ForeignKeyRelationship on the origin model class used to retrieve
699+
associated objects
644700
conditions: Conditions to apply on the subquery
701+
foreign_key_relation: True if the relation is a foreign key relation,
702+
False if it is a legacy relation (eg not enforced in Spanner)
645703
646704
Returns:
647705
A Condition subclass that will be used in the query
648706
"""
649-
return IncludesCondition(relation, conditions)
707+
return IncludesCondition(
708+
relation, conditions, foreign_key_relation)
650709

651710

652711
def in_list(column: Union[field.Field, str],

spanner_orm/foreign_key_relationship.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""Helps define a foreign key relationship between two models."""
1616

17-
from typing import Mapping
17+
from typing import Any, Mapping
1818

1919
import dataclasses
2020
from spanner_orm import registry
@@ -24,6 +24,7 @@
2424
class ForeignKeyRelationshipConstraint:
2525
columns: Mapping[str, str]
2626
referenced_table_name: str
27+
referenced_table: Any
2728

2829

2930
class ForeignKeyRelationship(object):
@@ -49,11 +50,29 @@ def __init__(self,
4950
def constraint(self) -> ForeignKeyRelationshipConstraint:
5051
return self._parse_constraint()
5152

53+
@property
54+
def ddl(self) -> str:
55+
referencing_columns_ddl = ', '.join(self.constraint.columns.keys())
56+
referenced_columns_ddl = ', '.join(self.constraint.columns.values())
57+
return (
58+
'CONSTRAINT {fk_name} FOREIGN KEY ({referencing_columns}) REFERENCES'
59+
' {referenced_table} ({referenced_columns})').format(
60+
fk_name=self.name,
61+
referencing_columns=referencing_columns_ddl,
62+
referenced_table=self.constraint.referenced_table_name,
63+
referenced_columns=referenced_columns_ddl,
64+
)
65+
5266
def _parse_constraint(self) -> ForeignKeyRelationshipConstraint:
5367
"""Return the relationship constraint."""
5468
referenced_table = registry.model_registry().get(
5569
self._referenced_table_name)
5670
return ForeignKeyRelationshipConstraint(
5771
self._columns,
5872
referenced_table.table,
73+
referenced_table,
5974
)
75+
76+
@property
77+
def single(self) -> bool:
78+
return True

spanner_orm/model.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,19 @@ def __new__(mcs, name: str, bases: Any, attrs: Dict[str, Any], **kwargs: Any):
7979

8080
def __getattr__(
8181
cls,
82-
name: str) -> Union[field.Field, relationship.Relationship, index.Index]:
82+
name: str) -> Union[
83+
field.Field,
84+
relationship.Relationship,
85+
foreign_key_relationship.ForeignKeyRelationship,
86+
index.Index]:
8387
# Unclear why pylint doesn't like this
8488
# pylint: disable=unsupported-membership-test
8589
if name in cls.fields:
8690
return cls.fields[name]
8791
elif name in cls.relations:
8892
return cls.relations[name]
93+
elif name in cls.foreign_key_relations:
94+
return cls.foreign_key_relations[name]
8995
elif name in cls.indexes:
9096
return cls.indexes[name]
9197
# pylint: enable=unsupported-membership-test
@@ -484,6 +490,10 @@ def __init__(self, values: Dict[str, Any], persisted: bool = False):
484490
if relation in values:
485491
self.__dict__[relation] = values[relation]
486492

493+
for foreign_key_relation in self._foreign_key_relations:
494+
if foreign_key_relation in values:
495+
self.__dict__[foreign_key_relation] = values[foreign_key_relation]
496+
487497
def __setattr__(self, name: str, value: Any) -> None:
488498
if name in self._relations:
489499
raise AttributeError(name)
@@ -513,6 +523,11 @@ def _primary_keys(self) -> List[str]:
513523
def _relations(self) -> Dict[str, relationship.Relationship]:
514524
return self._metaclass.relations
515525

526+
@property
527+
def _foreign_key_relations(
528+
self) -> Dict[str, foreign_key_relationship.ForeignKeyRelationship]:
529+
return self._metaclass.foreign_key_relations
530+
516531
@property
517532
def _table(self) -> str:
518533
return self._metaclass.table

0 commit comments

Comments
 (0)