Skip to content

Commit 711573e

Browse files
committed
Improve support for unions
1 parent 22bbded commit 711573e

File tree

4 files changed

+43
-7
lines changed

4 files changed

+43
-7
lines changed

openapi_python_client/parser/openapi.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class Model:
236236
is_union: bool = False
237237

238238
@staticmethod
239-
def from_data(*, data: oai.Schema, name: str) -> Union[Model, ParseError]:
239+
def from_data(*, data: oai.Schema, name: str, ref: Reference = None) -> Union[Model, ParseError]:
240240
""" A single Model from its OAI data
241241
242242
Args:
@@ -249,7 +249,7 @@ def from_data(*, data: oai.Schema, name: str) -> Union[Model, ParseError]:
249249
optional_properties: List[Property] = []
250250
relative_imports: Set[str] = set()
251251

252-
ref = Reference.from_ref(data.title or name)
252+
ref = ref or Reference.from_ref(data.title or name)
253253

254254
inherits = None
255255
props = data.properties
@@ -261,7 +261,6 @@ def from_data(*, data: oai.Schema, name: str) -> Union[Model, ParseError]:
261261
relative_imports.add(f"from .{inherits.module_name} import {inherits.class_name}")
262262
props = data.allOf[1].properties
263263
elif not props:
264-
breakpoint()
265264
props = {}
266265
# raise AssertionError('Found empty property!')
267266

@@ -302,11 +301,21 @@ def from_data(*, data: oai.Schema, name: str) -> Union[MyUnion]:
302301
ref = Reference.from_ref(data.title or name)
303302
name = data.title or name
304303

305-
return MyUnion(
304+
model = MyUnion(
306305
reference=ref,
307-
joins=[Model.from_data(data=opt, name=f'{name}_{idx + 1}') for idx, opt in enumerate(data.anyOf)],
306+
joins=[
307+
Model.from_data(
308+
data=opt,
309+
name=f'{name}_{idx + 1}',
310+
ref=Reference(
311+
class_name=f'{ref.class_name}{idx + 1}',
312+
module_name=ref.module_name,
313+
),
314+
) for idx, opt in enumerate(data.anyOf)],
308315
relative_imports=[]
309316
)
317+
ALL_MODELS[ref] = model
318+
return model
310319

311320
@dataclass
312321
class Schemas:

openapi_python_client/parser/properties.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,10 @@ def get_type_string(self, no_optional: bool = False) -> str:
309309
return f'Literal[{value_as_string}]'
310310
return f'Literal[{value_as_string}, None]'
311311

312+
@property
313+
def repr_value(self) -> str:
314+
return repr(self.value)
315+
312316
def get_imports(self, *, prefix: str) -> Set[str]:
313317
"""
314318
Get a set of import strings that should be included when this property is used somewhere

openapi_python_client/templates/property_templates/literal_property.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
{% macro construct(property, source) %}
2+
if {{ source }} != {{ property.repr_value }}:
3+
raise ValueError('{{ "Wrong value for " + property.python_name + ": "}}' + {{ source }})
24
{{ property.python_name }} = {{ source }}
35
{% endmacro %}
46

openapi_python_client/templates/property_templates/ref_property.pyi

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
1+
{% macro _construct(property, source) %}
2+
{% set model = property.reference.lookup() %}
3+
{% if model.is_union %}
4+
from . import {{ model.joins[0].reference.module_name }}
5+
6+
err = None
7+
for opt in [{% for submodel in model.joins %}{{ submodel.reference.module_name }}.{{ submodel.reference.class_name }}, {% endfor %}]:
8+
try:
9+
{{ property.python_name }} = opt.from_dict(cast(Dict[str, Any], {{ source }}))
10+
except Exception as exc:
11+
err = exc
12+
else:
13+
break
14+
else:
15+
raise err
16+
del err
17+
{% else %}
18+
{{ property.python_name }} = {{ property.reference.class_name }}.from_dict(cast(Dict[str, Any], {{ source }}))
19+
{% endif %}
20+
{% endmacro %}
21+
122
{% macro construct(property, source) %}
223
{% if property.required %}
3-
{{ property.python_name }} = {{ property.reference.class_name }}.from_dict({{ source }})
24+
{{ _construct(property, source) }}
425
{% else %}
526
{{ property.python_name }} = None
627
if {{ source }} is not None:
7-
{{ property.python_name }} = {{ property.reference.class_name }}.from_dict(cast(Dict[str, Any], {{ source }}))
28+
{{ _construct(property, source) | indent(4) }}
829
{% endif %}
930
{% endmacro %}
1031

0 commit comments

Comments
 (0)