Skip to content

Commit 3624b64

Browse files
Add checker, ast, and type-provider support for JSON names (#1283)
* Add checker, ast, and type-provider support for JSON names * Support backwards compatible JSON field / proto field resolution
1 parent 6ee79cf commit 3624b64

21 files changed

Lines changed: 809 additions & 166 deletions

cel/env.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,9 +1052,10 @@ func (p *interopCELTypeProvider) FindStructFieldType(structType, fieldName strin
10521052
return nil, false
10531053
}
10541054
return &types.FieldType{
1055-
Type: t,
1056-
IsSet: ft.IsSet,
1057-
GetFrom: ft.GetFrom,
1055+
Type: t,
1056+
IsSet: ft.IsSet,
1057+
GetFrom: ft.GetFrom,
1058+
IsJSONField: ft.IsJSONField,
10581059
}, true
10591060
}
10601061
return nil, false

checker/checker.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ func Check(parsed *ast.AST, source common.Source, env *Env) (*ast.AST, *common.E
7171
// check() deletes some nodes while rewriting the AST. For example the Select operand is
7272
// deleted when a variable reference is replaced with a Ident expression.
7373
c.AST.ClearUnusedIDs()
74+
if env.jsonFieldNames {
75+
c.AST.SourceInfo().AddExtension(
76+
ast.NewExtension("json_name", ast.NewExtensionVersion(1, 1), ast.ComponentRuntime),
77+
)
78+
}
7479
return c.AST, errs
7580
}
7681

@@ -718,6 +723,9 @@ func (c *checker) lookupFieldType(exprID int64, structType, fieldName string) (*
718723
}
719724

720725
if ft, found := c.env.provider.FindStructFieldType(structType, fieldName); found {
726+
if c.env.jsonFieldNames && !ft.IsJSONField {
727+
c.errors.undefinedField(exprID, c.locationByID(exprID), fieldName)
728+
}
721729
return ft.Type, found
722730
}
723731

checker/checker_test.go

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2413,6 +2413,42 @@ _&&_(_==_(list~type(list(dyn))^list,
24132413
@result~bool^@result)~bool`,
24142414
outType: types.BoolType,
24152415
},
2416+
{
2417+
in: `TestAllTypes{?singleInt32: {}.?i}`,
2418+
container: "google.expr.proto2.test",
2419+
env: testEnv{optionalSyntax: true, jsonFieldNames: true},
2420+
out: `google.expr.proto2.test.TestAllTypes{
2421+
?singleInt32:_?._(
2422+
{}~map(dyn, int),
2423+
"i"
2424+
)~optional_type(int)^select_optional_field
2425+
}~google.expr.proto2.test.TestAllTypes^google.expr.proto2.test.TestAllTypes`,
2426+
outType: types.NewObjectType(
2427+
"google.expr.proto2.test.TestAllTypes",
2428+
),
2429+
},
2430+
{
2431+
in: `TestAllTypes{?singleInt32: {'i': 20}.?i}.singleInt32`,
2432+
container: "google.expr.proto2.test",
2433+
env: testEnv{optionalSyntax: true, jsonFieldNames: true},
2434+
out: `google.expr.proto2.test.TestAllTypes{
2435+
?singleInt32:_?._(
2436+
{
2437+
"i"~string:20~int
2438+
}~map(string, int),
2439+
"i"
2440+
)~optional_type(int)^select_optional_field
2441+
}~google.expr.proto2.test.TestAllTypes^google.expr.proto2.test.TestAllTypes.singleInt32~int`,
2442+
outType: types.IntType,
2443+
},
2444+
{
2445+
in: `TestAllTypes{singleInt32: 1, single_bool: true}.singleInt32`,
2446+
container: "google.expr.proto2.test",
2447+
env: testEnv{optionalSyntax: true, jsonFieldNames: true},
2448+
err: `ERROR: <input>:1:41: undefined field 'single_bool'
2449+
| TestAllTypes{singleInt32: 1, single_bool: true}.singleInt32
2450+
| ........................................^`,
2451+
},
24162452
}
24172453
}
24182454

@@ -2470,6 +2506,7 @@ type testEnv struct {
24702506
functions []*decls.FunctionDecl
24712507
variadicASTs bool
24722508
optionalSyntax bool
2509+
jsonFieldNames bool
24732510
}
24742511

24752512
func TestCheck(t *testing.T) {
@@ -2505,9 +2542,12 @@ func TestCheck(t *testing.T) {
25052542
t.Fatalf("Unexpected parse errors: %v", errors.ToDisplayString())
25062543
}
25072544

2508-
reg, err := types.NewRegistry(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{})
2545+
reg, err := types.NewProtoRegistry(
2546+
types.JSONFieldNames(tc.env.jsonFieldNames),
2547+
types.ProtoTypes(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{}),
2548+
)
25092549
if err != nil {
2510-
t.Fatalf("types.NewRegistry() failed: %v", err)
2550+
t.Fatalf("types.NewProtoRegistry() failed: %v", err)
25112551
}
25122552
if tc.env.optionalSyntax {
25132553
if err := reg.RegisterType(types.OptionalType); err != nil {
@@ -2522,6 +2562,9 @@ func TestCheck(t *testing.T) {
25222562
if len(tc.opts) != 0 {
25232563
opts = tc.opts
25242564
}
2565+
if tc.env.jsonFieldNames {
2566+
opts = append(opts, JSONFieldNames(true))
2567+
}
25252568
env, err := NewEnv(cont, reg, opts...)
25262569
if err != nil {
25272570
t.Fatalf("NewEnv(cont, reg) failed: %v", err)

checker/env.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ type Env struct {
7474
declarations *Scopes
7575
aggLitElemType aggregateLiteralElementType
7676
filteredOverloadIDs map[string]struct{}
77+
jsonFieldNames bool
7778
}
7879

7980
// NewEnv returns a new *Env with the given parameters.
@@ -104,6 +105,7 @@ func NewEnv(container *containers.Container, provider types.Provider, opts ...Op
104105
declarations: declarations,
105106
aggLitElemType: aggLitElemType,
106107
filteredOverloadIDs: filteredOverloadIDs,
108+
jsonFieldNames: envOptions.jsonFieldNames,
107109
}, nil
108110
}
109111

checker/options.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ type options struct {
1818
crossTypeNumericComparisons bool
1919
homogeneousAggregateLiterals bool
2020
validatedDeclarations *Scopes
21+
jsonFieldNames bool
2122
}
2223

2324
// Option is a functional option for configuring the type-checker
@@ -40,3 +41,11 @@ func ValidatedDeclarations(env *Env) Option {
4041
return nil
4142
}
4243
}
44+
45+
// JSONFieldNames enables the use of json names instead of the standard protobuf snake_case field names
46+
func JSONFieldNames(enabled bool) Option {
47+
return func(opts *options) error {
48+
opts.jsonFieldNames = enabled
49+
return nil
50+
}
51+
}

common/ast/ast.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@ func CopySourceInfo(info *SourceInfo) *SourceInfo {
231231
for id, call := range info.macroCalls {
232232
callsCopy[id] = defaultFactory.CopyExpr(call)
233233
}
234+
var extCopy []Extension
235+
if len(info.extensions) > 0 {
236+
extCopy = make([]Extension, len(info.extensions))
237+
copy(extCopy, info.extensions)
238+
}
234239
return &SourceInfo{
235240
syntax: info.syntax,
236241
desc: info.desc,
@@ -239,6 +244,7 @@ func CopySourceInfo(info *SourceInfo) *SourceInfo {
239244
baseCol: info.baseCol,
240245
offsetRanges: rangesCopy,
241246
macroCalls: callsCopy,
247+
extensions: extCopy,
242248
}
243249
}
244250

@@ -252,6 +258,9 @@ type SourceInfo struct {
252258
baseCol int32
253259
offsetRanges map[int64]OffsetRange
254260
macroCalls map[int64]Expr
261+
262+
// extensions indicate versioned optional features which affect the execution of one or more CEL component.
263+
extensions []Extension
255264
}
256265

257266
// RenumberIDs performs an in-place update of the expression IDs within the SourceInfo.
@@ -420,6 +429,23 @@ func (s *SourceInfo) ComputeOffsetAbsolute(line, col int32) int32 {
420429
return offset + col
421430
}
422431

432+
// Extensions returns the set of extensions present in the source.
433+
func (s *SourceInfo) Extensions() []Extension {
434+
var extensions []Extension
435+
if s == nil {
436+
return extensions
437+
}
438+
return s.extensions
439+
}
440+
441+
// AddExtension adds an extension record into the SourceInfo.
442+
func (s *SourceInfo) AddExtension(ext Extension) {
443+
if s == nil {
444+
return
445+
}
446+
s.extensions = append(s.extensions, ext)
447+
}
448+
423449
// OffsetRange captures the start and stop positions of a section of text in the input expression.
424450
type OffsetRange struct {
425451
Start int32
@@ -489,6 +515,53 @@ func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool {
489515
return true
490516
}
491517

518+
// NewExtension creates an Extension to be recorded on the SourceInfo.
519+
func NewExtension(id string, version ExtensionVersion, components ...ExtensionComponent) Extension {
520+
return Extension{
521+
ID: id,
522+
Version: version,
523+
Components: components,
524+
}
525+
}
526+
527+
// Extension represents a versioned, optional feature present in the AST that affects CEL component behavior.
528+
type Extension struct {
529+
// ID indicates the unique name of the extension.
530+
ID string
531+
// Version indicates the major / minor version.
532+
Version ExtensionVersion
533+
// Components enumerates the CEL components affected by the feature.
534+
Components []ExtensionComponent
535+
}
536+
537+
// NewExtensionVersion creates a new extension version with a major, minor version.
538+
func NewExtensionVersion(major, minor int64) ExtensionVersion {
539+
return ExtensionVersion{Major: major, Minor: minor}
540+
}
541+
542+
// ExtensionVersion represents a semantic version with a major and minor number.
543+
type ExtensionVersion struct {
544+
// Major version of the extension.
545+
// All versions with the same major number are expected to be compatible with all minor version changes.
546+
Major int64
547+
548+
// Minor version of the extension which indicates that some small non-semantic change has been made to
549+
// the extension.
550+
Minor int64
551+
}
552+
553+
// ExtensionComponent indicates which CEL component is affected.
554+
type ExtensionComponent int
555+
556+
const (
557+
// ComponentParser means the feature affects expression parsing.
558+
ComponentParser ExtensionComponent = iota + 1
559+
// ComponentTypeChecker means the feature affects type-checking.
560+
ComponentTypeChecker
561+
// ComponentRuntime alters program planning or evaluation of the AST.
562+
ComponentRuntime
563+
)
564+
492565
type maxIDVisitor struct {
493566
maxID int64
494567
*baseVisitor

common/ast/ast_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"reflect"
2121
"testing"
2222

23+
"github.com/google/cel-go/checker"
2324
"github.com/google/cel-go/common"
2425
"github.com/google/cel-go/common/ast"
2526
"github.com/google/cel-go/common/overloads"
@@ -83,6 +84,55 @@ func TestASTCopy(t *testing.T) {
8384
}
8485
}
8586

87+
func TestASTJsonNames(t *testing.T) {
88+
tests := []string{
89+
`google.expr.proto3.test.TestAllTypes{}`,
90+
`google.expr.proto3.test.TestAllTypes{repeatedInt32: [1, 2]}`,
91+
`google.expr.proto3.test.TestAllTypes{singleInt32: 2}.singleInt32 == 2`,
92+
}
93+
94+
for _, tst := range tests {
95+
checked := mustTypeCheck(t, tst, checker.JSONFieldNames(true), types.JSONFieldNames(true))
96+
copyChecked := ast.Copy(checked)
97+
if !reflect.DeepEqual(copyChecked.Expr(), checked.Expr()) {
98+
t.Errorf("Copy() got expr %v, wanted %v", copyChecked.Expr(), checked.Expr())
99+
}
100+
if !reflect.DeepEqual(copyChecked.SourceInfo(), checked.SourceInfo()) {
101+
t.Errorf("Copy() got source info %v, wanted %v", copyChecked.SourceInfo(), checked.SourceInfo())
102+
}
103+
copyParsed := ast.Copy(ast.NewAST(checked.Expr(), checked.SourceInfo()))
104+
if !reflect.DeepEqual(copyParsed.Expr(), checked.Expr()) {
105+
t.Errorf("Copy() got expr %v, wanted %v", copyParsed.Expr(), checked.Expr())
106+
}
107+
if !reflect.DeepEqual(copyParsed.SourceInfo(), checked.SourceInfo()) {
108+
t.Errorf("Copy() got source info %v, wanted %v", copyParsed.SourceInfo(), checked.SourceInfo())
109+
}
110+
checkedPB, err := ast.ToProto(checked)
111+
if err != nil {
112+
t.Errorf("ast.ToProto() failed: %v", err)
113+
}
114+
copyCheckedPB, err := ast.ToProto(copyChecked)
115+
if err != nil {
116+
t.Errorf("ast.ToProto() failed: %v", err)
117+
}
118+
if !proto.Equal(checkedPB, copyCheckedPB) {
119+
t.Errorf("Copy() produced different proto results, got %v, wanted %v",
120+
prototext.Format(checkedPB), prototext.Format(copyCheckedPB))
121+
}
122+
checkedRoundtrip, err := ast.ToAST(checkedPB)
123+
if err != nil {
124+
t.Errorf("ast.ToAST() failed: %v", err)
125+
}
126+
same := reflect.DeepEqual(checked.Expr(), checkedRoundtrip.Expr()) &&
127+
reflect.DeepEqual(checked.ReferenceMap(), checkedRoundtrip.ReferenceMap()) &&
128+
reflect.DeepEqual(checked.TypeMap(), checkedRoundtrip.TypeMap()) &&
129+
reflect.DeepEqual(checked.SourceInfo().MacroCalls(), checkedRoundtrip.SourceInfo().MacroCalls())
130+
if !same {
131+
t.Errorf("Roundtrip got %v, wanted %v", checkedRoundtrip, checked)
132+
}
133+
}
134+
}
135+
86136
func TestASTNilSafety(t *testing.T) {
87137
ex, err := ast.ProtoToExpr(nil)
88138
if err != nil {
@@ -184,6 +234,9 @@ func TestSourceInfoNilSafety(t *testing.T) {
184234
if len(testInfo.MacroCalls()) != 0 {
185235
t.Errorf("MacroCalls() got %v, wanted empty map", testInfo.MacroCalls())
186236
}
237+
if len(testInfo.Extensions()) != 0 {
238+
t.Errorf("Extensions() got %v, wanted empty list", testInfo.Extensions())
239+
}
187240
if call, found := testInfo.GetMacroCall(0); found {
188241
t.Errorf("GetMacroCall(0) got %v, wanted not found", call)
189242
}

common/ast/conversion.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,19 @@ import (
2727
structpb "google.golang.org/protobuf/types/known/structpb"
2828
)
2929

30+
var (
31+
pbComponentMap = map[exprpb.SourceInfo_Extension_Component]ExtensionComponent{
32+
exprpb.SourceInfo_Extension_COMPONENT_PARSER: ComponentParser,
33+
exprpb.SourceInfo_Extension_COMPONENT_TYPE_CHECKER: ComponentTypeChecker,
34+
exprpb.SourceInfo_Extension_COMPONENT_RUNTIME: ComponentRuntime,
35+
}
36+
componentPBMap = map[ExtensionComponent]exprpb.SourceInfo_Extension_Component{
37+
ComponentParser: exprpb.SourceInfo_Extension_COMPONENT_PARSER,
38+
ComponentTypeChecker: exprpb.SourceInfo_Extension_COMPONENT_TYPE_CHECKER,
39+
ComponentRuntime: exprpb.SourceInfo_Extension_COMPONENT_RUNTIME,
40+
}
41+
)
42+
3043
// ToProto converts an AST to a CheckedExpr protobouf.
3144
func ToProto(ast *AST) (*exprpb.CheckedExpr, error) {
3245
refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap()))
@@ -534,6 +547,25 @@ func SourceInfoToProto(info *SourceInfo) (*exprpb.SourceInfo, error) {
534547
}
535548
sourceInfo.MacroCalls[id] = call
536549
}
550+
for _, ext := range info.Extensions() {
551+
var components []exprpb.SourceInfo_Extension_Component
552+
for _, c := range ext.Components {
553+
comp, found := componentPBMap[c]
554+
if found {
555+
components = append(components, comp)
556+
}
557+
}
558+
ver := &exprpb.SourceInfo_Extension_Version{
559+
Major: ext.Version.Major,
560+
Minor: ext.Version.Minor,
561+
}
562+
pbExt := &exprpb.SourceInfo_Extension{
563+
Id: ext.ID,
564+
Version: ver,
565+
AffectedComponents: components,
566+
}
567+
sourceInfo.Extensions = append(sourceInfo.Extensions, pbExt)
568+
}
537569
return sourceInfo, nil
538570
}
539571

@@ -556,6 +588,23 @@ func ProtoToSourceInfo(info *exprpb.SourceInfo) (*SourceInfo, error) {
556588
}
557589
sourceInfo.SetMacroCall(id, call)
558590
}
591+
for _, pbExt := range info.GetExtensions() {
592+
var components []ExtensionComponent
593+
for _, c := range pbExt.GetAffectedComponents() {
594+
comp, found := pbComponentMap[*c.Enum()]
595+
if found {
596+
components = append(components, comp)
597+
}
598+
}
599+
sourceInfo.AddExtension(NewExtension(
600+
pbExt.GetId(),
601+
NewExtensionVersion(
602+
pbExt.GetVersion().GetMajor(),
603+
pbExt.GetVersion().GetMinor(),
604+
),
605+
components...,
606+
))
607+
}
559608
return sourceInfo, nil
560609
}
561610

0 commit comments

Comments
 (0)