diff --git a/internal/test/issues/issue-2232/config.yaml b/internal/test/issues/issue-2232/config.yaml new file mode 100644 index 0000000000..6368de1d4d --- /dev/null +++ b/internal/test/issues/issue-2232/config.yaml @@ -0,0 +1,5 @@ +package: issue2232 +output: issue2232.gen.go +generate: + std-http-server: true + models: true diff --git a/internal/test/issues/issue-2232/generate.go b/internal/test/issues/issue-2232/generate.go new file mode 100644 index 0000000000..f6767c10e1 --- /dev/null +++ b/internal/test/issues/issue-2232/generate.go @@ -0,0 +1,3 @@ +package issue2232 + +//go:generate go run github.com/oapi-codegen/oapi-codegen/v2/cmd/oapi-codegen --config=config.yaml spec.yaml diff --git a/internal/test/issues/issue-2232/issue2232.gen.go b/internal/test/issues/issue-2232/issue2232.gen.go new file mode 100644 index 0000000000..e795e7ad72 --- /dev/null +++ b/internal/test/issues/issue-2232/issue2232.gen.go @@ -0,0 +1,260 @@ +//go:build go1.22 + +// Package issue2232 provides primitives to interact with the openapi HTTP API. +// +// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.0.0-00010101000000-000000000000 DO NOT EDIT. +package issue2232 + +import ( + "fmt" + "net/http" + + "github.com/oapi-codegen/runtime" +) + +// Defines values for GetEndpointParamsEnvParamLevel. +const ( + GetEndpointParamsEnvParamLevelDev GetEndpointParamsEnvParamLevel = "dev" + GetEndpointParamsEnvParamLevelLive GetEndpointParamsEnvParamLevel = "live" +) + +// Valid indicates whether the value is a known member of the GetEndpointParamsEnvParamLevel enum. +func (e GetEndpointParamsEnvParamLevel) Valid() bool { + switch e { + case GetEndpointParamsEnvParamLevelDev: + return true + case GetEndpointParamsEnvParamLevelLive: + return true + default: + return false + } +} + +// Defines values for GetEndpointParamsEnvSchemaLevel. +const ( + GetEndpointParamsEnvSchemaLevelDev GetEndpointParamsEnvSchemaLevel = "dev" + GetEndpointParamsEnvSchemaLevelLive GetEndpointParamsEnvSchemaLevel = "live" +) + +// Valid indicates whether the value is a known member of the GetEndpointParamsEnvSchemaLevel enum. +func (e GetEndpointParamsEnvSchemaLevel) Valid() bool { + switch e { + case GetEndpointParamsEnvSchemaLevelDev: + return true + case GetEndpointParamsEnvSchemaLevelLive: + return true + default: + return false + } +} + +// GetEndpointParams defines parameters for GetEndpoint. +type GetEndpointParams struct { + EnvParamLevel GetEndpointParamsEnvParamLevel `form:"env_param_level" json:"env_param_level" validate:"required,oneof=dev live"` + EnvSchemaLevel GetEndpointParamsEnvSchemaLevel `form:"env_schema_level" json:"env_schema_level" validate:"required,oneof=dev live"` + Limit *int `form:"limit,omitempty" json:"limit,omitempty" validate:"min=0,max=100"` +} + +// GetEndpointParamsEnvParamLevel defines parameters for GetEndpoint. +type GetEndpointParamsEnvParamLevel string + +// GetEndpointParamsEnvSchemaLevel defines parameters for GetEndpoint. +type GetEndpointParamsEnvSchemaLevel string + +// ServerInterface represents all server handlers. +type ServerInterface interface { + + // (GET /v1/endpoint) + GetEndpoint(w http.ResponseWriter, r *http.Request, params GetEndpointParams) +} + +// ServerInterfaceWrapper converts contexts to parameters. +type ServerInterfaceWrapper struct { + Handler ServerInterface + HandlerMiddlewares []MiddlewareFunc + ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error) +} + +type MiddlewareFunc func(http.Handler) http.Handler + +// GetEndpoint operation middleware +func (siw *ServerInterfaceWrapper) GetEndpoint(w http.ResponseWriter, r *http.Request) { + + var err error + + // Parameter object where we will unmarshal all parameters from the context + var params GetEndpointParams + + // ------------- Required query parameter "env_param_level" ------------- + + if paramValue := r.URL.Query().Get("env_param_level"); paramValue != "" { + + } else { + siw.ErrorHandlerFunc(w, r, &RequiredParamError{ParamName: "env_param_level"}) + return + } + + err = runtime.BindQueryParameter("form", true, true, "env_param_level", r.URL.Query(), ¶ms.EnvParamLevel) + if err != nil { + siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "env_param_level", Err: err}) + return + } + + // ------------- Required query parameter "env_schema_level" ------------- + + if paramValue := r.URL.Query().Get("env_schema_level"); paramValue != "" { + + } else { + siw.ErrorHandlerFunc(w, r, &RequiredParamError{ParamName: "env_schema_level"}) + return + } + + err = runtime.BindQueryParameter("form", true, true, "env_schema_level", r.URL.Query(), ¶ms.EnvSchemaLevel) + if err != nil { + siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "env_schema_level", Err: err}) + return + } + + // ------------- Optional query parameter "limit" ------------- + + err = runtime.BindQueryParameter("form", true, false, "limit", r.URL.Query(), ¶ms.Limit) + if err != nil { + siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "limit", Err: err}) + return + } + + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siw.Handler.GetEndpoint(w, r, params) + })) + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler.ServeHTTP(w, r) +} + +type UnescapedCookieParamError struct { + ParamName string + Err error +} + +func (e *UnescapedCookieParamError) Error() string { + return fmt.Sprintf("error unescaping cookie parameter '%s'", e.ParamName) +} + +func (e *UnescapedCookieParamError) Unwrap() error { + return e.Err +} + +type UnmarshalingParamError struct { + ParamName string + Err error +} + +func (e *UnmarshalingParamError) Error() string { + return fmt.Sprintf("Error unmarshaling parameter %s as JSON: %s", e.ParamName, e.Err.Error()) +} + +func (e *UnmarshalingParamError) Unwrap() error { + return e.Err +} + +type RequiredParamError struct { + ParamName string +} + +func (e *RequiredParamError) Error() string { + return fmt.Sprintf("Query argument %s is required, but not found", e.ParamName) +} + +type RequiredHeaderError struct { + ParamName string + Err error +} + +func (e *RequiredHeaderError) Error() string { + return fmt.Sprintf("Header parameter %s is required, but not found", e.ParamName) +} + +func (e *RequiredHeaderError) Unwrap() error { + return e.Err +} + +type InvalidParamFormatError struct { + ParamName string + Err error +} + +func (e *InvalidParamFormatError) Error() string { + return fmt.Sprintf("Invalid format for parameter %s: %s", e.ParamName, e.Err.Error()) +} + +func (e *InvalidParamFormatError) Unwrap() error { + return e.Err +} + +type TooManyValuesForParamError struct { + ParamName string + Count int +} + +func (e *TooManyValuesForParamError) Error() string { + return fmt.Sprintf("Expected one value for %s, got %d", e.ParamName, e.Count) +} + +// Handler creates http.Handler with routing matching OpenAPI spec. +func Handler(si ServerInterface) http.Handler { + return HandlerWithOptions(si, StdHTTPServerOptions{}) +} + +// ServeMux is an abstraction of http.ServeMux. +type ServeMux interface { + HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) + ServeHTTP(w http.ResponseWriter, r *http.Request) +} + +type StdHTTPServerOptions struct { + BaseURL string + BaseRouter ServeMux + Middlewares []MiddlewareFunc + ErrorHandlerFunc func(w http.ResponseWriter, r *http.Request, err error) +} + +// HandlerFromMux creates http.Handler with routing matching OpenAPI spec based on the provided mux. +func HandlerFromMux(si ServerInterface, m ServeMux) http.Handler { + return HandlerWithOptions(si, StdHTTPServerOptions{ + BaseRouter: m, + }) +} + +func HandlerFromMuxWithBaseURL(si ServerInterface, m ServeMux, baseURL string) http.Handler { + return HandlerWithOptions(si, StdHTTPServerOptions{ + BaseURL: baseURL, + BaseRouter: m, + }) +} + +// HandlerWithOptions creates http.Handler with additional options +func HandlerWithOptions(si ServerInterface, options StdHTTPServerOptions) http.Handler { + m := options.BaseRouter + + if m == nil { + m = http.NewServeMux() + } + if options.ErrorHandlerFunc == nil { + options.ErrorHandlerFunc = func(w http.ResponseWriter, r *http.Request, err error) { + http.Error(w, err.Error(), http.StatusBadRequest) + } + } + + wrapper := ServerInterfaceWrapper{ + Handler: si, + HandlerMiddlewares: options.Middlewares, + ErrorHandlerFunc: options.ErrorHandlerFunc, + } + + m.HandleFunc("GET "+options.BaseURL+"/v1/endpoint", wrapper.GetEndpoint) + + return m +} diff --git a/internal/test/issues/issue-2232/issue2232_test.go b/internal/test/issues/issue-2232/issue2232_test.go new file mode 100644 index 0000000000..a8ae51836e --- /dev/null +++ b/internal/test/issues/issue-2232/issue2232_test.go @@ -0,0 +1,41 @@ +package issue2232 + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestExtraTagsOnQueryParams verifies that x-oapi-codegen-extra-tags is applied +// to query parameter struct fields regardless of whether the extension is placed +// at the parameter level or at the schema level within the parameter. +// This is a regression test for https://github.com/oapi-codegen/oapi-codegen/issues/2232 +func TestExtraTagsOnQueryParams(t *testing.T) { + paramType := reflect.TypeOf(GetEndpointParams{}) + + t.Run("parameter-level extension", func(t *testing.T) { + field, ok := paramType.FieldByName("EnvParamLevel") + require.True(t, ok, "field EnvParamLevel should exist") + + assert.Equal(t, `required,oneof=dev live`, field.Tag.Get("validate"), + "x-oapi-codegen-extra-tags at parameter level should produce validate tag") + }) + + t.Run("schema-level extension", func(t *testing.T) { + field, ok := paramType.FieldByName("EnvSchemaLevel") + require.True(t, ok, "field EnvSchemaLevel should exist") + + assert.Equal(t, `required,oneof=dev live`, field.Tag.Get("validate"), + "x-oapi-codegen-extra-tags at schema level within a parameter should produce validate tag") + }) + + t.Run("schema-level extension on optional param", func(t *testing.T) { + field, ok := paramType.FieldByName("Limit") + require.True(t, ok, "field Limit should exist") + + assert.Equal(t, `min=0,max=100`, field.Tag.Get("validate"), + "x-oapi-codegen-extra-tags at schema level within an optional parameter should produce validate tag") + }) +} diff --git a/internal/test/issues/issue-2232/spec.yaml b/internal/test/issues/issue-2232/spec.yaml new file mode 100644 index 0000000000..0fea5cba83 --- /dev/null +++ b/internal/test/issues/issue-2232/spec.yaml @@ -0,0 +1,46 @@ +openapi: "3.0.3" +info: + title: test + version: 1.0.0 +paths: + /v1/endpoint: + get: + operationId: GetEndpoint + parameters: + - name: env_param_level + in: query + required: true + schema: + type: string + enum: + - dev + - live + x-oapi-codegen-extra-tags: + validate: "required,oneof=dev live" + - name: env_schema_level + in: query + required: true + schema: + type: string + enum: + - dev + - live + x-oapi-codegen-extra-tags: + validate: "required,oneof=dev live" + - name: limit + in: query + required: false + schema: + type: integer + x-oapi-codegen-extra-tags: + validate: "min=0,max=100" + responses: + "200": + description: Success + content: + application/json: + schema: + type: object + properties: + message: + type: string diff --git a/pkg/codegen/operations.go b/pkg/codegen/operations.go index 1743d270ea..b928882499 100644 --- a/pkg/codegen/operations.go +++ b/pkg/codegen/operations.go @@ -925,13 +925,24 @@ func GenerateParamsTypes(op OperationDefinition) []TypeDefinition { Schema: param.Schema, }) } + // Merge extensions from the schema level and the parameter level. + // Parameter-level extensions take precedence over schema-level ones. + extensions := make(map[string]any) + if param.Spec.Schema != nil && param.Spec.Schema.Value != nil { + for k, v := range param.Spec.Schema.Value.Extensions { + extensions[k] = v + } + } + for k, v := range param.Spec.Extensions { + extensions[k] = v + } prop := Property{ Description: param.Spec.Description, JsonFieldName: param.ParamName, Required: param.Required, Schema: pSchema, NeedsFormTag: param.Style() == "form", - Extensions: param.Spec.Extensions, + Extensions: extensions, } s.Properties = append(s.Properties, prop) }