@@ -104,6 +104,43 @@ def test_json_encoder(self):
104104 b = '\xff ' * 5
105105 self .assertEqual (JsonEncoder .encode (dt .Binary (), b ), json .dumps (base64 .b64encode (b )))
106106
107+ def test_json_encoder_union (self ):
108+ class S (object ):
109+ _field_names_ = {'f' }
110+ _fields_ = [('f' , dt .String ())]
111+ class U (object ):
112+ _fields_ = {'a' : dt .Int64 (),
113+ 'b' : dt .Symbol (),
114+ 'c' : dt .Struct (S ),
115+ 'd' : dt .List (dt .Int64 ())}
116+
117+ # Test primitive variant
118+ u = U ()
119+ u ._tag = 'a'
120+ u .a = 64
121+ self .assertEqual (JsonEncoder .encode (dt .Union (U ), u ), json .dumps ({'a' : 64 }))
122+
123+ # Test symbol variant
124+ u = U ()
125+ u ._tag = 'b'
126+ self .assertEqual (JsonEncoder .encode (dt .Union (U ), u ), json .dumps ('b' ))
127+
128+ # Test struct variant
129+ u = U ()
130+ u ._tag = 'c'
131+ u .c = S ()
132+ u .c .f = 'hello'
133+ self .assertEqual (JsonEncoder .encode (dt .Union (U ), u ), json .dumps ({'c' : {'f' : 'hello' }}))
134+
135+ # Test list variant
136+ u = U ()
137+ u ._tag = 'd'
138+ u .d = [1 , 2 , 3 , 'a' ]
139+ # lists should be re-validated during serialization
140+ self .assertRaises (dt .ValidationError , lambda : JsonEncoder .encode (dt .Union (U ), u ))
141+ u .d = [1 , 2 , 3 , 4 ]
142+ self .assertEqual (JsonEncoder .encode (dt .Union (U ), u ), json .dumps ({'d' : u .d }))
143+
107144 def test_json_decoder (self ):
108145 self .assertEqual (JsonDecoder .decode (dt .String (), json .dumps ('abc' )), 'abc' )
109146 self .assertEqual (JsonDecoder .decode (dt .UInt32 (), json .dumps (123 )), 123 )
@@ -115,3 +152,36 @@ def test_json_decoder(self):
115152 now )
116153 b = '\xff ' * 5
117154 self .assertEqual (JsonDecoder .decode (dt .Binary (), json .dumps (base64 .b64encode (b ))), b )
155+
156+ def test_json_decoder_union (self ):
157+ class S (object ):
158+ _field_names_ = {'f' }
159+ _fields_ = [('f' , dt .String ())]
160+ class U (object ):
161+ _fields_ = {'a' : dt .Int64 (),
162+ 'b' : dt .Symbol (),
163+ 'c' : dt .Struct (S ),
164+ 'd' : dt .List (dt .Int64 ())}
165+ _tag = None
166+ def set_b (self ):
167+ self ._tag = 'b'
168+
169+ # Test primitive variant
170+ u = JsonDecoder .decode (dt .Union (U ), json .dumps ({'a' : 64 }))
171+ self .assertEqual (u .a , 64 )
172+
173+ # Test symbol variant
174+ u = JsonDecoder .decode (dt .Union (U ), json .dumps ('b' ))
175+ self .assertEqual (u ._tag , 'b' )
176+
177+ # Test struct variant
178+ u = JsonDecoder .decode (dt .Union (U ), json .dumps ({'c' : {'f' : 'hello' }}))
179+ self .assertEqual (u .c .f , 'hello' )
180+
181+ # Test list variant
182+ l = [1 , 2 , 3 , 4 ]
183+ u = JsonDecoder .decode (dt .Union (U ), json .dumps ({'d' : l }))
184+ self .assertEqual (u .d , l )
185+
186+ # Raises if unknown tag
187+ self .assertRaises (dt .ValidationError , lambda : JsonDecoder .decode (dt .Union (U ), json .dumps ('z' )))
0 commit comments