From 71958afa46343ea2343563dd11e3bed038e39ea7 Mon Sep 17 00:00:00 2001 From: Toby Brain Date: Tue, 2 Sep 2025 22:27:47 +1000 Subject: [PATCH] Support unions with multiple mappings pointing to a single underlying type --- internal/test/issues/issue-1530/config.yaml | 4 + internal/test/issues/issue-1530/doc.go | 3 + .../test/issues/issue-1530/issue1530.gen.go | 126 ++++++++++++++++++ .../test/issues/issue-1530/issue1530.yaml | 57 ++++++++ .../test/issues/issue-1530/issue1530_test.go | 51 +++++++ pkg/codegen/schema.go | 3 +- pkg/codegen/templates/union.tmpl | 37 ++--- 7 files changed, 263 insertions(+), 18 deletions(-) create mode 100644 internal/test/issues/issue-1530/config.yaml create mode 100644 internal/test/issues/issue-1530/doc.go create mode 100644 internal/test/issues/issue-1530/issue1530.gen.go create mode 100644 internal/test/issues/issue-1530/issue1530.yaml create mode 100644 internal/test/issues/issue-1530/issue1530_test.go diff --git a/internal/test/issues/issue-1530/config.yaml b/internal/test/issues/issue-1530/config.yaml new file mode 100644 index 0000000000..3cb9f3dd04 --- /dev/null +++ b/internal/test/issues/issue-1530/config.yaml @@ -0,0 +1,4 @@ +package: issue1530 +generate: + models: true +output: issue1530.gen.go \ No newline at end of file diff --git a/internal/test/issues/issue-1530/doc.go b/internal/test/issues/issue-1530/doc.go new file mode 100644 index 0000000000..c6a0133735 --- /dev/null +++ b/internal/test/issues/issue-1530/doc.go @@ -0,0 +1,3 @@ +package issue1530 + +//go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen --config=config.yaml issue1530.yaml diff --git a/internal/test/issues/issue-1530/issue1530.gen.go b/internal/test/issues/issue-1530/issue1530.gen.go new file mode 100644 index 0000000000..8a60251b2c --- /dev/null +++ b/internal/test/issues/issue-1530/issue1530.gen.go @@ -0,0 +1,126 @@ +// Package issue1530 provides primitives to interact with the openapi HTTP API. +// +// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.0.0-00010101000000-000000000000 DO NOT EDIT. +package issue1530 + +import ( + "encoding/json" + "errors" + + "github.com/oapi-codegen/runtime" +) + +// ConfigHttp defines model for ConfigHttp. +type ConfigHttp struct { + ConfigType string `json:"config_type"` + Host string `json:"host"` + Password *string `json:"password,omitempty"` + Port int `json:"port"` + User *string `json:"user,omitempty"` +} + +// ConfigSaveReq defines model for ConfigSaveReq. +type ConfigSaveReq struct { + union json.RawMessage +} + +// ConfigSsh defines model for ConfigSsh. +type ConfigSsh struct { + ConfigType string `json:"config_type"` + Host *string `json:"host,omitempty"` + Port *int `json:"port,omitempty"` + PrivateKey *string `json:"private_key,omitempty"` + User *string `json:"user,omitempty"` +} + +// PostConfigJSONRequestBody defines body for PostConfig for application/json ContentType. +type PostConfigJSONRequestBody = ConfigSaveReq + +// AsConfigHttp returns the union data inside the ConfigSaveReq as a ConfigHttp +func (t ConfigSaveReq) AsConfigHttp() (ConfigHttp, error) { + var body ConfigHttp + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromConfigHttp overwrites any union data inside the ConfigSaveReq as the provided ConfigHttp +func (t *ConfigSaveReq) FromConfigHttp(v ConfigHttp) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeConfigHttp performs a merge with any union data inside the ConfigSaveReq, using the provided ConfigHttp +func (t *ConfigSaveReq) MergeConfigHttp(v ConfigHttp) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +// AsConfigSsh returns the union data inside the ConfigSaveReq as a ConfigSsh +func (t ConfigSaveReq) AsConfigSsh() (ConfigSsh, error) { + var body ConfigSsh + err := json.Unmarshal(t.union, &body) + return body, err +} + +// FromConfigSsh overwrites any union data inside the ConfigSaveReq as the provided ConfigSsh +func (t *ConfigSaveReq) FromConfigSsh(v ConfigSsh) error { + b, err := json.Marshal(v) + t.union = b + return err +} + +// MergeConfigSsh performs a merge with any union data inside the ConfigSaveReq, using the provided ConfigSsh +func (t *ConfigSaveReq) MergeConfigSsh(v ConfigSsh) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + + merged, err := runtime.JSONMerge(t.union, b) + t.union = merged + return err +} + +func (t ConfigSaveReq) Discriminator() (string, error) { + var discriminator struct { + Discriminator string `json:"config_type"` + } + err := json.Unmarshal(t.union, &discriminator) + return discriminator.Discriminator, err +} + +func (t ConfigSaveReq) ValueByDiscriminator() (interface{}, error) { + discriminator, err := t.Discriminator() + if err != nil { + return nil, err + } + switch discriminator { + case "another_server": + return t.AsConfigHttp() + case "apache_server": + return t.AsConfigHttp() + case "ssh_server": + return t.AsConfigSsh() + case "web_server": + return t.AsConfigHttp() + default: + return nil, errors.New("unknown discriminator value: " + discriminator) + } +} + +func (t ConfigSaveReq) MarshalJSON() ([]byte, error) { + b, err := t.union.MarshalJSON() + return b, err +} + +func (t *ConfigSaveReq) UnmarshalJSON(b []byte) error { + err := t.union.UnmarshalJSON(b) + return err +} diff --git a/internal/test/issues/issue-1530/issue1530.yaml b/internal/test/issues/issue-1530/issue1530.yaml new file mode 100644 index 0000000000..222d074b7b --- /dev/null +++ b/internal/test/issues/issue-1530/issue1530.yaml @@ -0,0 +1,57 @@ +paths: + /config: + post: + summary: Save configuration + requestBody: + content: + application/json: + schema: + $ref: "#/components/schemas/ConfigSaveReq" + responses: + "200": + description: Configuration saved successfully +components: + schemas: + ConfigHttp: + type: object + properties: + config_type: + type: string + host: + type: string + port: + type: integer + user: + type: string + password: + type: string + required: + - config_type + - host + - port + ConfigSaveReq: + oneOf: + - $ref: "#/components/schemas/ConfigHttp" + - $ref: "#/components/schemas/ConfigSsh" + discriminator: + propertyName: config_type + mapping: + ssh_server: "#/components/schemas/ConfigSsh" + apache_server: "#/components/schemas/ConfigHttp" + web_server: "#/components/schemas/ConfigHttp" + another_server: "#/components/schemas/ConfigHttp" + ConfigSsh: + type: object + properties: + config_type: + type: string + host: + type: string + port: + type: integer + user: + type: string + private_key: + type: string + required: + - config_type \ No newline at end of file diff --git a/internal/test/issues/issue-1530/issue1530_test.go b/internal/test/issues/issue-1530/issue1530_test.go new file mode 100644 index 0000000000..a23f7ed2d8 --- /dev/null +++ b/internal/test/issues/issue-1530/issue1530_test.go @@ -0,0 +1,51 @@ +package issue1530_test + +import ( + "testing" + + issue1530 "github.com/oapi-codegen/oapi-codegen/v2/internal/test/issues/issue-1530" + "github.com/stretchr/testify/require" +) + +func TestIssue1530(t *testing.T) { + httpConfigTypes := []string{ + "another_server", + "apache_server", + "web_server", + } + + for _, configType := range httpConfigTypes { + t.Run("http-"+configType, func(t *testing.T) { + saveReq := issue1530.ConfigSaveReq{} + err := saveReq.FromConfigHttp(issue1530.ConfigHttp{ + ConfigType: configType, + Host: "example.com", + }) + require.NoError(t, err) + + cfg, err := saveReq.AsConfigHttp() + require.NoError(t, err) + require.Equal(t, configType, cfg.ConfigType) + + cfgByDiscriminator, err := saveReq.ValueByDiscriminator() + require.NoError(t, err) + require.Equal(t, cfg, cfgByDiscriminator) + }) + } + + t.Run("ssh", func(t *testing.T) { + saveReq := issue1530.ConfigSaveReq{} + err := saveReq.FromConfigSsh(issue1530.ConfigSsh{ + ConfigType: "ssh_server", + }) + require.NoError(t, err) + + cfg, err := saveReq.AsConfigSsh() + require.NoError(t, err) + require.Equal(t, "ssh_server", cfg.ConfigType) + + cfgByDiscriminator, err := saveReq.ValueByDiscriminator() + require.NoError(t, err) + require.Equal(t, cfg, cfgByDiscriminator) + }) +} diff --git a/pkg/codegen/schema.go b/pkg/codegen/schema.go index 1d03fbdd0e..6c1e517b23 100644 --- a/pkg/codegen/schema.go +++ b/pkg/codegen/schema.go @@ -900,7 +900,6 @@ func generateUnion(outSchema *Schema, elements openapi3.SchemaRefs, discriminato if v == element.Ref { outSchema.Discriminator.Mapping[k] = elementSchema.GoType mapped = true - break } } // Implicit mapping. @@ -911,7 +910,7 @@ func generateUnion(outSchema *Schema, elements openapi3.SchemaRefs, discriminato outSchema.UnionElements = append(outSchema.UnionElements, UnionElement(elementSchema.GoType)) } - if (outSchema.Discriminator != nil) && len(outSchema.Discriminator.Mapping) != len(elements) { + if (outSchema.Discriminator != nil) && len(outSchema.Discriminator.Mapping) < len(elements) { return errors.New("discriminator: not all schemas were mapped") } diff --git a/pkg/codegen/templates/union.tmpl b/pkg/codegen/templates/union.tmpl index c0385f82e3..272b9c9730 100644 --- a/pkg/codegen/templates/union.tmpl +++ b/pkg/codegen/templates/union.tmpl @@ -2,6 +2,7 @@ {{$typeName := .TypeName -}} {{$discriminator := .Schema.Discriminator}} {{$properties := .Schema.Properties -}} + {{$numberOfUnionTypes := len .Schema.UnionElements -}} {{range .Schema.UnionElements}} {{$element := . -}} // As{{ .Method }} returns the union data inside the {{$typeName}} as a {{.}} @@ -14,16 +15,18 @@ // From{{ .Method }} overwrites any union data inside the {{$typeName}} as the provided {{.}} func (t *{{$typeName}}) From{{ .Method }} (v {{.}}) error { {{if $discriminator -}} - {{range $value, $type := $discriminator.Mapping -}} - {{if eq $type $element -}} - {{$hasProperty := false -}} - {{range $properties -}} - {{if eq .GoFieldName $discriminator.PropertyName -}} - t.{{$discriminator.PropertyName}} = "{{$value}}" - {{$hasProperty = true -}} + {{if eq $numberOfUnionTypes (len $discriminator.Mapping) -}} + {{range $value, $type := $discriminator.Mapping -}} + {{if eq $type $element -}} + {{$hasProperty := false -}} + {{range $properties -}} + {{if eq .GoFieldName $discriminator.PropertyName -}} + t.{{$discriminator.PropertyName}} = "{{$value}}" + {{$hasProperty = true -}} + {{end -}} {{end -}} + {{if not $hasProperty}}v.{{$discriminator.PropertyName}} = "{{$value}}"{{end}} {{end -}} - {{if not $hasProperty}}v.{{$discriminator.PropertyName}} = "{{$value}}"{{end}} {{end -}} {{end -}} {{end -}} @@ -35,16 +38,18 @@ // Merge{{ .Method }} performs a merge with any union data inside the {{$typeName}}, using the provided {{.}} func (t *{{$typeName}}) Merge{{ .Method }} (v {{.}}) error { {{if $discriminator -}} - {{range $value, $type := $discriminator.Mapping -}} - {{if eq $type $element -}} - {{$hasProperty := false -}} - {{range $properties -}} - {{if eq .GoFieldName $discriminator.PropertyName -}} - t.{{$discriminator.PropertyName}} = "{{$value}}" - {{$hasProperty = true -}} + {{if eq $numberOfUnionTypes (len $discriminator.Mapping) -}} + {{range $value, $type := $discriminator.Mapping -}} + {{if eq $type $element -}} + {{$hasProperty := false -}} + {{range $properties -}} + {{if eq .GoFieldName $discriminator.PropertyName -}} + t.{{$discriminator.PropertyName}} = "{{$value}}" + {{$hasProperty = true -}} + {{end -}} {{end -}} + {{if not $hasProperty}}v.{{$discriminator.PropertyName}} = "{{$value}}"{{end}} {{end -}} - {{if not $hasProperty}}v.{{$discriminator.PropertyName}} = "{{$value}}"{{end}} {{end -}} {{end -}} {{end -}}