@@ -20,6 +20,7 @@ import (
2020 "reflect"
2121 "testing"
2222
23+ "github.com/google/cel-go/checker"
2324 "github.com/google/cel-go/common"
2425 "github.com/google/cel-go/common/ast"
2526 "github.com/google/cel-go/common/overloads"
@@ -83,6 +84,55 @@ func TestASTCopy(t *testing.T) {
8384 }
8485}
8586
87+ func TestASTJsonNames (t * testing.T ) {
88+ tests := []string {
89+ `google.expr.proto3.test.TestAllTypes{}` ,
90+ `google.expr.proto3.test.TestAllTypes{repeatedInt32: [1, 2]}` ,
91+ `google.expr.proto3.test.TestAllTypes{singleInt32: 2}.singleInt32 == 2` ,
92+ }
93+
94+ for _ , tst := range tests {
95+ checked := mustTypeCheck (t , tst , checker .JSONFieldNames (true ), types .JSONFieldNames (true ))
96+ copyChecked := ast .Copy (checked )
97+ if ! reflect .DeepEqual (copyChecked .Expr (), checked .Expr ()) {
98+ t .Errorf ("Copy() got expr %v, wanted %v" , copyChecked .Expr (), checked .Expr ())
99+ }
100+ if ! reflect .DeepEqual (copyChecked .SourceInfo (), checked .SourceInfo ()) {
101+ t .Errorf ("Copy() got source info %v, wanted %v" , copyChecked .SourceInfo (), checked .SourceInfo ())
102+ }
103+ copyParsed := ast .Copy (ast .NewAST (checked .Expr (), checked .SourceInfo ()))
104+ if ! reflect .DeepEqual (copyParsed .Expr (), checked .Expr ()) {
105+ t .Errorf ("Copy() got expr %v, wanted %v" , copyParsed .Expr (), checked .Expr ())
106+ }
107+ if ! reflect .DeepEqual (copyParsed .SourceInfo (), checked .SourceInfo ()) {
108+ t .Errorf ("Copy() got source info %v, wanted %v" , copyParsed .SourceInfo (), checked .SourceInfo ())
109+ }
110+ checkedPB , err := ast .ToProto (checked )
111+ if err != nil {
112+ t .Errorf ("ast.ToProto() failed: %v" , err )
113+ }
114+ copyCheckedPB , err := ast .ToProto (copyChecked )
115+ if err != nil {
116+ t .Errorf ("ast.ToProto() failed: %v" , err )
117+ }
118+ if ! proto .Equal (checkedPB , copyCheckedPB ) {
119+ t .Errorf ("Copy() produced different proto results, got %v, wanted %v" ,
120+ prototext .Format (checkedPB ), prototext .Format (copyCheckedPB ))
121+ }
122+ checkedRoundtrip , err := ast .ToAST (checkedPB )
123+ if err != nil {
124+ t .Errorf ("ast.ToAST() failed: %v" , err )
125+ }
126+ same := reflect .DeepEqual (checked .Expr (), checkedRoundtrip .Expr ()) &&
127+ reflect .DeepEqual (checked .ReferenceMap (), checkedRoundtrip .ReferenceMap ()) &&
128+ reflect .DeepEqual (checked .TypeMap (), checkedRoundtrip .TypeMap ()) &&
129+ reflect .DeepEqual (checked .SourceInfo ().MacroCalls (), checkedRoundtrip .SourceInfo ().MacroCalls ())
130+ if ! same {
131+ t .Errorf ("Roundtrip got %v, wanted %v" , checkedRoundtrip , checked )
132+ }
133+ }
134+ }
135+
86136func TestASTNilSafety (t * testing.T ) {
87137 ex , err := ast .ProtoToExpr (nil )
88138 if err != nil {
@@ -184,6 +234,9 @@ func TestSourceInfoNilSafety(t *testing.T) {
184234 if len (testInfo .MacroCalls ()) != 0 {
185235 t .Errorf ("MacroCalls() got %v, wanted empty map" , testInfo .MacroCalls ())
186236 }
237+ if len (testInfo .Extensions ()) != 0 {
238+ t .Errorf ("Extensions() got %v, wanted empty list" , testInfo .Extensions ())
239+ }
187240 if call , found := testInfo .GetMacroCall (0 ); found {
188241 t .Errorf ("GetMacroCall(0) got %v, wanted not found" , call )
189242 }
0 commit comments