@@ -262,13 +262,6 @@ def _generate_struct_class_fields_for_reflection(self, data_type):
262262 def _generate_union_class_fields_for_reflection (self , data_type ):
263263 assert not data_type .super_type , 'Unsupported: Inheritance of unions'
264264
265- self .emit_line ('_field_names_ = {' )
266- with self .indent ():
267- for field in data_type .fields :
268- self .emit_line ("'{}'," .format (self .lang .format_variable (field .name )))
269- self .emit_line ('}' )
270- self .emit_empty_line ()
271-
272265 self .emit_line ('_fields_ = {' )
273266 with self .indent ():
274267 for field in data_type .fields :
@@ -443,21 +436,19 @@ def _generate_union_class(self, data_type):
443436 self ._generate_union_class_vars (data_type )
444437 self ._generate_union_class_fields_for_reflection (data_type )
445438 self ._generate_union_class_init (data_type )
446- self ._generate_union_class_symbol_creators (data_type )
439+ self ._generate_union_class_variant_creators (data_type )
447440 self ._generate_union_class_is_set (data_type )
448- self ._generate_union_class_properties (data_type )
441+ self ._generate_union_class_get_helpers (data_type )
449442 self ._generate_union_class_repr (data_type )
443+ self ._generate_union_class_symbol_creators (data_type )
450444
451445 def _generate_union_class_vars (self , data_type ):
452446 """
453- Each class has a class attribute for each field that is a primitive type.
454- Each class has a class attribute for each field that is a primitive type.
455- The attribute is a validator for the field.
447+ Each class has a class attribute for each field specifying its data type.
448+ If a catch all field exists, it's also specified here.
456449 """
457450 lineno = self .lineno
458451 for field in data_type .fields :
459- #if is_symbol_type(field.data_type):
460- # continue
461452 field_name = self .lang .format_variable (field .name )
462453 validator_name = self ._determine_validator_type (field .data_type )
463454 self .emit_line ('__{}_data_type = {}' .format (field_name ,
@@ -466,12 +457,19 @@ def _generate_union_class_vars(self, data_type):
466457 self .emit_line ('_catch_all_ = %r' % data_type .catch_all_field .name )
467458 else :
468459 self .emit_line ('_catch_all_ = None' )
460+
461+ for field in data_type .fields :
462+ if is_symbol_type (field .data_type ) or is_any_type (field .data_type ):
463+ field_name = self .lang .format_variable (field .name )
464+ self .emit_line ('# Attribute is overwritten below the class definition' )
465+ self .emit_line ('{} = None' .format (field_name ))
466+
469467 if lineno != self .lineno :
470468 self .emit_empty_line ()
471469
472470 def _generate_union_class_init (self , data_type ):
473471 """Generates the __init__ method for the class."""
474- self .emit_line ('def __init__(self):' )
472+ self .emit_line ('def __init__(self, tag, value=None ):' )
475473 with self .indent ():
476474 # Call the parent constructor if a super type exists
477475 if data_type .super_type :
@@ -480,29 +478,39 @@ def _generate_union_class_init(self, data_type):
480478
481479 for field in data_type .fields :
482480 field_var_name = self .lang .format_variable (field .name )
483- if not is_symbol_type (field .data_type ):
481+ if not is_symbol_type (field .data_type ) and not is_any_type ( field . data_type ) :
484482 self .emit_line ('self._{} = None' .format (field_var_name ))
485- self .emit_line ('self._tag = None' )
483+ self .emit_line ("assert tag in self._fields_, 'Invalid tag %r.' % tag" )
484+ self .emit_line ('if isinstance(self._fields_[tag], (dt.Any, dt.Symbol)):' )
485+ with self .indent ():
486+ self .emit_line (
487+ "assert value is None, 'Do not set a value for Symbol or Any variant.'" )
488+ self .emit_line ('else:' )
489+ with self .indent ():
490+ self .emit_line ('self._fields_[tag].validate(value)' )
491+ self .emit_line ('self._tag = tag' )
486492 self .emit_empty_line ()
487493
488- def _generate_union_class_symbol_creators (self , data_type ):
489- class_name = self .lang .format_class (data_type .name )
494+ def _generate_union_class_variant_creators (self , data_type ):
490495 for field in data_type .fields :
491- if is_symbol_type (field .data_type ) or is_any_type (field .data_type ):
496+ if not is_symbol_type (field .data_type ) and not is_any_type (field .data_type ):
492497 field_name = self .lang .format_method (field .name )
493498 self .emit_line ('@classmethod' )
494- self .emit_line ('def create_and_set_ {}(cls):' .format (field_name ))
499+ self .emit_line ('def {}(cls, val ):' .format (field_name ))
495500 with self .indent ():
496- self .emit_line ('"""' )
497- self .emit_wrapped_indented_lines (
498- ':rtype: {}' .format (class_name )
499- )
500- self .emit_line ('"""' )
501- self .emit_line ('c = cls()' )
502- self .emit_line ('c.set_{}()' .format (field_name ))
503- self .emit_line ('return c' )
501+ self .emit_line ('return cls({!r}, val)' .format (field_name ))
504502 self .emit_empty_line ()
505503
504+ def _generate_union_class_symbol_creators (self , data_type ):
505+ class_name = self .lang .format_class (data_type .name )
506+ lineno = self .lineno
507+ for field in data_type .fields :
508+ if is_symbol_type (field .data_type ) or is_any_type (field .data_type ):
509+ field_name = self .lang .format_method (field .name )
510+ self .emit_line ('{0}.{1} = {0}({1!r})' .format (class_name , field_name ))
511+ if lineno != self .lineno :
512+ self .emit_empty_line ()
513+
506514 def _generate_union_class_is_set (self , data_type ):
507515 for field in data_type .fields :
508516 field_name = self .lang .format_method (field .name )
@@ -511,41 +519,19 @@ def _generate_union_class_is_set(self, data_type):
511519 self .emit_line ('return self._tag == {!r}' .format (field_name ))
512520 self .emit_empty_line ()
513521
514- def _generate_union_class_properties (self , data_type ):
522+ def _generate_union_class_get_helpers (self , data_type ):
515523 for field in data_type .fields :
516524 field_name = self .lang .format_method (field .name )
517525
518- if is_symbol_type (field .data_type ) or is_any_type (field .data_type ):
519- self .emit_line ('def set_{}(self):' .format (field_name ))
526+ if not is_symbol_type (field .data_type ) and not is_any_type (field .data_type ):
527+ # generate getter for field
528+ self .emit_line ('def get_{}(self):' .format (field_name ))
520529 with self .indent ():
521- self .emit_line ('self._tag = {!r}' .format (field_name ))
522- self .emit_empty_line ()
523- continue
524-
525- # generate getter for field
526- self .emit_line ('@property' )
527- self .emit_line ('def {}(self):' .format (field_name ))
528- with self .indent ():
529- self .emit_line ('if not self.is_{}():' .format (field_name ))
530- with self .indent ():
531- self .emit_line ('raise AttributeError("tag {!r} not set")' .format (field_name ))
532- if is_symbol_type (field .data_type ):
533- self .emit_line ('return {!r}' .format (field_name ))
534- else :
530+ self .emit_line ('if not self.is_{}():' .format (field_name ))
531+ with self .indent ():
532+ self .emit_line ('raise AttributeError("tag {!r} not set")' .format (field_name ))
535533 self .emit_line ('return self._{}' .format (field_name ))
536- self .emit_empty_line ()
537-
538- # generate setter for field
539- self .emit_line ('@{}.setter' .format (field_name ))
540- self .emit_line ('def {}(self, val):' .format (field_name ))
541- with self .indent ():
542- if is_composite_type (field .data_type ):
543- self .emit_line ('self.__{}_data_type.validate_type_only(val)' .format (field_name ))
544- else :
545- self .emit_line ('val = self.__{}_data_type.validate(val)' .format (field_name ))
546- self .emit_line ('self._{} = val' .format (field_name ))
547- self .emit_line ('self._tag = {!r}' .format (field_name ))
548- self .emit_empty_line ()
534+ self .emit_empty_line ()
549535
550536 def _generate_union_class_repr (self , data_type ):
551537 # The special __repr__() function will return a string of the class
0 commit comments