From c5136a0704523f13fa7faf50ff227486bee67512 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc=20B=C3=B3dy?= Date: Fri, 30 Sep 2022 12:31:48 +0200 Subject: [PATCH 1/4] demonstrate error --- internal/test/components/components.gen.go | 7 +++ internal/test/components/components.yaml | 17 ++++++ pkg/codegen/codegen.go | 32 +++++++++-- .../union-and-additional-properties.tmpl | 53 +++++++++++++++++++ 4 files changed, 106 insertions(+), 3 deletions(-) create mode 100644 pkg/codegen/templates/union-and-additional-properties.tmpl diff --git a/internal/test/components/components.gen.go b/internal/test/components/components.gen.go index 82bfceaa40..4f4c86884b 100644 --- a/internal/test/components/components.gen.go +++ b/internal/test/components/components.gen.go @@ -217,6 +217,13 @@ type OneOfObject120 = string // OneOfObject121 defines model for . type OneOfObject121 = float32 +// OneOfObject13 oneOf with fixed discriminator and other fields allowed +type OneOfObject13 struct { + Type string `json:"type"` + AdditionalProperties map[string]interface{} `json:"-"` + union json.RawMessage +} + // OneOfObject2 oneOf with inline elements type OneOfObject2 struct { union json.RawMessage diff --git a/internal/test/components/components.yaml b/internal/test/components/components.yaml index f49173a586..a170fdae04 100644 --- a/internal/test/components/components.yaml +++ b/internal/test/components/components.yaml @@ -325,6 +325,23 @@ components: - oneOf: - $ref: '#/components/schemas/OneOfVariant3' - $ref: '#/components/schemas/OneOfVariant4' + OneOfObject13: + description: oneOf with fixed discriminator and other fields allowed + type: object + properties: + type: + type: string + oneOf: + - $ref: '#/components/schemas/OneOfVariant1' + - $ref: '#/components/schemas/OneOfVariant6' + discriminator: + propertyName: type + mapping: + v1: '#/components/schemas/OneOfVariant1' + v6: '#/components/schemas/OneOfVariant6' + required: + - type + additionalProperties: true AnyOfObject1: description: simple anyOf case anyOf: diff --git a/pkg/codegen/codegen.go b/pkg/codegen/codegen.go index 7d03752986..c7257695ee 100644 --- a/pkg/codegen/codegen.go +++ b/pkg/codegen/codegen.go @@ -386,7 +386,12 @@ func GenerateTypeDefinitions(t *template.Template, swagger *openapi3.T, ops []Op return "", fmt.Errorf("error generating union boilerplate: %w", err) } - typeDefinitions := strings.Join([]string{enumsOut, typesOut, operationsOut, allOfBoilerplate, unionBoilerplate}, "") + unionAndAdditionalBoilerplate, err := GenerateUnionBoilerplate(t, allTypes) + if err != nil { + return "", fmt.Errorf("error generating boilerplate for union types with additionalProperties: %w", err) + } + + typeDefinitions := strings.Join([]string{enumsOut, typesOut, operationsOut, allOfBoilerplate, unionBoilerplate, unionAndAdditionalBoilerplate}, "") return typeDefinitions, nil } @@ -725,7 +730,7 @@ func GenerateAdditionalPropertyBoilerplate(t *template.Template, typeDefs []Type m[t.TypeName] = true - if t.Schema.HasAdditionalProperties { + if t.Schema.HasAdditionalProperties && len(t.Schema.UnionElements) == 0 { filteredTypes = append(filteredTypes, t) } } @@ -742,7 +747,7 @@ func GenerateAdditionalPropertyBoilerplate(t *template.Template, typeDefs []Type func GenerateUnionBoilerplate(t *template.Template, typeDefs []TypeDefinition) (string, error) { var filteredTypes []TypeDefinition for _, t := range typeDefs { - if len(t.Schema.UnionElements) != 0 { + if len(t.Schema.UnionElements) != 0 && !t.Schema.HasAdditionalProperties { filteredTypes = append(filteredTypes, t) } } @@ -760,6 +765,27 @@ func GenerateUnionBoilerplate(t *template.Template, typeDefs []TypeDefinition) ( return GenerateTemplates([]string{"union.tmpl"}, t, context) } +func GenerateUnionAndAdditionalProopertiesBoilerplate(t *template.Template, typeDefs []TypeDefinition) (string, error) { + var filteredTypes []TypeDefinition + for _, t := range typeDefs { + if len(t.Schema.UnionElements) != 0 && t.Schema.HasAdditionalProperties { + filteredTypes = append(filteredTypes, t) + } + } + + if len(filteredTypes) == 0 { + return "", nil + } + + context := struct { + Types []TypeDefinition + }{ + Types: filteredTypes, + } + + return GenerateTemplates([]string{"union-and-additonal-properties.tmpl"}, t, context) +} + // SanitizeCode runs sanitizers across the generated Go code to ensure the // generated code will be able to compile. func SanitizeCode(goCode string) string { diff --git a/pkg/codegen/templates/union-and-additional-properties.tmpl b/pkg/codegen/templates/union-and-additional-properties.tmpl new file mode 100644 index 0000000000..212f67f024 --- /dev/null +++ b/pkg/codegen/templates/union-and-additional-properties.tmpl @@ -0,0 +1,53 @@ +{{range .Types}}{{$addType := .Schema.AdditionalPropertiesType.TypeDecl}} + +// Override default JSON handling for {{.TypeName}} to handle AdditionalProperties and union +func (a *{{.TypeName}}) UnmarshalJSON(b []byte) error { + object := make(map[string]json.RawMessage) + err := json.Unmarshal(b, &object) + if err != nil { + return err + } +{{range .Schema.Properties}} + if raw, found := object["{{.JsonFieldName}}"]; found { + err = json.Unmarshal(raw, &a.{{.GoFieldName}}) + if err != nil { + return fmt.Errorf("error reading '{{.JsonFieldName}}': %w", err) + } + delete(object, "{{.JsonFieldName}}") + } +{{end}} + if len(object) != 0 { + a.AdditionalProperties = make(map[string]{{$addType}}) + for fieldName, fieldBuf := range object { + var fieldVal {{$addType}} + err := json.Unmarshal(fieldBuf, &fieldVal) + if err != nil { + return fmt.Errorf("error unmarshaling field %s: %w", fieldName, err) + } + a.AdditionalProperties[fieldName] = fieldVal + } + } + return nil +} + +// Override default JSON handling for {{.TypeName}} to handle AdditionalProperties and union +func (a {{.TypeName}}) MarshalJSON() ([]byte, error) { + var err error + object := make(map[string]json.RawMessage) +{{range .Schema.Properties}} +{{if not .Required}}if a.{{.GoFieldName}} != nil { {{end}} + object["{{.JsonFieldName}}"], err = json.Marshal(a.{{.GoFieldName}}) + if err != nil { + return nil, fmt.Errorf("error marshaling '{{.JsonFieldName}}': %w", err) + } +{{if not .Required}} }{{end}} +{{end}} + for fieldName, field := range a.AdditionalProperties { + object[fieldName], err = json.Marshal(field) + if err != nil { + return nil, fmt.Errorf("error marshaling '%s': %w", fieldName, err) + } + } + return json.Marshal(object) +} +{{end}} From 1892b3c1087b2458fcb77b6ca604115662c0dd2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc=20B=C3=B3dy?= Date: Fri, 30 Sep 2022 12:48:26 +0200 Subject: [PATCH 2/4] updated templates --- pkg/codegen/codegen.go | 10 ++++---- .../templates/additional-properties.tmpl | 2 ++ .../union-and-additional-properties.tmpl | 23 +++++++++++++++++-- pkg/codegen/templates/union.tmpl | 3 +++ 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/pkg/codegen/codegen.go b/pkg/codegen/codegen.go index c7257695ee..509854d031 100644 --- a/pkg/codegen/codegen.go +++ b/pkg/codegen/codegen.go @@ -386,7 +386,7 @@ func GenerateTypeDefinitions(t *template.Template, swagger *openapi3.T, ops []Op return "", fmt.Errorf("error generating union boilerplate: %w", err) } - unionAndAdditionalBoilerplate, err := GenerateUnionBoilerplate(t, allTypes) + unionAndAdditionalBoilerplate, err := GenerateUnionAndAdditionalProopertiesBoilerplate(t, allTypes) if err != nil { return "", fmt.Errorf("error generating boilerplate for union types with additionalProperties: %w", err) } @@ -730,7 +730,7 @@ func GenerateAdditionalPropertyBoilerplate(t *template.Template, typeDefs []Type m[t.TypeName] = true - if t.Schema.HasAdditionalProperties && len(t.Schema.UnionElements) == 0 { + if t.Schema.HasAdditionalProperties { filteredTypes = append(filteredTypes, t) } } @@ -747,7 +747,7 @@ func GenerateAdditionalPropertyBoilerplate(t *template.Template, typeDefs []Type func GenerateUnionBoilerplate(t *template.Template, typeDefs []TypeDefinition) (string, error) { var filteredTypes []TypeDefinition for _, t := range typeDefs { - if len(t.Schema.UnionElements) != 0 && !t.Schema.HasAdditionalProperties { + if len(t.Schema.UnionElements) != 0 { filteredTypes = append(filteredTypes, t) } } @@ -776,14 +776,14 @@ func GenerateUnionAndAdditionalProopertiesBoilerplate(t *template.Template, type if len(filteredTypes) == 0 { return "", nil } - + return "", nil context := struct { Types []TypeDefinition }{ Types: filteredTypes, } - return GenerateTemplates([]string{"union-and-additonal-properties.tmpl"}, t, context) + return GenerateTemplates([]string{"union-and-additional-properties.tmpl"}, t, context) } // SanitizeCode runs sanitizers across the generated Go code to ensure the diff --git a/pkg/codegen/templates/additional-properties.tmpl b/pkg/codegen/templates/additional-properties.tmpl index 46b66c86e2..5ddcd5ddba 100644 --- a/pkg/codegen/templates/additional-properties.tmpl +++ b/pkg/codegen/templates/additional-properties.tmpl @@ -17,6 +17,7 @@ func (a *{{.TypeName}}) Set(fieldName string, value {{$addType}}) { a.AdditionalProperties[fieldName] = value } +{{if eq 0 (len .Schema.UnionElements) -}}} // Override default JSON handling for {{.TypeName}} to handle AdditionalProperties func (a *{{.TypeName}}) UnmarshalJSON(b []byte) error { object := make(map[string]json.RawMessage) @@ -68,3 +69,4 @@ func (a {{.TypeName}}) MarshalJSON() ([]byte, error) { return json.Marshal(object) } {{end}} +{{end}} \ No newline at end of file diff --git a/pkg/codegen/templates/union-and-additional-properties.tmpl b/pkg/codegen/templates/union-and-additional-properties.tmpl index 212f67f024..a1dd480ab1 100644 --- a/pkg/codegen/templates/union-and-additional-properties.tmpl +++ b/pkg/codegen/templates/union-and-additional-properties.tmpl @@ -1,9 +1,18 @@ -{{range .Types}}{{$addType := .Schema.AdditionalPropertiesType.TypeDecl}} +{{range .Types}} + +{{$addType := .Schema.AdditionalPropertiesType.TypeDecl}} +{{$typeName := .TypeName -}} +{{$discriminator := .Schema.Discriminator}} +{{$properties := .Schema.Properties -}} // Override default JSON handling for {{.TypeName}} to handle AdditionalProperties and union func (a *{{.TypeName}}) UnmarshalJSON(b []byte) error { + err := t.union.UnmarshalJSON(b) + if err != nil { + return err + } object := make(map[string]json.RawMessage) - err := json.Unmarshal(b, &object) + err = json.Unmarshal(b, &object) if err != nil { return err } @@ -33,7 +42,17 @@ func (a *{{.TypeName}}) UnmarshalJSON(b []byte) error { // Override default JSON handling for {{.TypeName}} to handle AdditionalProperties and union func (a {{.TypeName}}) MarshalJSON() ([]byte, error) { var err error + b, err := t.union.MarshalJSON() + if err != nil { + return nil, err + } object := make(map[string]json.RawMessage) + if t.union != nil { + err = json.Unmarshal(b, &object) + if err != nil { + return nil, err + } + } {{range .Schema.Properties}} {{if not .Required}}if a.{{.GoFieldName}} != nil { {{end}} object["{{.JsonFieldName}}"], err = json.Marshal(a.{{.GoFieldName}}) diff --git a/pkg/codegen/templates/union.tmpl b/pkg/codegen/templates/union.tmpl index d233e13bee..8f64105519 100644 --- a/pkg/codegen/templates/union.tmpl +++ b/pkg/codegen/templates/union.tmpl @@ -86,6 +86,8 @@ {{end}} {{end}} + {{if not .Schema.HasAdditionalProperties}} + func (t {{.TypeName}}) MarshalJSON() ([]byte, error) { b, err := t.union.MarshalJSON() {{if ne 0 (len .Schema.Properties) -}} @@ -132,4 +134,5 @@ {{end -}} return err } + {{end}} {{end}} From 0cd4f2b5df8fee375d5fab270ad31ff574b9e1ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc=20B=C3=B3dy?= Date: Fri, 30 Sep 2022 13:01:23 +0200 Subject: [PATCH 3/4] fix template --- internal/test/components/components.gen.go | 163 ++++++++++++++++++ pkg/codegen/codegen.go | 1 - .../templates/additional-properties.tmpl | 2 +- .../union-and-additional-properties.tmpl | 6 +- 4 files changed, 167 insertions(+), 5 deletions(-) diff --git a/internal/test/components/components.gen.go b/internal/test/components/components.gen.go index 4f4c86884b..3a646f4128 100644 --- a/internal/test/components/components.gen.go +++ b/internal/test/components/components.gen.go @@ -779,6 +779,23 @@ func (a AdditionalPropertiesObject4_Inner) MarshalJSON() ([]byte, error) { return json.Marshal(object) } +// Getter for additional properties for OneOfObject13. Returns the specified +// element and whether it was found +func (a OneOfObject13) Get(fieldName string) (value interface{}, found bool) { + if a.AdditionalProperties != nil { + value, found = a.AdditionalProperties[fieldName] + } + return +} + +// Setter for additional properties for OneOfObject13 +func (a *OneOfObject13) Set(fieldName string, value interface{}) { + if a.AdditionalProperties == nil { + a.AdditionalProperties = make(map[string]interface{}) + } + a.AdditionalProperties[fieldName] = value +} + // AsOneOfVariant4 returns the union data inside the AnyOfObject1 as a OneOfVariant4 func (t AnyOfObject1) AsOneOfVariant4() (OneOfVariant4, error) { var body OneOfVariant4 @@ -1249,6 +1266,89 @@ func (t *OneOfObject12) UnmarshalJSON(b []byte) error { return err } +// AsOneOfVariant1 returns the union data inside the OneOfObject13 as a OneOfVariant1 +func (t OneOfObject13) AsOneOfVariant1() (OneOfVariant1, error) { + var body OneOfVariant1 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromOneOfVariant1 overwrites any union data inside the OneOfObject13 as the provided OneOfVariant1 +func (t *OneOfObject13) FromOneOfVariant1(v OneOfVariant1) error { + t.Type = "v1" + + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeOneOfVariant1 performs a merge with any union data inside the OneOfObject13, using the provided OneOfVariant1 +func (t *OneOfObject13) MergeOneOfVariant1(v OneOfVariant1) error { + t.Type = "v1" + + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JsonMerge(b, t.union) + t.union = merged + return err +} + +// AsOneOfVariant6 returns the union data inside the OneOfObject13 as a OneOfVariant6 +func (t OneOfObject13) AsOneOfVariant6() (OneOfVariant6, error) { + var body OneOfVariant6 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromOneOfVariant6 overwrites any union data inside the OneOfObject13 as the provided OneOfVariant6 +func (t *OneOfObject13) FromOneOfVariant6(v OneOfVariant6) error { + t.Type = "v6" + + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeOneOfVariant6 performs a merge with any union data inside the OneOfObject13, using the provided OneOfVariant6 +func (t *OneOfObject13) MergeOneOfVariant6(v OneOfVariant6) error { + t.Type = "v6" + + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JsonMerge(b, t.union) + t.union = merged + return err +} + +func (t OneOfObject13) Discriminator() (string, error) { + var discriminator struct { + Discriminator string `json:"type"` + } + err := json.Unmarshal(t.union, &discriminator) + return discriminator.Discriminator, err +} + +func (t OneOfObject13) ValueByDiscriminator() (interface{}, error) { + discriminator, err := t.Discriminator() + if err != nil { + return nil, err + } + switch discriminator { + case "v1": + return t.AsOneOfVariant1() + case "v6": + return t.AsOneOfVariant6() + default: + return nil, errors.New("unknown discriminator value: " + discriminator) + } +} + // AsOneOfObject20 returns the union data inside the OneOfObject2 as a OneOfObject20 func (t OneOfObject2) AsOneOfObject20() (OneOfObject20, error) { var body OneOfObject20 @@ -1984,3 +2084,66 @@ func (t *OneOfObject9) UnmarshalJSON(b []byte) error { return err } + +// Override default JSON handling for OneOfObject13 to handle AdditionalProperties and union +func (a *OneOfObject13) UnmarshalJSON(b []byte) error { + err := a.union.UnmarshalJSON(b) + if err != nil { + return err + } + object := make(map[string]json.RawMessage) + err = json.Unmarshal(b, &object) + if err != nil { + return err + } + + if raw, found := object["type"]; found { + err = json.Unmarshal(raw, &a.Type) + if err != nil { + return fmt.Errorf("error reading 'type': %w", err) + } + delete(object, "type") + } + + if len(object) != 0 { + a.AdditionalProperties = make(map[string]interface{}) + for fieldName, fieldBuf := range object { + var fieldVal interface{} + err := json.Unmarshal(fieldBuf, &fieldVal) + if err != nil { + return fmt.Errorf("error unmarshaling field %s: %w", fieldName, err) + } + a.AdditionalProperties[fieldName] = fieldVal + } + } + return nil +} + +// Override default JSON handling for OneOfObject13 to handle AdditionalProperties and union +func (a OneOfObject13) MarshalJSON() ([]byte, error) { + var err error + b, err := a.union.MarshalJSON() + if err != nil { + return nil, err + } + object := make(map[string]json.RawMessage) + if a.union != nil { + err = json.Unmarshal(b, &object) + if err != nil { + return nil, err + } + } + + object["type"], err = json.Marshal(a.Type) + if err != nil { + return nil, fmt.Errorf("error marshaling 'type': %w", err) + } + + for fieldName, field := range a.AdditionalProperties { + object[fieldName], err = json.Marshal(field) + if err != nil { + return nil, fmt.Errorf("error marshaling '%s': %w", fieldName, err) + } + } + return json.Marshal(object) +} diff --git a/pkg/codegen/codegen.go b/pkg/codegen/codegen.go index 509854d031..e6e8150526 100644 --- a/pkg/codegen/codegen.go +++ b/pkg/codegen/codegen.go @@ -776,7 +776,6 @@ func GenerateUnionAndAdditionalProopertiesBoilerplate(t *template.Template, type if len(filteredTypes) == 0 { return "", nil } - return "", nil context := struct { Types []TypeDefinition }{ diff --git a/pkg/codegen/templates/additional-properties.tmpl b/pkg/codegen/templates/additional-properties.tmpl index 5ddcd5ddba..7b7c0ace4b 100644 --- a/pkg/codegen/templates/additional-properties.tmpl +++ b/pkg/codegen/templates/additional-properties.tmpl @@ -17,7 +17,7 @@ func (a *{{.TypeName}}) Set(fieldName string, value {{$addType}}) { a.AdditionalProperties[fieldName] = value } -{{if eq 0 (len .Schema.UnionElements) -}}} +{{if eq 0 (len .Schema.UnionElements) -}} // Override default JSON handling for {{.TypeName}} to handle AdditionalProperties func (a *{{.TypeName}}) UnmarshalJSON(b []byte) error { object := make(map[string]json.RawMessage) diff --git a/pkg/codegen/templates/union-and-additional-properties.tmpl b/pkg/codegen/templates/union-and-additional-properties.tmpl index a1dd480ab1..79b4c67b2e 100644 --- a/pkg/codegen/templates/union-and-additional-properties.tmpl +++ b/pkg/codegen/templates/union-and-additional-properties.tmpl @@ -7,7 +7,7 @@ // Override default JSON handling for {{.TypeName}} to handle AdditionalProperties and union func (a *{{.TypeName}}) UnmarshalJSON(b []byte) error { - err := t.union.UnmarshalJSON(b) + err := a.union.UnmarshalJSON(b) if err != nil { return err } @@ -42,12 +42,12 @@ func (a *{{.TypeName}}) UnmarshalJSON(b []byte) error { // Override default JSON handling for {{.TypeName}} to handle AdditionalProperties and union func (a {{.TypeName}}) MarshalJSON() ([]byte, error) { var err error - b, err := t.union.MarshalJSON() + b, err := a.union.MarshalJSON() if err != nil { return nil, err } object := make(map[string]json.RawMessage) - if t.union != nil { + if a.union != nil { err = json.Unmarshal(b, &object) if err != nil { return nil, err From 2a536c436d976d77810d6c7135d5acddf6698cd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C5=91rinc=20B=C3=B3dy?= Date: Fri, 30 Sep 2022 13:25:14 +0200 Subject: [PATCH 4/4] test --- internal/test/components/components_test.go | 29 +++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/internal/test/components/components_test.go b/internal/test/components/components_test.go index bcb34d7524..ed51d6357d 100644 --- a/internal/test/components/components_test.go +++ b/internal/test/components/components_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func assertJsonEqual(t *testing.T, j1 []byte, j2 []byte) { @@ -203,6 +204,34 @@ func TestAnyOf(t *testing.T) { assert.Equal(t, OneOfVariant5{Discriminator: "all", Id: 456}, v5) } +func TestOneOfWithAdditional(t *testing.T) { + x := OneOfObject13{ + AdditionalProperties: map[string]interface{}{"x": "y"}, + } + err := x.MergeOneOfVariant1(OneOfVariant1{Name: "test-name"}) + require.NoError(t, err) + b, err := json.Marshal(x) + require.NoError(t, err) + assert.JSONEq(t, `{"x":"y", "name":"test-name", "type":"v1"}`, string(b)) + var y OneOfObject13 + err = json.Unmarshal(b, &y) + require.NoError(t, err) + assert.Equal(t, x.Type, y.Type) + xVariant, err := x.AsOneOfVariant1() + require.NoError(t, err) + yVariant, err := y.AsOneOfVariant1() + require.NoError(t, err) + assert.Equal(t, xVariant, yVariant) + xAdditional, ok := x.Get("x") + assert.True(t, ok) + yAdditional, ok := y.Get("x") + assert.True(t, ok) + assert.Equal(t, xAdditional, yAdditional) + b, err = json.Marshal(y) + require.NoError(t, err) + assert.JSONEq(t, `{"x":"y", "name":"test-name", "type":"v1"}`, string(b)) +} + func TestMarshalWhenNoUnionValueSet(t *testing.T) { const expected = `{"one":null,"three":null,"two":null}`