Skip to content

Commit 8c9efb1

Browse files
committed
Changed how union types should be instantiated.
Variants of symbol/any type are accessible as class attributes that return an instance of the class with the tag set. Variants of all other types are accessible as class methods.
1 parent e26199b commit 8c9efb1

File tree

3 files changed

+86
-89
lines changed

3 files changed

+86
-89
lines changed

babelapi/generator/target/python/babel_serializers.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _json_encode_helper(data_type, obj, needs_validation=True):
108108
if isinstance(field_data_type, (dt.Any, dt.Symbol)):
109109
return obj._tag
110110
else:
111-
val = getattr(obj, obj._tag)
111+
val = getattr(obj, '_'+obj._tag)
112112
return {obj._tag: _json_encode_helper(field_data_type, val, False)}
113113
else:
114114
return obj._tag
@@ -166,7 +166,7 @@ def _json_decode_helper(data_type, obj, strict, validate_primitives=True):
166166
setattr(o, name, v)
167167
data_type.validate(o)
168168
elif isinstance(data_type, dt.Union):
169-
o = data_type.definition()
169+
val = None # Symbols do not have values
170170
if isinstance(obj, six.string_types):
171171
# Variant is a symbol
172172
tag = obj
@@ -180,32 +180,28 @@ def _json_decode_helper(data_type, obj, strict, validate_primitives=True):
180180
tag = data_type.definition._catch_all_
181181
else:
182182
raise dt.ValidationError("unknown tag '%s'" % tag)
183-
o._tag = tag
184183
elif isinstance(obj, dict):
185184
# Variant is not a symbol
186185
if len(obj) != 1:
187186
raise dt.ValidationError('expected 1 key, got %s', len(obj))
188187
tag = list(obj)[0]
189-
val = obj[tag]
188+
raw_val = obj[tag]
190189
if tag in data_type.definition._fields_:
191190
val_data_type = data_type.definition._fields_[tag]
192-
if isinstance(val_data_type, dt.Any):
193-
o._tag = tag
194-
elif isinstance(val_data_type, dt.Symbol):
191+
if isinstance(val_data_type, dt.Symbol):
195192
raise dt.ValidationError("expected symbol '%s', got object"
196193
% tag)
197-
else:
198-
v = _json_decode_helper(val_data_type, val, strict, False)
199-
setattr(o, tag, v)
194+
elif not isinstance(val_data_type, dt.Any):
195+
val = _json_decode_helper(val_data_type, raw_val, strict, False)
200196
else:
201197
if not strict and data_type.definition._catch_all_:
202198
tag = data_type.definition._catch_all_
203-
o._tag = tag
204199
else:
205200
raise dt.ValidationError("unknown tag '%s'" % tag)
206201
else:
207202
raise dt.ValidationError("expected string or object, got %s"
208203
% dt.generic_type_name((obj)))
204+
o = data_type.definition(tag, val)
209205
elif isinstance(data_type, dt.List):
210206
if not isinstance(obj, list):
211207
raise dt.ValidationError(

babelapi/generator/target/python/python.babelg.py

Lines changed: 45 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -262,13 +262,6 @@ def _generate_struct_class_fields_for_reflection(self, data_type):
262262
def _generate_union_class_fields_for_reflection(self, data_type):
263263
assert not data_type.super_type, 'Unsupported: Inheritance of unions'
264264

265-
self.emit_line('_field_names_ = {')
266-
with self.indent():
267-
for field in data_type.fields:
268-
self.emit_line("'{}',".format(self.lang.format_variable(field.name)))
269-
self.emit_line('}')
270-
self.emit_empty_line()
271-
272265
self.emit_line('_fields_ = {')
273266
with self.indent():
274267
for field in data_type.fields:
@@ -443,21 +436,19 @@ def _generate_union_class(self, data_type):
443436
self._generate_union_class_vars(data_type)
444437
self._generate_union_class_fields_for_reflection(data_type)
445438
self._generate_union_class_init(data_type)
446-
self._generate_union_class_symbol_creators(data_type)
439+
self._generate_union_class_variant_creators(data_type)
447440
self._generate_union_class_is_set(data_type)
448-
self._generate_union_class_properties(data_type)
441+
self._generate_union_class_get_helpers(data_type)
449442
self._generate_union_class_repr(data_type)
443+
self._generate_union_class_symbol_creators(data_type)
450444

451445
def _generate_union_class_vars(self, data_type):
452446
"""
453-
Each class has a class attribute for each field that is a primitive type.
454-
Each class has a class attribute for each field that is a primitive type.
455-
The attribute is a validator for the field.
447+
Each class has a class attribute for each field specifying its data type.
448+
If a catch all field exists, it's also specified here.
456449
"""
457450
lineno = self.lineno
458451
for field in data_type.fields:
459-
#if is_symbol_type(field.data_type):
460-
# continue
461452
field_name = self.lang.format_variable(field.name)
462453
validator_name = self._determine_validator_type(field.data_type)
463454
self.emit_line('__{}_data_type = {}'.format(field_name,
@@ -466,12 +457,19 @@ def _generate_union_class_vars(self, data_type):
466457
self.emit_line('_catch_all_ = %r' % data_type.catch_all_field.name)
467458
else:
468459
self.emit_line('_catch_all_ = None')
460+
461+
for field in data_type.fields:
462+
if is_symbol_type(field.data_type) or is_any_type(field.data_type):
463+
field_name = self.lang.format_variable(field.name)
464+
self.emit_line('# Attribute is overwritten below the class definition')
465+
self.emit_line('{} = None'.format(field_name))
466+
469467
if lineno != self.lineno:
470468
self.emit_empty_line()
471469

472470
def _generate_union_class_init(self, data_type):
473471
"""Generates the __init__ method for the class."""
474-
self.emit_line('def __init__(self):')
472+
self.emit_line('def __init__(self, tag, value=None):')
475473
with self.indent():
476474
# Call the parent constructor if a super type exists
477475
if data_type.super_type:
@@ -480,29 +478,39 @@ def _generate_union_class_init(self, data_type):
480478

481479
for field in data_type.fields:
482480
field_var_name = self.lang.format_variable(field.name)
483-
if not is_symbol_type(field.data_type):
481+
if not is_symbol_type(field.data_type) and not is_any_type(field.data_type):
484482
self.emit_line('self._{} = None'.format(field_var_name))
485-
self.emit_line('self._tag = None')
483+
self.emit_line("assert tag in self._fields_, 'Invalid tag %r.' % tag")
484+
self.emit_line('if isinstance(self._fields_[tag], (dt.Any, dt.Symbol)):')
485+
with self.indent():
486+
self.emit_line(
487+
"assert value is None, 'Do not set a value for Symbol or Any variant.'")
488+
self.emit_line('else:')
489+
with self.indent():
490+
self.emit_line('self._fields_[tag].validate(value)')
491+
self.emit_line('self._tag = tag')
486492
self.emit_empty_line()
487493

488-
def _generate_union_class_symbol_creators(self, data_type):
489-
class_name = self.lang.format_class(data_type.name)
494+
def _generate_union_class_variant_creators(self, data_type):
490495
for field in data_type.fields:
491-
if is_symbol_type(field.data_type) or is_any_type(field.data_type):
496+
if not is_symbol_type(field.data_type) and not is_any_type(field.data_type):
492497
field_name = self.lang.format_method(field.name)
493498
self.emit_line('@classmethod')
494-
self.emit_line('def create_and_set_{}(cls):'.format(field_name))
499+
self.emit_line('def {}(cls, val):'.format(field_name))
495500
with self.indent():
496-
self.emit_line('"""')
497-
self.emit_wrapped_indented_lines(
498-
':rtype: {}'.format(class_name)
499-
)
500-
self.emit_line('"""')
501-
self.emit_line('c = cls()')
502-
self.emit_line('c.set_{}()'.format(field_name))
503-
self.emit_line('return c')
501+
self.emit_line('return cls({!r}, val)'.format(field_name))
504502
self.emit_empty_line()
505503

504+
def _generate_union_class_symbol_creators(self, data_type):
505+
class_name = self.lang.format_class(data_type.name)
506+
lineno = self.lineno
507+
for field in data_type.fields:
508+
if is_symbol_type(field.data_type) or is_any_type(field.data_type):
509+
field_name = self.lang.format_method(field.name)
510+
self.emit_line('{0}.{1} = {0}({1!r})'.format(class_name, field_name))
511+
if lineno != self.lineno:
512+
self.emit_empty_line()
513+
506514
def _generate_union_class_is_set(self, data_type):
507515
for field in data_type.fields:
508516
field_name = self.lang.format_method(field.name)
@@ -511,41 +519,19 @@ def _generate_union_class_is_set(self, data_type):
511519
self.emit_line('return self._tag == {!r}'.format(field_name))
512520
self.emit_empty_line()
513521

514-
def _generate_union_class_properties(self, data_type):
522+
def _generate_union_class_get_helpers(self, data_type):
515523
for field in data_type.fields:
516524
field_name = self.lang.format_method(field.name)
517525

518-
if is_symbol_type(field.data_type) or is_any_type(field.data_type):
519-
self.emit_line('def set_{}(self):'.format(field_name))
526+
if not is_symbol_type(field.data_type) and not is_any_type(field.data_type):
527+
# generate getter for field
528+
self.emit_line('def get_{}(self):'.format(field_name))
520529
with self.indent():
521-
self.emit_line('self._tag = {!r}'.format(field_name))
522-
self.emit_empty_line()
523-
continue
524-
525-
# generate getter for field
526-
self.emit_line('@property')
527-
self.emit_line('def {}(self):'.format(field_name))
528-
with self.indent():
529-
self.emit_line('if not self.is_{}():'.format(field_name))
530-
with self.indent():
531-
self.emit_line('raise AttributeError("tag {!r} not set")'.format(field_name))
532-
if is_symbol_type(field.data_type):
533-
self.emit_line('return {!r}'.format(field_name))
534-
else:
530+
self.emit_line('if not self.is_{}():'.format(field_name))
531+
with self.indent():
532+
self.emit_line('raise AttributeError("tag {!r} not set")'.format(field_name))
535533
self.emit_line('return self._{}'.format(field_name))
536-
self.emit_empty_line()
537-
538-
# generate setter for field
539-
self.emit_line('@{}.setter'.format(field_name))
540-
self.emit_line('def {}(self, val):'.format(field_name))
541-
with self.indent():
542-
if is_composite_type(field.data_type):
543-
self.emit_line('self.__{}_data_type.validate_type_only(val)'.format(field_name))
544-
else:
545-
self.emit_line('val = self.__{}_data_type.validate(val)'.format(field_name))
546-
self.emit_line('self._{} = val'.format(field_name))
547-
self.emit_line('self._tag = {!r}'.format(field_name))
548-
self.emit_empty_line()
534+
self.emit_empty_line()
549535

550536
def _generate_union_class_repr(self, data_type):
551537
# The special __repr__() function will return a string of the class

test/test_python_gen.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -113,33 +113,40 @@ class U(object):
113113
'b': dt.Symbol(),
114114
'c': dt.Struct(S),
115115
'd': dt.List(dt.Int64())}
116+
_tag = None
117+
def __init__(self, tag, value=None):
118+
self._tag = tag
119+
setattr(self, '_' + tag, value)
120+
def get_a(self):
121+
return self._a
122+
def get_c(self):
123+
return self._c
124+
def get_d(self):
125+
return self._d
126+
127+
U.b = U('b')
116128

117129
# Test primitive variant
118-
u = U()
119-
u._tag = 'a'
120-
u.a = 64
130+
u = U('a', 64)
121131
self.assertEqual(json_encode(dt.Union(U), u), json.dumps({'a': 64}))
122132

123133
# Test symbol variant
124-
u = U()
125-
u._tag = 'b'
134+
u = U('b')
126135
self.assertEqual(json_encode(dt.Union(U), u), json.dumps('b'))
127136

128137
# Test struct variant
129-
u = U()
130-
u._tag = 'c'
131-
u.c = S()
132-
u.c.f = 'hello'
138+
c = S()
139+
c.f = 'hello'
140+
u = U('c', c)
133141
self.assertEqual(json_encode(dt.Union(U), u), json.dumps({'c': {'f': 'hello'}}))
134142

135143
# Test list variant
136-
u = U()
137-
u._tag = 'd'
138-
u.d = [1, 2, 3, 'a']
144+
u = U('d', [1, 2, 3, 'a'])
139145
# lists should be re-validated during serialization
140146
self.assertRaises(dt.ValidationError, lambda: json_encode(dt.Union(U), u))
141-
u.d = [1, 2, 3, 4]
142-
self.assertEqual(json_encode(dt.Union(U), u), json.dumps({'d': u.d}))
147+
l = [1, 2, 3, 4]
148+
u = U('d', [1, 2, 3, 4])
149+
self.assertEqual(json_encode(dt.Union(U), u), json.dumps({'d': l}))
143150

144151
def test_json_decoder(self):
145152
self.assertEqual(json_decode(dt.String(), json.dumps('abc')), 'abc')
@@ -175,12 +182,20 @@ class U(object):
175182
'd': dt.List(dt.Int64())}
176183
_catch_all_ = 'b'
177184
_tag = None
178-
def set_b(self):
179-
self._tag = 'b'
185+
def __init__(self, tag, value=None):
186+
self._tag = tag
187+
setattr(self, '_' + tag, value)
188+
def get_a(self):
189+
return self._a
190+
def get_c(self):
191+
return self._c
192+
def get_d(self):
193+
return self._d
194+
U.b = U('b')
180195

181196
# Test primitive variant
182197
u = json_decode(dt.Union(U), json.dumps({'a': 64}))
183-
self.assertEqual(u.a, 64)
198+
self.assertEqual(u.get_a(), 64)
184199

185200
# Test symbol variant
186201
u = json_decode(dt.Union(U), json.dumps('b'))
@@ -190,14 +205,14 @@ def set_b(self):
190205

191206
# Test struct variant
192207
u = json_decode(dt.Union(U), json.dumps({'c': {'f': 'hello'}}))
193-
self.assertEqual(u.c.f, 'hello')
208+
self.assertEqual(u.get_c().f, 'hello')
194209
self.assertRaises(dt.ValidationError,
195210
lambda: json_decode(dt.Union(U), json.dumps({'c': [1,2,3]})))
196211

197212
# Test list variant
198213
l = [1, 2, 3, 4]
199214
u = json_decode(dt.Union(U), json.dumps({'d': l}))
200-
self.assertEqual(u.d, l)
215+
self.assertEqual(u.get_d(), l)
201216

202217
# Raises if unknown tag
203218
self.assertRaises(dt.ValidationError, lambda: json_decode(dt.Union(U), json.dumps('z')))

0 commit comments

Comments
 (0)