From 027c31677e646bbefef61518f1d112aa5d664077 Mon Sep 17 00:00:00 2001 From: Anton Telyshev Date: Sun, 22 Jan 2023 16:41:37 +0200 Subject: [PATCH] OneOf: Implicit mapping --- internal/test/components/components.gen.go | 215 +++++++++++++++++++- internal/test/components/components.yaml | 30 ++- internal/test/components/components_test.go | 122 ++++++++++- pkg/codegen/schema.go | 16 ++ pkg/codegen/utils.go | 17 ++ pkg/codegen/utils_test.go | 14 ++ 6 files changed, 403 insertions(+), 11 deletions(-) diff --git a/internal/test/components/components.gen.go b/internal/test/components/components.gen.go index 4ed8b65a1e..53e71d3cb1 100644 --- a/internal/test/components/components.gen.go +++ b/internal/test/components/components.gen.go @@ -256,7 +256,7 @@ type OneOfObject4 struct { union json.RawMessage } -// OneOfObject5 oneOf with disciminator but no mapping +// OneOfObject5 oneOf with discriminator but no mapping type OneOfObject5 struct { union json.RawMessage } @@ -266,6 +266,16 @@ type OneOfObject6 struct { union json.RawMessage } +// OneOfObject61 oneOf with discriminator and partial mapping +type OneOfObject61 struct { + union json.RawMessage +} + +// OneOfObject62 oneOf with snake_case discriminator and partial snake_case mapping +type OneOfObject62 struct { + union json.RawMessage +} + // OneOfObject7 array of oneOf type OneOfObject7 = []OneOfObject7_Item @@ -337,6 +347,12 @@ type SchemaObject struct { WriteOnlyRequiredProp *int `json:"writeOnlyRequiredProp,omitempty"` } +// OneOfVariant51 defines model for one_of_variant51. +type OneOfVariant51 struct { + Discriminator string `json:"discriminator"` + Id int `json:"id"` +} + // EnumParam1 defines model for EnumParam1. type EnumParam1 string @@ -1662,6 +1678,7 @@ func (t OneOfObject5) AsOneOfVariant4() (OneOfVariant4, error) { // FromOneOfVariant4 overwrites any union data inside the OneOfObject5 as the provided OneOfVariant4 func (t *OneOfObject5) FromOneOfVariant4(v OneOfVariant4) error { + v.Discriminator = "OneOfVariant4" b, err := json.Marshal(v) t.union = b return err @@ -1669,6 +1686,7 @@ func (t *OneOfObject5) FromOneOfVariant4(v OneOfVariant4) error { // MergeOneOfVariant4 performs a merge with any union data inside the OneOfObject5, using the provided OneOfVariant4 func (t *OneOfObject5) MergeOneOfVariant4(v OneOfVariant4) error { + v.Discriminator = "OneOfVariant4" b, err := json.Marshal(v) if err != nil { return err @@ -1688,6 +1706,7 @@ func (t OneOfObject5) AsOneOfVariant5() (OneOfVariant5, error) { // FromOneOfVariant5 overwrites any union data inside the OneOfObject5 as the provided OneOfVariant5 func (t *OneOfObject5) FromOneOfVariant5(v OneOfVariant5) error { + v.Discriminator = "OneOfVariant5" b, err := json.Marshal(v) t.union = b return err @@ -1695,6 +1714,7 @@ func (t *OneOfObject5) FromOneOfVariant5(v OneOfVariant5) error { // MergeOneOfVariant5 performs a merge with any union data inside the OneOfObject5, using the provided OneOfVariant5 func (t *OneOfObject5) MergeOneOfVariant5(v OneOfVariant5) error { + v.Discriminator = "OneOfVariant5" b, err := json.Marshal(v) if err != nil { return err @@ -1713,6 +1733,21 @@ func (t OneOfObject5) Discriminator() (string, error) { return discriminator.Discriminator, err } +func (t OneOfObject5) ValueByDiscriminator() (interface{}, error) { + discriminator, err := t.Discriminator() + if err != nil { + return nil, err + } + switch discriminator { + case "OneOfVariant4": + return t.AsOneOfVariant4() + case "OneOfVariant5": + return t.AsOneOfVariant5() + default: + return nil, errors.New("unknown discriminator value: " + discriminator) + } +} + func (t OneOfObject5) MarshalJSON() ([]byte, error) { b, err := t.union.MarshalJSON() return b, err @@ -1812,6 +1847,184 @@ func (t *OneOfObject6) UnmarshalJSON(b []byte) error { return err } +// AsOneOfVariant4 returns the union data inside the OneOfObject61 as a OneOfVariant4 +func (t OneOfObject61) AsOneOfVariant4() (OneOfVariant4, error) { + var body OneOfVariant4 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromOneOfVariant4 overwrites any union data inside the OneOfObject61 as the provided OneOfVariant4 +func (t *OneOfObject61) FromOneOfVariant4(v OneOfVariant4) error { + v.Discriminator = "v4" + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeOneOfVariant4 performs a merge with any union data inside the OneOfObject61, using the provided OneOfVariant4 +func (t *OneOfObject61) MergeOneOfVariant4(v OneOfVariant4) error { + v.Discriminator = "v4" + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JsonMerge(b, t.union) + t.union = merged + return err +} + +// AsOneOfVariant5 returns the union data inside the OneOfObject61 as a OneOfVariant5 +func (t OneOfObject61) AsOneOfVariant5() (OneOfVariant5, error) { + var body OneOfVariant5 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromOneOfVariant5 overwrites any union data inside the OneOfObject61 as the provided OneOfVariant5 +func (t *OneOfObject61) FromOneOfVariant5(v OneOfVariant5) error { + v.Discriminator = "OneOfVariant5" + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeOneOfVariant5 performs a merge with any union data inside the OneOfObject61, using the provided OneOfVariant5 +func (t *OneOfObject61) MergeOneOfVariant5(v OneOfVariant5) error { + v.Discriminator = "OneOfVariant5" + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JsonMerge(b, t.union) + t.union = merged + return err +} + +func (t OneOfObject61) Discriminator() (string, error) { + var discriminator struct { + Discriminator string `json:"discriminator"` + } + err := json.Unmarshal(t.union, &discriminator) + return discriminator.Discriminator, err +} + +func (t OneOfObject61) ValueByDiscriminator() (interface{}, error) { + discriminator, err := t.Discriminator() + if err != nil { + return nil, err + } + switch discriminator { + case "OneOfVariant5": + return t.AsOneOfVariant5() + case "v4": + return t.AsOneOfVariant4() + default: + return nil, errors.New("unknown discriminator value: " + discriminator) + } +} + +func (t OneOfObject61) MarshalJSON() ([]byte, error) { + b, err := t.union.MarshalJSON() + return b, err +} + +func (t *OneOfObject61) UnmarshalJSON(b []byte) error { + err := t.union.UnmarshalJSON(b) + return err +} + +// AsOneOfVariant4 returns the union data inside the OneOfObject62 as a OneOfVariant4 +func (t OneOfObject62) AsOneOfVariant4() (OneOfVariant4, error) { + var body OneOfVariant4 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromOneOfVariant4 overwrites any union data inside the OneOfObject62 as the provided OneOfVariant4 +func (t *OneOfObject62) FromOneOfVariant4(v OneOfVariant4) error { + v.Discriminator = "variant_four" + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeOneOfVariant4 performs a merge with any union data inside the OneOfObject62, using the provided OneOfVariant4 +func (t *OneOfObject62) MergeOneOfVariant4(v OneOfVariant4) error { + v.Discriminator = "variant_four" + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JsonMerge(b, t.union) + t.union = merged + return err +} + +// AsOneOfVariant51 returns the union data inside the OneOfObject62 as a OneOfVariant51 +func (t OneOfObject62) AsOneOfVariant51() (OneOfVariant51, error) { + var body OneOfVariant51 + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromOneOfVariant51 overwrites any union data inside the OneOfObject62 as the provided OneOfVariant51 +func (t *OneOfObject62) FromOneOfVariant51(v OneOfVariant51) error { + v.Discriminator = "one_of_variant51" + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeOneOfVariant51 performs a merge with any union data inside the OneOfObject62, using the provided OneOfVariant51 +func (t *OneOfObject62) MergeOneOfVariant51(v OneOfVariant51) error { + v.Discriminator = "one_of_variant51" + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JsonMerge(b, t.union) + t.union = merged + return err +} + +func (t OneOfObject62) Discriminator() (string, error) { + var discriminator struct { + Discriminator string `json:"discriminator"` + } + err := json.Unmarshal(t.union, &discriminator) + return discriminator.Discriminator, err +} + +func (t OneOfObject62) ValueByDiscriminator() (interface{}, error) { + discriminator, err := t.Discriminator() + if err != nil { + return nil, err + } + switch discriminator { + case "one_of_variant51": + return t.AsOneOfVariant51() + case "variant_four": + return t.AsOneOfVariant4() + default: + return nil, errors.New("unknown discriminator value: " + discriminator) + } +} + +func (t OneOfObject62) MarshalJSON() ([]byte, error) { + b, err := t.union.MarshalJSON() + return b, err +} + +func (t *OneOfObject62) UnmarshalJSON(b []byte) error { + err := t.union.UnmarshalJSON(b) + return err +} + // AsOneOfVariant1 returns the union data inside the OneOfObject7_Item as a OneOfVariant1 func (t OneOfObject7_Item) AsOneOfVariant1() (OneOfVariant1, error) { var body OneOfVariant1 diff --git a/internal/test/components/components.yaml b/internal/test/components/components.yaml index 98224ed5f6..8cb21f859d 100644 --- a/internal/test/components/components.yaml +++ b/internal/test/components/components.yaml @@ -244,7 +244,7 @@ components: - $ref: '#/components/schemas/OneOfVariant2' - $ref: '#/components/schemas/OneOfVariant3' OneOfObject5: - description: oneOf with disciminator but no mapping + description: oneOf with discriminator but no mapping oneOf: - $ref: '#/components/schemas/OneOfVariant4' - $ref: '#/components/schemas/OneOfVariant5' @@ -260,6 +260,24 @@ components: mapping: v4: '#/components/schemas/OneOfVariant4' v5: '#/components/schemas/OneOfVariant5' + OneOfObject61: + description: oneOf with discriminator and partial mapping + oneOf: + - $ref: '#/components/schemas/OneOfVariant4' + - $ref: '#/components/schemas/OneOfVariant5' + discriminator: + propertyName: discriminator + mapping: + v4: '#/components/schemas/OneOfVariant4' + OneOfObject62: + description: oneOf with snake_case discriminator and partial snake_case mapping + oneOf: + - $ref: '#/components/schemas/OneOfVariant4' + - $ref: '#/components/schemas/one_of_variant51' + discriminator: + propertyName: discriminator + mapping: + variant_four: '#/components/schemas/OneOfVariant4' OneOfObject7: description: array of oneOf type: array @@ -380,6 +398,16 @@ components: required: - discriminator - id + one_of_variant51: + type: object + properties: + discriminator: + type: string + id: + type: integer + required: + - discriminator + - id OneOfVariant6: type: object properties: diff --git a/internal/test/components/components_test.go b/internal/test/components/components_test.go index e25ede26bc..7810ea20d6 100644 --- a/internal/test/components/components_test.go +++ b/internal/test/components/components_test.go @@ -9,15 +9,8 @@ import ( ) func assertJsonEqual(t *testing.T, j1 []byte, j2 []byte) { - var v1, v2 interface{} - - err := json.Unmarshal(j1, &v1) - assert.NoError(t, err) - - err = json.Unmarshal(j2, &v2) - assert.NoError(t, err) - - assert.EqualValues(t, v1, v2) + t.Helper() + assert.JSONEq(t, string(j1), string(j2)) } func TestRawJSON(t *testing.T) { @@ -152,6 +145,117 @@ func TestOneOfWithDiscriminator(t *testing.T) { assertJsonEqual(t, []byte(variant5), marshaled) } +func TestOneOfWithDiscriminator_PartialMapping(t *testing.T) { + const variant4 = `{"discriminator": "v4", "name": "123"}` + const variant5 = `{"discriminator": "OneOfVariant5", "id": 321}` + var dst OneOfObject61 + + err := json.Unmarshal([]byte(variant4), &dst) + assert.NoError(t, err) + discriminator, err := dst.Discriminator() + require.NoError(t, err) + assert.Equal(t, "v4", discriminator) + v4, err := dst.ValueByDiscriminator() + require.NoError(t, err) + assert.Equal(t, OneOfVariant4{Discriminator: "v4", Name: "123"}, v4) + + err = json.Unmarshal([]byte(variant5), &dst) + require.NoError(t, err) + discriminator, err = dst.Discriminator() + require.NoError(t, err) + assert.Equal(t, "OneOfVariant5", discriminator) + v5, err := dst.ValueByDiscriminator() + require.NoError(t, err) + assert.Equal(t, OneOfVariant5{Discriminator: "OneOfVariant5", Id: 321}, v5) + + // discriminator value will be filled by the generated code + err = dst.FromOneOfVariant4(OneOfVariant4{Name: "123"}) + require.NoError(t, err) + marshaled, err := json.Marshal(dst) + require.NoError(t, err) + assertJsonEqual(t, []byte(variant4), marshaled) + + err = dst.FromOneOfVariant5(OneOfVariant5{Id: 321}) + require.NoError(t, err) + marshaled, err = json.Marshal(dst) + require.NoError(t, err) + assertJsonEqual(t, []byte(variant5), marshaled) +} + +func TestOneOfWithDiscriminator_SchemaNameUsed(t *testing.T) { + const variant4 = `{"discriminator": "variant_four", "name": "789"}` + const variant51 = `{"discriminator": "one_of_variant51", "id": 987}` + var dst OneOfObject62 + + err := json.Unmarshal([]byte(variant4), &dst) + assert.NoError(t, err) + discriminator, err := dst.Discriminator() + require.NoError(t, err) + assert.Equal(t, "variant_four", discriminator) + v4, err := dst.ValueByDiscriminator() + require.NoError(t, err) + assert.Equal(t, OneOfVariant4{Discriminator: "variant_four", Name: "789"}, v4) + + err = json.Unmarshal([]byte(variant51), &dst) + require.NoError(t, err) + discriminator, err = dst.Discriminator() + require.NoError(t, err) + assert.Equal(t, "one_of_variant51", discriminator) + v5, err := dst.ValueByDiscriminator() + require.NoError(t, err) + assert.Equal(t, OneOfVariant51{Discriminator: "one_of_variant51", Id: 987}, v5) + + // discriminator value will be filled by the generated code + err = dst.FromOneOfVariant4(OneOfVariant4{Name: "789"}) + require.NoError(t, err) + marshaled, err := json.Marshal(dst) + require.NoError(t, err) + assertJsonEqual(t, []byte(variant4), marshaled) + + err = dst.FromOneOfVariant51(OneOfVariant51{Id: 987}) + require.NoError(t, err) + marshaled, err = json.Marshal(dst) + require.NoError(t, err) + assertJsonEqual(t, []byte(variant51), marshaled) +} + +func TestOneOfWithDiscriminator_FullImplicitMapping(t *testing.T) { + const variant4 = `{"discriminator": "OneOfVariant4", "name": "456"}` + const variant5 = `{"discriminator": "OneOfVariant5", "id": 654}` + var dst OneOfObject5 + + err := json.Unmarshal([]byte(variant4), &dst) + assert.NoError(t, err) + discriminator, err := dst.Discriminator() + require.NoError(t, err) + assert.Equal(t, "OneOfVariant4", discriminator) + v4, err := dst.ValueByDiscriminator() + require.NoError(t, err) + assert.Equal(t, OneOfVariant4{Discriminator: "OneOfVariant4", Name: "456"}, v4) + + err = json.Unmarshal([]byte(variant5), &dst) + require.NoError(t, err) + discriminator, err = dst.Discriminator() + require.NoError(t, err) + assert.Equal(t, "OneOfVariant5", discriminator) + v5, err := dst.ValueByDiscriminator() + require.NoError(t, err) + assert.Equal(t, OneOfVariant5{Discriminator: "OneOfVariant5", Id: 654}, v5) + + // discriminator value will be filled by the generated code + err = dst.FromOneOfVariant4(OneOfVariant4{Name: "456"}) + require.NoError(t, err) + marshaled, err := json.Marshal(dst) + require.NoError(t, err) + assertJsonEqual(t, []byte(variant4), marshaled) + + err = dst.FromOneOfVariant5(OneOfVariant5{Id: 654}) + require.NoError(t, err) + marshaled, err = json.Marshal(dst) + require.NoError(t, err) + assertJsonEqual(t, []byte(variant5), marshaled) +} + func TestOneOfWithFixedProperties(t *testing.T) { const variant1 = "{\"type\": \"v1\", \"name\": \"123\"}" const variant6 = "{\"type\": \"v6\", \"values\": [1, 2, 3]}" diff --git a/pkg/codegen/schema.go b/pkg/codegen/schema.go index 18430240db..bcb27d5fb7 100644 --- a/pkg/codegen/schema.go +++ b/pkg/codegen/schema.go @@ -1,6 +1,7 @@ package codegen import ( + "errors" "fmt" "strings" @@ -759,15 +760,30 @@ func generateUnion(outSchema *Schema, elements openapi3.SchemaRefs, discriminato } if discriminator != nil { + if len(discriminator.Mapping) != 0 && element.Ref == "" { + return errors.New("ambiguous discriminator.mapping: please replace inlined object with $ref") + } + + // Explicit mapping. + var mapped bool for k, v := range discriminator.Mapping { if v == element.Ref { outSchema.Discriminator.Mapping[k] = elementSchema.GoType + mapped = true break } } + // Implicit mapping. + if !mapped { + outSchema.Discriminator.Mapping[RefPathToObjName(element.Ref)] = elementSchema.GoType + } } outSchema.UnionElements = append(outSchema.UnionElements, UnionElement(elementSchema.GoType)) } + if (outSchema.Discriminator != nil) && len(outSchema.Discriminator.Mapping) != len(elements) { + return errors.New("discriminator: not all schemas were mapped") + } + return nil } diff --git a/pkg/codegen/utils.go b/pkg/codegen/utils.go index 7526ee3bf1..b85d161945 100644 --- a/pkg/codegen/utils.go +++ b/pkg/codegen/utils.go @@ -291,6 +291,23 @@ func StringInArray(str string, array []string) bool { return false } +// RefPathToObjName returns the name of referenced object without changes. +// +// #/components/schemas/Foo -> Foo +// #/components/parameters/Bar -> Bar +// #/components/responses/baz_baz -> baz_baz +// document.json#/Foo -> Foo +// http://deepmap.com/schemas/document.json#/objObj -> objObj +// +// Does not check refPath correctness. +func RefPathToObjName(refPath string) string { + parts := strings.Split(refPath, "/") + if len(parts) > 0 { + return parts[len(parts)-1] + } + return "" +} + // RefPathToGoType takes a $ref value and converts it to a Go typename. // #/components/schemas/Foo -> Foo // #/components/parameters/Bar -> Bar diff --git a/pkg/codegen/utils_test.go b/pkg/codegen/utils_test.go index 518e287f5e..32708f0692 100644 --- a/pkg/codegen/utils_test.go +++ b/pkg/codegen/utils_test.go @@ -432,3 +432,17 @@ func TestSchemaNameToTypeName(t *testing.T) { assert.Equal(t, want, SchemaNameToTypeName(in)) } } + +func TestRefPathToObjName(t *testing.T) { + t.Parallel() + + for in, want := range map[string]string{ + "#/components/schemas/Foo": "Foo", + "#/components/parameters/Bar": "Bar", + "#/components/responses/baz_baz": "baz_baz", + "document.json#/Foo": "Foo", + "http://deepmap.com/schemas/document.json#/objObj": "objObj", + } { + assert.Equal(t, want, RefPathToObjName(in)) + } +}