Skip to content

Commit 2aabec5

Browse files
authored
fix: tweak oneof detection (#505)
Oneof detection and assignment to fields is tricky. This patch fixes detection of oneof fields, fixes uses in generated clients and tweaks generated tests to use them correctly.
1 parent 15fc9a0 commit 2aabec5

File tree

5 files changed

+124
-6
lines changed

5 files changed

+124
-6
lines changed

packages/gapic-generator/gapic/schema/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def _get_fields(self,
615615
# `_load_message` method.
616616
answer: Dict[str, wrappers.Field] = collections.OrderedDict()
617617
for i, field_pb in enumerate(field_pbs):
618-
is_oneof = oneofs and field_pb.oneof_index > 0
618+
is_oneof = oneofs and field_pb.HasField('oneof_index')
619619
oneof_name = nth(
620620
(oneofs or {}).keys(),
621621
field_pb.oneof_index

packages/gapic-generator/gapic/schema/wrappers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,15 @@ def __hash__(self):
239239
# Identity is sufficiently unambiguous.
240240
return hash(self.ident)
241241

242+
def oneof_fields(self, include_optional=False):
243+
oneof_fields = collections.defaultdict(list)
244+
for field in self.fields.values():
245+
# Only include proto3 optional oneofs if explicitly looked for.
246+
if field.oneof and not field.proto3_optional or include_optional:
247+
oneof_fields[field.oneof].append(field)
248+
249+
return oneof_fields
250+
242251
@utils.cached_property
243252
def field_types(self) -> Sequence[Union['MessageType', 'EnumType']]:
244253
answer = tuple(
@@ -583,6 +592,15 @@ def client_output(self):
583592
def client_output_async(self):
584593
return self._client_output(enable_asyncio=True)
585594

595+
def flattened_oneof_fields(self, include_optional=False):
596+
oneof_fields = collections.defaultdict(list)
597+
for field in self.flattened_fields.values():
598+
# Only include proto3 optional oneofs if explicitly looked for.
599+
if field.oneof and not field.proto3_optional or include_optional:
600+
oneof_fields[field.oneof].append(field)
601+
602+
return oneof_fields
603+
586604
def _client_output(self, enable_asyncio: bool):
587605
"""Return the output from the client layer.
588606
@@ -685,6 +703,10 @@ def filter_fields(sig: str) -> Iterable[Tuple[str, Field]]:
685703

686704
return answer
687705

706+
@utils.cached_property
707+
def flattened_field_to_key(self):
708+
return {field.name: key for key, field in self.flattened_fields.items()}
709+
688710
@utils.cached_property
689711
def legacy_flattened_fields(self) -> Mapping[str, Field]:
690712
"""Return the legacy flattening interface: top level fields only,

packages/gapic-generator/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,15 @@ def test_{{ method.name|snake_case }}(transport: str = 'grpc'):
288288
call.return_value = iter([{{ method.output.ident }}()])
289289
{% else -%}
290290
call.return_value = {{ method.output.ident }}(
291-
{%- for field in method.output.fields.values() | rejectattr('message')%}{% if not (field.oneof and not field.proto3_optional) %}
291+
{%- for field in method.output.fields.values() | rejectattr('message')%}{% if not field.oneof or field.proto3_optional %}
292292
{{ field.name }}={{ field.mock_value }},
293293
{% endif %}{%- endfor %}
294+
{#- This is a hack to only pick one field #}
295+
{%- for oneof_fields in method.output.oneof_fields().values() %}
296+
{% with field = oneof_fields[0] %}
297+
{{ field.name }}={{ field.mock_value }},
298+
{%- endwith %}
299+
{%- endfor %}
294300
)
295301
{% endif -%}
296302
{% if method.client_streaming %}
@@ -567,9 +573,15 @@ def test_{{ method.name|snake_case }}_flattened():
567573
# request object values.
568574
assert len(call.mock_calls) == 1
569575
_, args, _ = call.mock_calls[0]
570-
{% for key, field in method.flattened_fields.items() -%}
576+
{% for key, field in method.flattened_fields.items() -%}{%- if not field.oneof or field.proto3_optional %}
571577
assert args[0].{{ key }} == {{ field.mock_value }}
572-
{% endfor %}
578+
{% endif %}{% endfor %}
579+
{%- for oneofs in method.flattened_oneof_fields().values() %}
580+
{%- with field = oneofs[-1] %}
581+
assert args[0].{{ method.flattened_field_to_key[field.name] }} == {{ field.mock_value }}
582+
{%- endwith %}
583+
{%- endfor %}
584+
573585

574586

575587
def test_{{ method.name|snake_case }}_flattened_error():
@@ -640,9 +652,14 @@ async def test_{{ method.name|snake_case }}_flattened_async():
640652
# request object values.
641653
assert len(call.mock_calls)
642654
_, args, _ = call.mock_calls[0]
643-
{% for key, field in method.flattened_fields.items() -%}
655+
{% for key, field in method.flattened_fields.items() -%}{%- if not field.oneof or field.proto3_optional %}
644656
assert args[0].{{ key }} == {{ field.mock_value }}
645-
{% endfor %}
657+
{% endif %}{% endfor %}
658+
{%- for oneofs in method.flattened_oneof_fields().values() %}
659+
{%- with field = oneofs[-1] %}
660+
assert args[0].{{ method.flattened_field_to_key[field.name] }} == {{ field.mock_value }}
661+
{%- endwith %}
662+
{%- endfor %}
646663

647664

648665
@pytest.mark.asyncio

packages/gapic-generator/tests/unit/schema/wrappers/test_message.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,26 @@ def test_field_map():
235235
entry_field = make_field('foos', message=entry_msg, repeated=True)
236236
assert entry_msg.map
237237
assert entry_field.map
238+
239+
240+
def test_oneof_fields():
241+
mass_kg = make_field(name="mass_kg", oneof="mass", type=5)
242+
mass_lbs = make_field(name="mass_lbs", oneof="mass", type=5)
243+
length_m = make_field(name="length_m", oneof="length", type=5)
244+
length_f = make_field(name="length_f", oneof="length", type=5)
245+
color = make_field(name="color", type=5)
246+
request = make_message(
247+
name="CreateMolluscReuqest",
248+
fields=(
249+
mass_kg,
250+
mass_lbs,
251+
length_m,
252+
length_f,
253+
color,
254+
),
255+
)
256+
actual_oneofs = request.oneof_fields()
257+
expected_oneofs = {
258+
"mass": [mass_kg, mass_lbs],
259+
"length": [length_m, length_f],
260+
}

packages/gapic-generator/tests/unit/schema/wrappers/test_method.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,3 +364,59 @@ def test_method_legacy_flattened_fields():
364364
])
365365

366366
assert method.legacy_flattened_fields == expected
367+
368+
369+
def test_flattened_oneof_fields():
370+
mass_kg = make_field(name="mass_kg", oneof="mass", type=5)
371+
mass_lbs = make_field(name="mass_lbs", oneof="mass", type=5)
372+
373+
length_m = make_field(name="length_m", oneof="length", type=5)
374+
length_f = make_field(name="length_f", oneof="length", type=5)
375+
376+
color = make_field(name="color", type=5)
377+
mantle = make_field(
378+
name="mantle",
379+
message=make_message(
380+
name="Mantle",
381+
fields=(
382+
make_field(name="color", type=5),
383+
mass_kg,
384+
mass_lbs,
385+
),
386+
),
387+
)
388+
request = make_message(
389+
name="CreateMolluscReuqest",
390+
fields=(
391+
length_m,
392+
length_f,
393+
color,
394+
mantle,
395+
),
396+
)
397+
method = make_method(
398+
name="CreateMollusc",
399+
input_message=request,
400+
signatures=[
401+
"length_m,",
402+
"length_f,",
403+
"mantle.mass_kg,",
404+
"mantle.mass_lbs,",
405+
"color",
406+
]
407+
)
408+
409+
expected = {"mass": [mass_kg, mass_lbs], "length": [length_m, length_f]}
410+
actual = method.flattened_oneof_fields()
411+
assert expected == actual
412+
413+
# Check this method too becasue the setup is a lot of work.
414+
expected = {
415+
"color": "color",
416+
"length_m": "length_m",
417+
"length_f": "length_f",
418+
"mass_kg": "mantle.mass_kg",
419+
"mass_lbs": "mantle.mass_lbs",
420+
}
421+
actual = method.flattened_field_to_key
422+
assert expected == actual

0 commit comments

Comments
 (0)