88from __future__ import absolute_import , division , print_function , unicode_literals
99
1010from abc import ABCMeta , abstractmethod
11- from collections import OrderedDict
11+ from collections import OrderedDict , deque
1212import copy
1313import datetime
1414import math
@@ -728,38 +728,47 @@ def set_enumerated_subtypes(self, subtype_fields, is_catch_all):
728728 self ._is_catch_all = is_catch_all
729729 self ._enumerated_subtypes = []
730730
731+ if self .parent_type :
732+ raise InvalidSpec (
733+ "'%s' enumerates subtypes so it cannot extend another struct."
734+ % self .name , self ._token .lineno , self ._token .path )
735+
731736 # Require that if this struct enumerates subtypes, its parent (and thus
732737 # the entire hierarchy above this struct) does as well.
733738 if self .parent_type and not self .parent_type .has_enumerated_subtypes ():
734739 raise InvalidSpec (
735740 "'%s' cannot enumerate subtypes if parent '%s' does not." %
736- (self .name , self .parent_type .name ), self ._token .lineno )
741+ (self .name , self .parent_type .name ),
742+ self ._token .lineno , self ._token .path )
737743
738744 enumerated_subtype_names = set () # Set[str]
739745 for subtype_field in subtype_fields :
746+ path = subtype_field ._token .path
740747 lineno = subtype_field ._token .lineno
741748
742749 # Require that a subtype only has a single type tag.
743750 if subtype_field .data_type .name in enumerated_subtype_names :
744751 raise InvalidSpec (
745752 "Subtype '%s' can only be specified once." %
746- subtype_field .data_type .name , lineno )
753+ subtype_field .data_type .name , lineno , path )
747754
748755 # Require that a subtype has this struct as its parent.
749756 if subtype_field .data_type .parent_type != self :
750757 raise InvalidSpec (
751758 "'%s' is not a subtype of '%s'." %
752- (subtype_field .data_type .name , self .name ), lineno )
759+ (subtype_field .data_type .name , self .name ), lineno , path )
753760
754761 # Check for subtype tags that conflict with this struct's
755762 # non-inherited fields.
756763 if subtype_field .name in self ._fields_by_name :
757764 # Since the union definition comes first, use its line number
758765 # as the source of the field's original declaration.
766+ orig_field = self ._fields_by_name [subtype_field .name ]
759767 raise InvalidSpec (
760768 "Field '%s' already defined on line %d." %
761769 (subtype_field .name , lineno ),
762- self ._fields_by_name [subtype_field .name ]._token .lineno )
770+ orig_field ._token .lineno ,
771+ orig_field ._token .path )
763772
764773 # Walk up parent tree hierarchy to ensure no field conflicts.
765774 # Checks for conflicts with subtype tags and regular fields.
@@ -768,10 +777,10 @@ def set_enumerated_subtypes(self, subtype_fields, is_catch_all):
768777 if subtype_field .name in cur_type ._fields_by_name :
769778 orig_field = cur_type ._fields_by_name [subtype_field .name ]
770779 raise InvalidSpec (
771- "Field '%s' already defined in parent '%s' on line %d ."
780+ "Field '%s' already defined in parent '%s' (%s:%d) ."
772781 % (subtype_field .name , cur_type .name ,
773- orig_field ._token .lineno ),
774- lineno )
782+ orig_field ._token .path , orig_field . _token . lineno ),
783+ lineno , path )
775784 cur_type = cur_type .parent_type
776785
777786 # Note the discrepancy between `fields` which contains only the
@@ -795,9 +804,7 @@ def get_all_subtypes_with_tags(self):
795804 """
796805 Unlike other enumerated-subtypes-related functionality, this method
797806 returns not just direct subtypes, but all subtypes of this struct. The
798- tag of each subtype is the tag of the enumerated subtype from which it
799- descended, which means that it's likely that subtypes will share the
800- same tag.
807+ tag of each subtype is the list of tags from which the type descends.
801808
802809 This method only applies to structs that enumerate subtypes.
803810
@@ -806,18 +813,44 @@ def get_all_subtypes_with_tags(self):
806813 in the serialized format.
807814
808815 Returns:
809- List[Tuple[String, Struct]]
816+ List[Tuple[List[ String] , Struct]]
810817 """
811818 assert self .has_enumerated_subtypes (), 'Enumerated subtypes not set.'
812- subtypes_with_tags = [] # List[Tuple[String, Struct]]
813- for subtype_field in self .get_enumerated_subtypes ():
814- subtypes_with_tags .append (
815- (subtype_field .name , subtype_field .data_type ))
816- for subtype in subtype_field .data_type .subtypes :
817- subtypes_with_tags .append (
818- (subtype_field .name , subtype ))
819+ subtypes_with_tags = [] # List[Tuple[List[String], Struct]]
820+ fifo = deque ([subtype_field .data_type
821+ for subtype_field in self .get_enumerated_subtypes ()])
822+ # Traverse down the hierarchy registering subtypes as they're found.
823+ while fifo :
824+ data_type = fifo .popleft ()
825+ subtypes_with_tags .append ((data_type ._get_subtype_tags (), data_type ))
826+ if data_type .has_enumerated_subtypes ():
827+ for subtype_field in data_type .get_enumerated_subtypes ():
828+ fifo .append (subtype_field .data_type )
819829 return subtypes_with_tags
820830
831+ def _get_subtype_tags (self ):
832+ """
833+ Returns a list of type tags that refer to this type starting from the
834+ base of the struct hierarchy.
835+ """
836+ assert self .is_member_of_enumerated_subtypes_tree (), \
837+ 'Not a part of a subtypes tree.'
838+ cur = self .parent_type
839+ cur_dt = self
840+ tags = []
841+ while cur :
842+ assert cur .has_enumerated_subtypes ()
843+ for subtype_field in cur .get_enumerated_subtypes ():
844+ if subtype_field .data_type is cur_dt :
845+ tags .append (subtype_field .name )
846+ break
847+ else :
848+ assert False , 'Could not find?!'
849+ cur_dt = cur
850+ cur = cur .parent_type
851+ tags .reverse ()
852+ return tuple (tags )
853+
821854 def _add_example (self , example ):
822855 """Adds a "raw example" for this type.
823856
@@ -1009,7 +1042,7 @@ def _compute_example_flat_helper(self, label):
10091042
10101043 return Example (example .label , example .text , ex_val )
10111044
1012- def _compute_example_enumerated_subtypes (self , label , root = True ):
1045+ def _compute_example_enumerated_subtypes (self , label ):
10131046 """
10141047 Analogous to :meth:`_compute_example_flat_helper` but for structs with
10151048 enumerated subtypes.
@@ -1018,36 +1051,15 @@ def _compute_example_enumerated_subtypes(self, label, root=True):
10181051
10191052 example = self ._raw_examples [label ]
10201053
1021- if self .has_enumerated_subtypes ():
1022- tag , (ref , data_type ) = list (example .value .items ())[0 ]
1023- if not data_type ._has_example (ref .label ):
1024- raise InvalidSpec (
1025- "Reference to example for '%s' with label '%s' does not "
1026- "exist." % (data_type .name , ref .label ),
1027- ref .lineno , ref .path )
1028- flat_example , ex_sub_value = \
1029- data_type ._compute_example_enumerated_subtypes (ref .label , False )
1030- ex_value = OrderedDict ()
1031- ex_value [tag ] = ex_sub_value
1032- for field in self .fields :
1033- if field .name in flat_example .value :
1034- ex_value [field .name ] = flat_example .value [field .name ]
1035- del flat_example .value [field .name ]
1036- if root :
1037- return Example (label , example .text , ex_value )
1038- else :
1039- return flat_example , ex_value
1040- else :
1041- # If we're at a leaf of a subtypes tree, then compute the example
1042- # as if it were a flat struct. The caller is responsible for moving
1043- # fields into different nesting levels based on the subtypes tree.
1044- flat_example = self ._compute_example_flat_helper (label )
1045- ex_value = OrderedDict ()
1046- for field in self .fields :
1047- if field .name in flat_example .value :
1048- ex_value [field .name ] = flat_example .value [field .name ]
1049- del flat_example .value [field .name ]
1050- return flat_example , ex_value
1054+ tag , (ref , data_type ) = list (example .value .items ())[0 ]
1055+ if not data_type ._has_example (ref .label ):
1056+ raise InvalidSpec (
1057+ "Reference to example for '%s' with label '%s' does not "
1058+ "exist." % (data_type .name , ref .label ),
1059+ ref .lineno , ref .path )
1060+ flat_example = data_type ._compute_example_flat_helper (ref .label )
1061+ flat_example .value ['.tag' ] = tag
1062+ return flat_example
10511063
10521064 def __repr__ (self ):
10531065 return 'Struct(%r, %r)' % (self .name , self .fields )
@@ -1224,11 +1236,12 @@ def _compute_example(self, label):
12241236 # Find the field referenced by this tag.
12251237 for field in self .all_fields :
12261238 if tag == field .name :
1227- break
1239+ break
12281240 else :
12291241 raise AssertionError ('Unknown tag %r' % tag )
12301242
1231- dt , _ = get_underlying_type (field .data_type )
1243+ orig_dt , _ = get_underlying_type (field .data_type )
1244+ dt = orig_dt
12321245 list_nesting_count = 0
12331246 while is_list_type (dt ):
12341247 dt = dt .data_type
@@ -1246,7 +1259,11 @@ def _compute_example(self, label):
12461259 while list_nesting_count > 0 :
12471260 ex_val = [ex_val ]
12481261 list_nesting_count -= 1
1249- example_copy .value = {tag : ex_val }
1262+ if isinstance (orig_dt , Struct ) and not dt .has_enumerated_subtypes ():
1263+ ex_val .update ({'.tag' : tag })
1264+ example_copy .value = ex_val
1265+ else :
1266+ example_copy .value = {'.tag' : tag , tag : ex_val }
12501267
12511268 return example_copy
12521269
0 commit comments