Skip to content

Commit e1b4ade

Browse files
authored
AdditionalProperties and oneOf in one Schema (#765)
* demonstrate error * updated templates * fix template * test
1 parent 1b63982 commit e1b4ade

File tree

7 files changed

+319
-1
lines changed

7 files changed

+319
-1
lines changed

internal/test/components/components.gen.go

Lines changed: 170 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/test/components/components.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,23 @@ components:
325325
- oneOf:
326326
- $ref: '#/components/schemas/OneOfVariant3'
327327
- $ref: '#/components/schemas/OneOfVariant4'
328+
OneOfObject13:
329+
description: oneOf with fixed discriminator and other fields allowed
330+
type: object
331+
properties:
332+
type:
333+
type: string
334+
oneOf:
335+
- $ref: '#/components/schemas/OneOfVariant1'
336+
- $ref: '#/components/schemas/OneOfVariant6'
337+
discriminator:
338+
propertyName: type
339+
mapping:
340+
v1: '#/components/schemas/OneOfVariant1'
341+
v6: '#/components/schemas/OneOfVariant6'
342+
required:
343+
- type
344+
additionalProperties: true
328345
AnyOfObject1:
329346
description: simple anyOf case
330347
anyOf:

internal/test/components/components_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"testing"
66

77
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
89
)
910

1011
func assertJsonEqual(t *testing.T, j1 []byte, j2 []byte) {
@@ -203,6 +204,34 @@ func TestAnyOf(t *testing.T) {
203204
assert.Equal(t, OneOfVariant5{Discriminator: "all", Id: 456}, v5)
204205
}
205206

207+
func TestOneOfWithAdditional(t *testing.T) {
208+
x := OneOfObject13{
209+
AdditionalProperties: map[string]interface{}{"x": "y"},
210+
}
211+
err := x.MergeOneOfVariant1(OneOfVariant1{Name: "test-name"})
212+
require.NoError(t, err)
213+
b, err := json.Marshal(x)
214+
require.NoError(t, err)
215+
assert.JSONEq(t, `{"x":"y", "name":"test-name", "type":"v1"}`, string(b))
216+
var y OneOfObject13
217+
err = json.Unmarshal(b, &y)
218+
require.NoError(t, err)
219+
assert.Equal(t, x.Type, y.Type)
220+
xVariant, err := x.AsOneOfVariant1()
221+
require.NoError(t, err)
222+
yVariant, err := y.AsOneOfVariant1()
223+
require.NoError(t, err)
224+
assert.Equal(t, xVariant, yVariant)
225+
xAdditional, ok := x.Get("x")
226+
assert.True(t, ok)
227+
yAdditional, ok := y.Get("x")
228+
assert.True(t, ok)
229+
assert.Equal(t, xAdditional, yAdditional)
230+
b, err = json.Marshal(y)
231+
require.NoError(t, err)
232+
assert.JSONEq(t, `{"x":"y", "name":"test-name", "type":"v1"}`, string(b))
233+
}
234+
206235
func TestMarshalWhenNoUnionValueSet(t *testing.T) {
207236
const expected = `{}`
208237

pkg/codegen/codegen.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,12 @@ func GenerateTypeDefinitions(t *template.Template, swagger *openapi3.T, ops []Op
396396
return "", fmt.Errorf("error generating union boilerplate: %w", err)
397397
}
398398

399-
typeDefinitions := strings.Join([]string{enumsOut, typesOut, operationsOut, allOfBoilerplate, unionBoilerplate}, "")
399+
unionAndAdditionalBoilerplate, err := GenerateUnionAndAdditionalProopertiesBoilerplate(t, allTypes)
400+
if err != nil {
401+
return "", fmt.Errorf("error generating boilerplate for union types with additionalProperties: %w", err)
402+
}
403+
404+
typeDefinitions := strings.Join([]string{enumsOut, typesOut, operationsOut, allOfBoilerplate, unionBoilerplate, unionAndAdditionalBoilerplate}, "")
400405
return typeDefinitions, nil
401406
}
402407

@@ -778,6 +783,26 @@ func GenerateUnionBoilerplate(t *template.Template, typeDefs []TypeDefinition) (
778783
return GenerateTemplates([]string{"union.tmpl"}, t, context)
779784
}
780785

786+
func GenerateUnionAndAdditionalProopertiesBoilerplate(t *template.Template, typeDefs []TypeDefinition) (string, error) {
787+
var filteredTypes []TypeDefinition
788+
for _, t := range typeDefs {
789+
if len(t.Schema.UnionElements) != 0 && t.Schema.HasAdditionalProperties {
790+
filteredTypes = append(filteredTypes, t)
791+
}
792+
}
793+
794+
if len(filteredTypes) == 0 {
795+
return "", nil
796+
}
797+
context := struct {
798+
Types []TypeDefinition
799+
}{
800+
Types: filteredTypes,
801+
}
802+
803+
return GenerateTemplates([]string{"union-and-additional-properties.tmpl"}, t, context)
804+
}
805+
781806
// SanitizeCode runs sanitizers across the generated Go code to ensure the
782807
// generated code will be able to compile.
783808
func SanitizeCode(goCode string) string {

pkg/codegen/templates/additional-properties.tmpl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ func (a *{{.TypeName}}) Set(fieldName string, value {{$addType}}) {
1717
a.AdditionalProperties[fieldName] = value
1818
}
1919

20+
{{if eq 0 (len .Schema.UnionElements) -}}
2021
// Override default JSON handling for {{.TypeName}} to handle AdditionalProperties
2122
func (a *{{.TypeName}}) UnmarshalJSON(b []byte) error {
2223
object := make(map[string]json.RawMessage)
@@ -68,3 +69,4 @@ func (a {{.TypeName}}) MarshalJSON() ([]byte, error) {
6869
return json.Marshal(object)
6970
}
7071
{{end}}
72+
{{end}}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
{{range .Types}}
2+
3+
{{$addType := .Schema.AdditionalPropertiesType.TypeDecl}}
4+
{{$typeName := .TypeName -}}
5+
{{$discriminator := .Schema.Discriminator}}
6+
{{$properties := .Schema.Properties -}}
7+
8+
// Override default JSON handling for {{.TypeName}} to handle AdditionalProperties and union
9+
func (a *{{.TypeName}}) UnmarshalJSON(b []byte) error {
10+
err := a.union.UnmarshalJSON(b)
11+
if err != nil {
12+
return err
13+
}
14+
object := make(map[string]json.RawMessage)
15+
err = json.Unmarshal(b, &object)
16+
if err != nil {
17+
return err
18+
}
19+
{{range .Schema.Properties}}
20+
if raw, found := object["{{.JsonFieldName}}"]; found {
21+
err = json.Unmarshal(raw, &a.{{.GoFieldName}})
22+
if err != nil {
23+
return fmt.Errorf("error reading '{{.JsonFieldName}}': %w", err)
24+
}
25+
delete(object, "{{.JsonFieldName}}")
26+
}
27+
{{end}}
28+
if len(object) != 0 {
29+
a.AdditionalProperties = make(map[string]{{$addType}})
30+
for fieldName, fieldBuf := range object {
31+
var fieldVal {{$addType}}
32+
err := json.Unmarshal(fieldBuf, &fieldVal)
33+
if err != nil {
34+
return fmt.Errorf("error unmarshaling field %s: %w", fieldName, err)
35+
}
36+
a.AdditionalProperties[fieldName] = fieldVal
37+
}
38+
}
39+
return nil
40+
}
41+
42+
// Override default JSON handling for {{.TypeName}} to handle AdditionalProperties and union
43+
func (a {{.TypeName}}) MarshalJSON() ([]byte, error) {
44+
var err error
45+
b, err := a.union.MarshalJSON()
46+
if err != nil {
47+
return nil, err
48+
}
49+
object := make(map[string]json.RawMessage)
50+
if a.union != nil {
51+
err = json.Unmarshal(b, &object)
52+
if err != nil {
53+
return nil, err
54+
}
55+
}
56+
{{range .Schema.Properties}}
57+
{{if not .Required}}if a.{{.GoFieldName}} != nil { {{end}}
58+
object["{{.JsonFieldName}}"], err = json.Marshal(a.{{.GoFieldName}})
59+
if err != nil {
60+
return nil, fmt.Errorf("error marshaling '{{.JsonFieldName}}': %w", err)
61+
}
62+
{{if not .Required}} }{{end}}
63+
{{end}}
64+
for fieldName, field := range a.AdditionalProperties {
65+
object[fieldName], err = json.Marshal(field)
66+
if err != nil {
67+
return nil, fmt.Errorf("error marshaling '%s': %w", fieldName, err)
68+
}
69+
}
70+
return json.Marshal(object)
71+
}
72+
{{end}}

0 commit comments

Comments
 (0)