Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 165 additions & 1 deletion configure_data_plane.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ import (
"strings"
"sync"
"syscall"
"bytes"
"reflect"
"fmt"

"github.com/getkin/kin-openapi/openapi2"
"github.com/getkin/kin-openapi/openapi2conv"
Expand Down Expand Up @@ -86,6 +89,156 @@ func SetServerStartedCallback(callFunc func()) {
serverStartedCallback = callFunc
}

func strictUnknownFieldsCheck(body []byte, target any) error {
if len(bytes.TrimSpace(body)) == 0 {
return nil
}

// Raw/plain config endpoint decodes body into string.
// Do not apply JSON strict check to string targets.
if _, ok := target.(*string); ok {
return nil
}

var raw any

dec := json.NewDecoder(bytes.NewReader(body))
dec.UseNumber()

if err := dec.Decode(&raw); err != nil {
return err
}

return checkUnknownFields(raw, reflect.TypeOf(target), "")
}

func checkUnknownFields(raw any, targetType reflect.Type, path string) error {
if targetType == nil {
return nil
}

for targetType.Kind() == reflect.Pointer {
targetType = targetType.Elem()
}

switch rawValue := raw.(type) {
case map[string]any:
switch targetType.Kind() {
case reflect.Struct:
allowed := jsonFieldNames(targetType)

for key, value := range rawValue {
fieldType, ok := allowed[key]
if !ok {
if path == "" {
return fmt.Errorf("unknown field %q", key)
}

return fmt.Errorf("unknown field %q at %s", key, path)
}

childPath := key
if path != "" {
childPath = path + "." + key
}

if err := checkUnknownFields(value, fieldType, childPath); err != nil {
return err
}
}

case reflect.Map:
elemType := targetType.Elem()

for key, value := range rawValue {
childPath := key
if path != "" {
childPath = path + "." + key
}

if err := checkUnknownFields(value, elemType, childPath); err != nil {
return err
}
}
}

case []any:
elemType := targetType

for elemType.Kind() == reflect.Pointer {
elemType = elemType.Elem()
}

if elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Array {
elemType = elemType.Elem()
}

for index, item := range rawValue {
childPath := fmt.Sprintf("%s[%d]", path, index)

if err := checkUnknownFields(item, elemType, childPath); err != nil {
return err
}
}
}

return nil
}

func jsonFieldNames(t reflect.Type) map[string]reflect.Type {
fields := make(map[string]reflect.Type)

for t.Kind() == reflect.Pointer {
t = t.Elem()
}

if t.Kind() != reflect.Struct {
return fields
}

for i := 0; i < t.NumField(); i++ {
field := t.Field(i)

// skip unexported non-embedded fields
if field.PkgPath != "" && !field.Anonymous {
continue
}

fieldType := field.Type
for fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}

tag := field.Tag.Get("json")
name := strings.Split(tag, ",")[0]

if tag == "-" {
continue
}

// Embedded struct without explicit json name:
// type Resolver struct { ResolverBase }
// Need to flatten ResolverBase fields into parent object.
if field.Anonymous && (name == "" || name == field.Name) {
if fieldType.Kind() == reflect.Struct {
for embeddedName, embeddedType := range jsonFieldNames(fieldType) {
fields[embeddedName] = embeddedType
}
continue
}
}

if name == "" {
name = field.Name
}

fields[name] = field.Type
}

return fields
}


func configureFlags(api *operations.DataPlaneAPI) {
cfg := dataplaneapi_config.Get()

Expand Down Expand Up @@ -169,9 +322,20 @@ func configureAPI(api *operations.DataPlaneAPI) http.Handler { //nolint:cyclop,m
api.Logger = log.Printf

api.JSONConsumer = runtime.ConsumerFunc(func(reader io.Reader, data any) error {
body, err := io.ReadAll(reader)
if err != nil {
return err
}

if err := strictUnknownFieldsCheck(body, data); err != nil {
log.Warningf("STRICT JSON UNKNOWN FIELD ERROR: %v", err)
return err
}

json := jsoniter.ConfigCompatibleWithStandardLibrary
dec := json.NewDecoder(reader)
dec := json.NewDecoder(bytes.NewReader(body))
dec.UseNumber() // preserve number formats

return dec.Decode(data)
})

Expand Down