2020
2121from spanner_orm import error
2222from spanner_orm import field
23+ from spanner_orm import foreign_key_relationship
2324from spanner_orm import index
2425from spanner_orm import relationship
2526
@@ -200,11 +201,38 @@ def _validate(self, model_class: Type[Any]) -> None:
200201class IncludesCondition (Condition ):
201202 """Used to include related model_classs via a relation in a Spanner query."""
202203
203- def __init__ (self ,
204- relation_or_name : Union [relationship .Relationship , str ],
205- conditions : List [Condition ] = None ):
204+ def __init__ (
205+ self ,
206+ relation_or_name : Union [relationship .Relationship ,
207+ foreign_key_relationship .ForeignKeyRelationship ,
208+ str ],
209+ conditions : List [Condition ] = None ,
210+ foreign_key_relation = False ,
211+ ):
212+ """Initializer.
213+
214+
215+ Args:
216+ relation: Name of the relationship on the origin model or the Relationship/
217+ ForeignKeyRelationship on the origin model class used to retrieve
218+ associated objects
219+ conditions: Conditions to apply on the subquery
220+ foreign_key_relation: True if the relation is a foreign key relation,
221+ False if it is a legacy relation (eg not enforced in Spanner)
222+ """
206223 super ().__init__ ()
224+ self .foreign_key_relation = foreign_key_relation
207225 if isinstance (relation_or_name , relationship .Relationship ):
226+ if foreign_key_relation :
227+ raise ValueError (
228+ 'Must pass foreign key relation if ' '`foreign_key_relation=True`.' )
229+ self .name = relation_or_name .name
230+ self .relation = relation_or_name
231+ elif isinstance (relation_or_name ,
232+ foreign_key_relationship .ForeignKeyRelationship ):
233+ if not foreign_key_relation :
234+ raise ValueError (
235+ 'Must pass legacy relation if `foreign_key_relation=False`.' )
208236 self .name = relation_or_name .name
209237 self .relation = relation_or_name
210238 else :
@@ -214,7 +242,10 @@ def __init__(self,
214242
215243 def bind (self , model_class : Type [Any ]) -> None :
216244 super ().bind (model_class )
217- self .relation = self .model_class .relations [self .name ]
245+ if self .foreign_key_relation :
246+ self .relation = self .model_class .foreign_key_relations [self .name ]
247+ else :
248+ self .relation = self .model_class .relations [self .name ]
218249
219250 @property
220251 def conditions (self ) -> List [Condition ]:
@@ -223,21 +254,31 @@ def conditions(self) -> List[Condition]:
223254 raise error .SpannerError (
224255 'Condition must be bound before conditions is called' )
225256 relation_conditions = []
226- for constraint in self .relation .constraints :
227- # This is backward from what you might imagine because the condition will
228- # be processed from the context of the destination model
229- relation_conditions .append (
230- ColumnsEqualCondition (constraint .destination_column ,
231- constraint .origin_class ,
232- constraint .origin_column ))
257+ if not self .foreign_key_relation :
258+ for constraint in self .relation .constraints :
259+ # This is backward from what you might imagine because the condition
260+ # will be processed from the context of the destination model.
261+ relation_conditions .append (
262+ ColumnsEqualCondition (constraint .destination_column ,
263+ constraint .origin_class ,
264+ constraint .origin_column ))
265+ else :
266+ for pair in self .relation .constraint .columns .items ():
267+ referencing_column , referenced_column = pair
268+ relation_conditions .append (
269+ ColumnsEqualCondition (referenced_column , self .model_class ,
270+ referencing_column ))
233271 return relation_conditions + self ._conditions
234272
235273 @property
236274 def destination (self ) -> Type [Any ]:
237275 if not self .relation :
238276 raise error .SpannerError (
239277 'Condition must be bound before destination is called' )
240- return self .relation .destination
278+ if self .foreign_key_relation :
279+ return self .relation .constraint .referenced_table
280+ else :
281+ return self .relation .destination
241282
242283 @property
243284 def relation_name (self ) -> str :
@@ -263,14 +304,25 @@ def _types(self) -> Dict[str, type_pb2.Type]:
263304 return {}
264305
265306 def _validate (self , model_class : Type [Any ]) -> None :
266- if self .name not in model_class .relations :
267- raise error .ValidationError ('{} is not a relation on {}' .format (
268- self .name , model_class .table ))
269- if self .relation and self .relation != model_class .relations [self .name ]:
270- raise error .ValidationError ('{} does not belong to {}' .format (
271- self .relation .name , model_class .table ))
307+ if self .foreign_key_relation :
308+ if self .name not in model_class .foreign_key_relations :
309+ raise error .ValidationError ('{} is not a relation on {}' .format (
310+ self .name , model_class .table ))
311+ if self .relation and self .relation != model_class .foreign_key_relations [
312+ self .name ]:
313+ raise error .ValidationError ('{} does not belong to {}' .format (
314+ self .relation .name , model_class .table ))
315+ other_model_class = model_class .foreign_key_relations [
316+ self .name ].constraint .referenced_table
317+ else :
318+ if self .name not in model_class .relations :
319+ raise error .ValidationError ('{} is not a relation on {}' .format (
320+ self .name , model_class .table ))
321+ if self .relation and self .relation != model_class .relations [self .name ]:
322+ raise error .ValidationError ('{} does not belong to {}' .format (
323+ self .relation .name , model_class .table ))
324+ other_model_class = model_class .relations [self .name ].destination
272325
273- other_model_class = model_class .relations [self .name ].destination
274326 for condition in self ._conditions :
275327 condition ._validate (other_model_class ) # pylint: disable=protected-access
276328
@@ -629,8 +681,11 @@ def greater_than_or_equal_to(column: Union[field.Field, str],
629681 return ComparisonCondition ('>=' , column , value )
630682
631683
632- def includes (relation : Union [relationship .Relationship , str ],
633- conditions : List [Condition ] = None ) -> IncludesCondition :
684+ def includes (relation : Union [relationship .Relationship ,
685+ foreign_key_relationship .ForeignKeyRelationship ,
686+ str ],
687+ conditions : List [Condition ] = None ,
688+ foreign_key_relation : bool = False ) -> IncludesCondition :
634689 """Condition where the objects associated with a relationship are retrieved.
635690
636691 Note that the query formed by this call is not a JOIN, but instead a
@@ -639,14 +694,18 @@ def includes(relation: Union[relationship.Relationship, str],
639694 subquery may be included, but not all conditions may apply
640695
641696 Args:
642- relation: Name of the relationship on the origin model or the Relationship
643- on the origin model class used to retrievec associated objects
697+ relation: Name of the relationship on the origin model or the Relationship/
698+ ForeignKeyRelationship on the origin model class used to retrieve
699+ associated objects
644700 conditions: Conditions to apply on the subquery
701+ foreign_key_relation: True if the relation is a foreign key relation,
702+ False if it is a legacy relation (eg not enforced in Spanner)
645703
646704 Returns:
647705 A Condition subclass that will be used in the query
648706 """
649- return IncludesCondition (relation , conditions )
707+ return IncludesCondition (
708+ relation , conditions , foreign_key_relation )
650709
651710
652711def in_list (column : Union [field .Field , str ],
0 commit comments