diff --git a/configure_data_plane.go b/configure_data_plane.go index 697efdd3..54f08182 100644 --- a/configure_data_plane.go +++ b/configure_data_plane.go @@ -29,6 +29,9 @@ import ( "strings" "sync" "syscall" + "bytes" + "reflect" + "fmt" "github.com/getkin/kin-openapi/openapi2" "github.com/getkin/kin-openapi/openapi2conv" @@ -86,6 +89,156 @@ func SetServerStartedCallback(callFunc func()) { serverStartedCallback = callFunc } +func strictUnknownFieldsCheck(body []byte, target any) error { + if len(bytes.TrimSpace(body)) == 0 { + return nil + } + + // Raw/plain config endpoint decodes body into string. + // Do not apply JSON strict check to string targets. + if _, ok := target.(*string); ok { + return nil + } + + var raw any + + dec := json.NewDecoder(bytes.NewReader(body)) + dec.UseNumber() + + if err := dec.Decode(&raw); err != nil { + return err + } + + return checkUnknownFields(raw, reflect.TypeOf(target), "") +} + +func checkUnknownFields(raw any, targetType reflect.Type, path string) error { + if targetType == nil { + return nil + } + + for targetType.Kind() == reflect.Pointer { + targetType = targetType.Elem() + } + + switch rawValue := raw.(type) { + case map[string]any: + switch targetType.Kind() { + case reflect.Struct: + allowed := jsonFieldNames(targetType) + + for key, value := range rawValue { + fieldType, ok := allowed[key] + if !ok { + if path == "" { + return fmt.Errorf("unknown field %q", key) + } + + return fmt.Errorf("unknown field %q at %s", key, path) + } + + childPath := key + if path != "" { + childPath = path + "." + key + } + + if err := checkUnknownFields(value, fieldType, childPath); err != nil { + return err + } + } + + case reflect.Map: + elemType := targetType.Elem() + + for key, value := range rawValue { + childPath := key + if path != "" { + childPath = path + "." + key + } + + if err := checkUnknownFields(value, elemType, childPath); err != nil { + return err + } + } + } + + case []any: + elemType := targetType + + for elemType.Kind() == reflect.Pointer { + elemType = elemType.Elem() + } + + if elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Array { + elemType = elemType.Elem() + } + + for index, item := range rawValue { + childPath := fmt.Sprintf("%s[%d]", path, index) + + if err := checkUnknownFields(item, elemType, childPath); err != nil { + return err + } + } + } + + return nil +} + +func jsonFieldNames(t reflect.Type) map[string]reflect.Type { + fields := make(map[string]reflect.Type) + + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return fields + } + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + // skip unexported non-embedded fields + if field.PkgPath != "" && !field.Anonymous { + continue + } + + fieldType := field.Type + for fieldType.Kind() == reflect.Pointer { + fieldType = fieldType.Elem() + } + + tag := field.Tag.Get("json") + name := strings.Split(tag, ",")[0] + + if tag == "-" { + continue + } + + // Embedded struct without explicit json name: + // type Resolver struct { ResolverBase } + // Need to flatten ResolverBase fields into parent object. + if field.Anonymous && (name == "" || name == field.Name) { + if fieldType.Kind() == reflect.Struct { + for embeddedName, embeddedType := range jsonFieldNames(fieldType) { + fields[embeddedName] = embeddedType + } + continue + } + } + + if name == "" { + name = field.Name + } + + fields[name] = field.Type + } + + return fields +} + + func configureFlags(api *operations.DataPlaneAPI) { cfg := dataplaneapi_config.Get() @@ -169,9 +322,20 @@ func configureAPI(api *operations.DataPlaneAPI) http.Handler { //nolint:cyclop,m api.Logger = log.Printf api.JSONConsumer = runtime.ConsumerFunc(func(reader io.Reader, data any) error { + body, err := io.ReadAll(reader) + if err != nil { + return err + } + + if err := strictUnknownFieldsCheck(body, data); err != nil { + log.Warningf("STRICT JSON UNKNOWN FIELD ERROR: %v", err) + return err + } + json := jsoniter.ConfigCompatibleWithStandardLibrary - dec := json.NewDecoder(reader) + dec := json.NewDecoder(bytes.NewReader(body)) dec.UseNumber() // preserve number formats + return dec.Decode(data) })