From 84b5c5d60bab86c1d758df72a7a66fedf6d8bd97 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 5 Mar 2026 10:53:01 -0800 Subject: [PATCH 01/66] Use structured types in YAML export. Updates Environment export to render types in variable and function declarations as structured maps instead of a string of the formatted name. PiperOrigin-RevId: 879147753 --- .../cel/bundle/CelEnvironmentExporter.java | 9 ++- .../bundle/CelEnvironmentExporterTest.java | 78 +++++++++++++++++++ .../CelEnvironmentYamlSerializerTest.java | 62 +++++++++++++++ .../test/resources/environment/dump_env.yaml | 35 +++++++++ 4 files changed, 182 insertions(+), 2 deletions(-) diff --git a/bundle/src/main/java/dev/cel/bundle/CelEnvironmentExporter.java b/bundle/src/main/java/dev/cel/bundle/CelEnvironmentExporter.java index f86787090..d233fd36f 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelEnvironmentExporter.java +++ b/bundle/src/main/java/dev/cel/bundle/CelEnvironmentExporter.java @@ -40,9 +40,9 @@ import dev.cel.common.CelVarDecl; import dev.cel.common.internal.EnvVisitable; import dev.cel.common.internal.EnvVisitor; +import dev.cel.common.types.CelKind; import dev.cel.common.types.CelProtoTypes; import dev.cel.common.types.CelType; -import dev.cel.common.types.CelTypes; import dev.cel.compiler.CelCompiler; import dev.cel.extensions.CelExtensionLibrary; import dev.cel.extensions.CelExtensions; @@ -484,7 +484,12 @@ private CelEnvironment.OverloadDecl toCelEnvOverloadDecl(CelOverloadDecl overloa } private CelEnvironment.TypeDecl toCelEnvTypeDecl(CelType type) { - return CelEnvironment.TypeDecl.create(CelTypes.format(type)); + return CelEnvironment.TypeDecl.newBuilder() + .setName(type.name()) + .setIsTypeParam(type.kind() == CelKind.TYPE_PARAM) + .addParams( + type.parameters().stream().map(this::toCelEnvTypeDecl).collect(toImmutableList())) + .build(); } /** Wrapper for CelOverloadDecl, associating it with the corresponding function name. */ diff --git a/bundle/src/test/java/dev/cel/bundle/CelEnvironmentExporterTest.java b/bundle/src/test/java/dev/cel/bundle/CelEnvironmentExporterTest.java index f70f1d466..10b9dee8e 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelEnvironmentExporterTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelEnvironmentExporterTest.java @@ -36,8 +36,10 @@ import dev.cel.common.CelOptions; import dev.cel.common.CelOverloadDecl; import dev.cel.common.CelVarDecl; +import dev.cel.common.types.ListType; import dev.cel.common.types.OpaqueType; import dev.cel.common.types.SimpleType; +import dev.cel.common.types.TypeParamType; import dev.cel.extensions.CelExtensions; import java.net.URL; import java.util.HashSet; @@ -176,6 +178,20 @@ public void customFunctions() { "math.isFinite", CelOverloadDecl.newGlobalOverload( "math_isFinite_int64", SimpleType.BOOL, SimpleType.INT)), + CelFunctionDecl.newFunctionDeclaration( + "zipGeneric", + CelOverloadDecl.newGlobalOverload( + "zip_list_list", + ListType.create(ListType.create(TypeParamType.create("T"))), + ListType.create(TypeParamType.create("T")), + ListType.create(TypeParamType.create("T")))), + CelFunctionDecl.newFunctionDeclaration( + "zip", + CelOverloadDecl.newGlobalOverload( + "zip_list_int_list_int", + ListType.create(ListType.create(SimpleType.INT)), + ListType.create(SimpleType.INT), + ListType.create(SimpleType.INT))), CelFunctionDecl.newFunctionDeclaration( "addWeeks", CelOverloadDecl.newMemberOverload( @@ -207,6 +223,68 @@ public void customFunctions() { .setTarget(TypeDecl.create("google.protobuf.Timestamp")) .setArguments(ImmutableList.of(TypeDecl.create("int"))) .setReturnType(TypeDecl.create("bool")) + .build())), + FunctionDecl.create( + "zipGeneric", + ImmutableSet.of( + OverloadDecl.newBuilder() + .setId("zip_list_list") + .setArguments( + ImmutableList.of( + TypeDecl.newBuilder() + .setName("list") + .addParams( + TypeDecl.newBuilder() + .setName("T") + .setIsTypeParam(true) + .build()) + .build(), + TypeDecl.newBuilder() + .setName("list") + .addParams( + TypeDecl.newBuilder() + .setName("T") + .setIsTypeParam(true) + .build()) + .build())) + .setReturnType( + TypeDecl.newBuilder() + .setName("list") + .addParams( + TypeDecl.newBuilder() + .setName("list") + .addParams( + TypeDecl.newBuilder() + .setName("T") + .setIsTypeParam(true) + .build()) + .build()) + .build()) + .build())), + FunctionDecl.create( + "zip", + ImmutableSet.of( + OverloadDecl.newBuilder() + .setId("zip_list_int_list_int") + .setArguments( + ImmutableList.of( + TypeDecl.newBuilder() + .setName("list") + .addParams(TypeDecl.create("int")) + .build(), + TypeDecl.newBuilder() + .setName("list") + .addParams(TypeDecl.create("int")) + .build())) + .setReturnType( + TypeDecl.newBuilder() + .setName("list") + .addParams( + TypeDecl.newBuilder() + .setName("list") + .addParams(TypeDecl.create("int")) + .build()) + .build()) .build()))); // Random-check some standard functions: we don't want to see them explicitly defined. diff --git a/bundle/src/test/java/dev/cel/bundle/CelEnvironmentYamlSerializerTest.java b/bundle/src/test/java/dev/cel/bundle/CelEnvironmentYamlSerializerTest.java index 0235cb2f4..aad72a578 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelEnvironmentYamlSerializerTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelEnvironmentYamlSerializerTest.java @@ -106,6 +106,68 @@ public void toYaml_success() throws Exception { .setReturnType( TypeDecl.newBuilder().setName("V").setIsTypeParam(true).build()) .build())), + FunctionDecl.create( + "zip", + ImmutableSet.of( + OverloadDecl.newBuilder() + .setId("zip_list_int_list_int") + .setArguments( + ImmutableList.of( + TypeDecl.newBuilder() + .setName("list") + .addParams(TypeDecl.create("int")) + .build(), + TypeDecl.newBuilder() + .setName("list") + .addParams(TypeDecl.create("int")) + .build())) + .setReturnType( + TypeDecl.newBuilder() + .setName("list") + .addParams( + TypeDecl.newBuilder() + .setName("list") + .addParams(TypeDecl.create("int")) + .build()) + .build()) + .build())), + FunctionDecl.create( + "zipGeneric", + ImmutableSet.of( + OverloadDecl.newBuilder() + .setId("zip_list_list") + .setArguments( + ImmutableList.of( + TypeDecl.newBuilder() + .setName("list") + .addParams( + TypeDecl.newBuilder() + .setName("T") + .setIsTypeParam(true) + .build()) + .build(), + TypeDecl.newBuilder() + .setName("list") + .addParams( + TypeDecl.newBuilder() + .setName("T") + .setIsTypeParam(true) + .build()) + .build())) + .setReturnType( + TypeDecl.newBuilder() + .setName("list") + .addParams( + TypeDecl.newBuilder() + .setName("list") + .addParams( + TypeDecl.newBuilder() + .setName("T") + .setIsTypeParam(true) + .build()) + .build()) + .build()) + .build())), FunctionDecl.create( "coalesce", ImmutableSet.of( diff --git a/testing/src/test/resources/environment/dump_env.yaml b/testing/src/test/resources/environment/dump_env.yaml index 1e7e7880b..a5ed23753 100644 --- a/testing/src/test/resources/environment/dump_env.yaml +++ b/testing/src/test/resources/environment/dump_env.yaml @@ -61,6 +61,41 @@ functions: return: type_name: V is_type_param: true +- name: zip + overloads: + - id: zip_list_int_list_int + args: + - type_name: list + params: + - type_name: int + - type_name: list + params: + - type_name: int + return: + type_name: list + params: + - type_name: list + params: + - type_name: int +- name: zipGeneric + overloads: + - id: zip_list_list + args: + - type_name: list + params: + - type_name: T + is_type_param: true + - type_name: list + params: + - type_name: T + is_type_param: true + return: + type_name: list + params: + - type_name: list + params: + - type_name: T + is_type_param: true - name: coalesce overloads: - id: coalesce_null_int From a5a1ef1c75c06574e1f9548cf576f2edf2a6d5e7 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 10 Mar 2026 14:20:54 -0700 Subject: [PATCH 02/66] Add test cases around type-checking gp.NullValue google.protobuf.NullValue is represented as an enum (int in CEL), but is interpreted to mean a null literal when set as the alternative in google.protobuf.Value. It is not normally referenced directly, but should behave as an int when it is. PiperOrigin-RevId: 881621568 --- .../java/dev/cel/checker/ExprCheckerTest.java | 65 +++++++++++++++++++ .../resources/jsonTypeNullAccess.baseline | 54 +++++++++++++++ .../jsonTypeNullConstruction.baseline | 35 ++++++++++ .../test/resources/jsonValueTypes.baseline | 43 +++++++++++- .../dev/cel/testing/BaseInterpreterTest.java | 17 +++++ 5 files changed, 212 insertions(+), 2 deletions(-) create mode 100644 checker/src/test/resources/jsonTypeNullAccess.baseline create mode 100644 checker/src/test/resources/jsonTypeNullConstruction.baseline diff --git a/checker/src/test/java/dev/cel/checker/ExprCheckerTest.java b/checker/src/test/java/dev/cel/checker/ExprCheckerTest.java index d5d5d9a3a..846201d32 100644 --- a/checker/src/test/java/dev/cel/checker/ExprCheckerTest.java +++ b/checker/src/test/java/dev/cel/checker/ExprCheckerTest.java @@ -517,6 +517,71 @@ public void jsonType() throws Exception { runTest(); } + @Test + public void jsonTypeNullConstruction() throws Exception { + // Ok + source = "google.protobuf.Value{null_value: google.protobuf.NullValue.NULL_VALUE}"; + runTest(); + + // Error + source = "google.protobuf.Value{null_value: null}"; + runTest(); + + // Ok + source = "cel.expr.conformance.proto3.TestAllTypes{single_value: null}"; + runTest(); + + // Ok but not expected (int coerced to double/json number 0.0) + source = + "cel.expr.conformance.proto3.TestAllTypes{single_value:" + + " google.protobuf.NullValue.NULL_VALUE}"; + runTest(); + + // Error + source = "cel.expr.conformance.proto3.TestAllTypes{null_value: null}"; + runTest(); + + // Ok + source = + "cel.expr.conformance.proto3.TestAllTypes{null_value:" + + " google.protobuf.NullValue.NULL_VALUE}"; + runTest(); + } + + @Test + public void jsonTypeNullAccess() throws Exception { + source = "google.protobuf.Value{null_value: google.protobuf.NullValue.NULL_VALUE} == null"; + runTest(); + + source = "cel.expr.conformance.proto3.TestAllTypes{single_value: null}.single_value == null"; + runTest(); + + source = + "cel.expr.conformance.proto3.TestAllTypes{single_value:" + + " google.protobuf.NullValue.NULL_VALUE}.single_value == null"; + runTest(); + + // Error + source = + "cel.expr.conformance.proto3.TestAllTypes{null_value:" + + " google.protobuf.NullValue.NULL_VALUE}.null_value == null"; + runTest(); + + // Ok + source = + "cel.expr.conformance.proto3.TestAllTypes{null_value:" + + " google.protobuf.NullValue.NULL_VALUE}.null_value == 0"; + runTest(); + + // Error + source = "google.protobuf.NullValue.NULL_VALUE == null"; + runTest(); + + // Ok + source = "google.protobuf.NullValue.NULL_VALUE == 0"; + runTest(); + } + // Call Style and User Functions // ============================= diff --git a/checker/src/test/resources/jsonTypeNullAccess.baseline b/checker/src/test/resources/jsonTypeNullAccess.baseline new file mode 100644 index 000000000..834b8fde8 --- /dev/null +++ b/checker/src/test/resources/jsonTypeNullAccess.baseline @@ -0,0 +1,54 @@ +Source: google.protobuf.Value{null_value: google.protobuf.NullValue.NULL_VALUE} == null +=====> +_==_( + google.protobuf.Value{ + null_value:google.protobuf.NullValue.NULL_VALUE~int^google.protobuf.NullValue.NULL_VALUE + }~dyn^google.protobuf.Value, + null~null +)~bool^equals + +Source: cel.expr.conformance.proto3.TestAllTypes{single_value: null}.single_value == null +=====> +_==_( + cel.expr.conformance.proto3.TestAllTypes{ + single_value:null~null + }~cel.expr.conformance.proto3.TestAllTypes^cel.expr.conformance.proto3.TestAllTypes.single_value~dyn, + null~null +)~bool^equals + +Source: cel.expr.conformance.proto3.TestAllTypes{single_value: google.protobuf.NullValue.NULL_VALUE}.single_value == null +=====> +_==_( + cel.expr.conformance.proto3.TestAllTypes{ + single_value:google.protobuf.NullValue.NULL_VALUE~int^google.protobuf.NullValue.NULL_VALUE + }~cel.expr.conformance.proto3.TestAllTypes^cel.expr.conformance.proto3.TestAllTypes.single_value~dyn, + null~null +)~bool^equals + +Source: cel.expr.conformance.proto3.TestAllTypes{null_value: google.protobuf.NullValue.NULL_VALUE}.null_value == null +=====> +ERROR: test_location:1:103: found no matching overload for '_==_' applied to '(int, null)' (candidates: (%A0, %A0)) + | cel.expr.conformance.proto3.TestAllTypes{null_value: google.protobuf.NullValue.NULL_VALUE}.null_value == null + | ......................................................................................................^ + +Source: cel.expr.conformance.proto3.TestAllTypes{null_value: google.protobuf.NullValue.NULL_VALUE}.null_value == 0 +=====> +_==_( + cel.expr.conformance.proto3.TestAllTypes{ + null_value:google.protobuf.NullValue.NULL_VALUE~int^google.protobuf.NullValue.NULL_VALUE + }~cel.expr.conformance.proto3.TestAllTypes^cel.expr.conformance.proto3.TestAllTypes.null_value~int, + 0~int +)~bool^equals + +Source: google.protobuf.NullValue.NULL_VALUE == null +=====> +ERROR: test_location:1:38: found no matching overload for '_==_' applied to '(int, null)' (candidates: (%A0, %A0)) + | google.protobuf.NullValue.NULL_VALUE == null + | .....................................^ + +Source: google.protobuf.NullValue.NULL_VALUE == 0 +=====> +_==_( + google.protobuf.NullValue.NULL_VALUE~int^google.protobuf.NullValue.NULL_VALUE, + 0~int +)~bool^equals \ No newline at end of file diff --git a/checker/src/test/resources/jsonTypeNullConstruction.baseline b/checker/src/test/resources/jsonTypeNullConstruction.baseline new file mode 100644 index 000000000..5b9b211a8 --- /dev/null +++ b/checker/src/test/resources/jsonTypeNullConstruction.baseline @@ -0,0 +1,35 @@ +Source: google.protobuf.Value{null_value: google.protobuf.NullValue.NULL_VALUE} +=====> +google.protobuf.Value{ + null_value:google.protobuf.NullValue.NULL_VALUE~int^google.protobuf.NullValue.NULL_VALUE +}~dyn^google.protobuf.Value + +Source: google.protobuf.Value{null_value: null} +=====> +ERROR: test_location:1:33: expected type of field 'null_value' is 'int' but provided type is 'null' + | google.protobuf.Value{null_value: null} + | ................................^ + +Source: cel.expr.conformance.proto3.TestAllTypes{single_value: null} +=====> +cel.expr.conformance.proto3.TestAllTypes{ + single_value:null~null +}~cel.expr.conformance.proto3.TestAllTypes^cel.expr.conformance.proto3.TestAllTypes + +Source: cel.expr.conformance.proto3.TestAllTypes{single_value: google.protobuf.NullValue.NULL_VALUE} +=====> +cel.expr.conformance.proto3.TestAllTypes{ + single_value:google.protobuf.NullValue.NULL_VALUE~int^google.protobuf.NullValue.NULL_VALUE +}~cel.expr.conformance.proto3.TestAllTypes^cel.expr.conformance.proto3.TestAllTypes + +Source: cel.expr.conformance.proto3.TestAllTypes{null_value: null} +=====> +ERROR: test_location:1:52: expected type of field 'null_value' is 'int' but provided type is 'null' + | cel.expr.conformance.proto3.TestAllTypes{null_value: null} + | ...................................................^ + +Source: cel.expr.conformance.proto3.TestAllTypes{null_value: google.protobuf.NullValue.NULL_VALUE} +=====> +cel.expr.conformance.proto3.TestAllTypes{ + null_value:google.protobuf.NullValue.NULL_VALUE~int^google.protobuf.NullValue.NULL_VALUE +}~cel.expr.conformance.proto3.TestAllTypes^cel.expr.conformance.proto3.TestAllTypes \ No newline at end of file diff --git a/runtime/src/test/resources/jsonValueTypes.baseline b/runtime/src/test/resources/jsonValueTypes.baseline index ff406cf94..cc840b24b 100644 --- a/runtime/src/test/resources/jsonValueTypes.baseline +++ b/runtime/src/test/resources/jsonValueTypes.baseline @@ -53,6 +53,46 @@ bindings: {x=single_value { } result: true +Source: google.protobuf.Value{string_value: 'hello'} == 'hello' +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {} +result: true + +Source: google.protobuf.Value{number_value: 1.1} == 1.1 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {} +result: true + +Source: google.protobuf.Value{null_value: google.protobuf.NullValue.NULL_VALUE} == null +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {} +result: true + +Source: TestAllTypes{null_value: google.protobuf.NullValue.NULL_VALUE}.null_value == 0 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {} +result: true + +Source: TestAllTypes{null_value: google.protobuf.NullValue.NULL_VALUE}.null_value != dyn(null) +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {} +result: true + Source: x.single_value[0] == [['hello'], -1.1][0] declare x { value cel.expr.conformance.proto3.TestAllTypes @@ -148,5 +188,4 @@ bindings: {x=single_struct { } } } -result: {hello=val} - +result: {hello=val} \ No newline at end of file diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index 144ada5a8..b3c9af423 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -1850,6 +1850,23 @@ public void jsonValueTypes() { source = "x.single_value == 'hello'"; runTest(ImmutableMap.of("x", xString)); + // json manual construction + source = "google.protobuf.Value{string_value: 'hello'} == 'hello'"; + runTest(); + + source = "google.protobuf.Value{number_value: 1.1} == 1.1"; + runTest(); + + source = "google.protobuf.Value{null_value: google.protobuf.NullValue.NULL_VALUE} == null"; + runTest(); + + // NULL_VALUE is not the same as null. + source = "TestAllTypes{null_value: google.protobuf.NullValue.NULL_VALUE}.null_value == 0"; + runTest(); + source = + "TestAllTypes{null_value: google.protobuf.NullValue.NULL_VALUE}.null_value != dyn(null)"; + runTest(); + // JSON list equality. TestAllTypes xList = TestAllTypes.newBuilder() From 62c1f111d0b9d02f44185dd9a720bab768528916 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 12 Mar 2026 11:08:12 -0700 Subject: [PATCH 03/66] Null assignability fix for repeated and map fields PiperOrigin-RevId: 882683464 --- .../dev/cel/common/internal/ProtoAdapter.java | 51 +++++++++++++++++-- .../test/resources/nullAssignability.baseline | 29 +++++++++++ .../dev/cel/testing/BaseInterpreterTest.java | 27 ++++++++++ 3 files changed, 102 insertions(+), 5 deletions(-) diff --git a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java index b6648a5b8..7e3910433 100644 --- a/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java +++ b/common/src/main/java/dev/cel/common/internal/ProtoAdapter.java @@ -222,9 +222,7 @@ public Optional adaptValueToFieldType( throw new IllegalArgumentException("Unsupported field type"); } - String typeFullName = fieldDescriptor.getMessageType().getFullName(); - if (!WellKnownProto.ANY_VALUE.typeName().equals(typeFullName) - && !WellKnownProto.JSON_VALUE.typeName().equals(typeFullName)) { + if (!isFieldAnyOrJson(fieldDescriptor)) { return Optional.empty(); } } @@ -242,7 +240,11 @@ public Optional adaptValueToFieldType( getDefaultValueForMaybeMessage(keyDescriptor), valueDescriptor.getLiteType(), getDefaultValueForMaybeMessage(valueDescriptor)); + boolean isValueAnyOrJson = isFieldAnyOrJson(valueDescriptor); for (Map.Entry entry : ((Map) fieldValue).entrySet()) { + if (!isValueAnyOrJson && entry.getValue() instanceof NullValue) { + continue; + } mapEntries.add( protoMapEntry.toBuilder() .setKey(keyConverter.backwardConverter().convert(entry.getKey())) @@ -252,15 +254,54 @@ public Optional adaptValueToFieldType( return Optional.of(mapEntries); } if (fieldDescriptor.isRepeated()) { + List listValue = (List) fieldValue; + + if (!isFieldAnyOrJson(fieldDescriptor)) { + listValue = filterOutNullValues(listValue); + } + return Optional.of( - AdaptingTypes.adaptingList( - (List) fieldValue, fieldToValueConverter(fieldDescriptor).reverse())); + AdaptingTypes.adaptingList(listValue, fieldToValueConverter(fieldDescriptor).reverse())); } return Optional.of( fieldToValueConverter(fieldDescriptor).backwardConverter().convert(fieldValue)); } + private static List filterOutNullValues(List originalList) { + List filteredList = null; + + for (int i = 0; i < originalList.size(); i++) { + Object elem = originalList.get(i); + + if (elem instanceof NullValue) { + if (filteredList == null) { + filteredList = new ArrayList<>(originalList.size() - 1); + if (i > 0) { + filteredList.addAll(originalList.subList(0, i)); + } + } + } else if (filteredList != null) { + filteredList.add(elem); + } + } + + // Return the original list if no nulls were found to avoid unnecessary allocations + return filteredList != null ? filteredList : originalList; + } + + private static boolean isFieldAnyOrJson(FieldDescriptor fieldDescriptor) { + if (!fieldDescriptor.getType().equals(FieldDescriptor.Type.MESSAGE)) { + return false; + } + + String typeFullName = fieldDescriptor.getMessageType().getFullName(); + + return WellKnownProto.getByTypeName(typeFullName) + .map(wkp -> wkp.equals(WellKnownProto.ANY_VALUE) || wkp.equals(WellKnownProto.JSON_VALUE)) + .orElse(false); + } + @SuppressWarnings("rawtypes") private BidiConverter fieldToValueConverter(FieldDescriptor fieldDescriptor) { switch (fieldDescriptor.getType()) { diff --git a/runtime/src/test/resources/nullAssignability.baseline b/runtime/src/test/resources/nullAssignability.baseline index 47b9c7a0d..b60f434ea 100644 --- a/runtime/src/test/resources/nullAssignability.baseline +++ b/runtime/src/test/resources/nullAssignability.baseline @@ -33,3 +33,32 @@ Source: has(TestAllTypes{single_timestamp: null}.single_timestamp) bindings: {} result: false +Source: TestAllTypes{repeated_timestamp: [timestamp(1), null]}.repeated_timestamp == [timestamp(1)] +=====> +bindings: {} +result: true + +Source: TestAllTypes{map_bool_timestamp: {true: null, false: timestamp(1)}}.map_bool_timestamp == {false: timestamp(1)} +=====> +bindings: {} +result: true + +Source: TestAllTypes{repeated_any: [1, null]}.repeated_any == [1, null] +=====> +bindings: {} +result: true + +Source: TestAllTypes{map_bool_any: {true: null, false: 1}}.map_bool_any == {true: null, false: 1} +=====> +bindings: {} +result: true + +Source: TestAllTypes{repeated_value: [google.protobuf.Value{bool_value: true}, null]}.repeated_value == [true, null] +=====> +bindings: {} +result: true + +Source: TestAllTypes{map_bool_value: {true: null, false: google.protobuf.Value{bool_value: true}}}.map_bool_value == {true: null, false: true} +=====> +bindings: {} +result: true \ No newline at end of file diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index b3c9af423..f3c1cf398 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -2162,6 +2162,33 @@ public void nullAssignability() throws Exception { source = "has(TestAllTypes{single_timestamp: null}.single_timestamp)"; runTest(); + + source = + "TestAllTypes{repeated_timestamp: [timestamp(1), null]}.repeated_timestamp ==" + + " [timestamp(1)]"; + runTest(); + + source = + "TestAllTypes{map_bool_timestamp: {true: null, false: timestamp(1)}}.map_bool_timestamp ==" + + " {false: timestamp(1)}"; + runTest(); + + source = "TestAllTypes{repeated_any: [1, null]}.repeated_any == [1, null]"; + runTest(); + + source = + "TestAllTypes{map_bool_any: {true: null, false: 1}}.map_bool_any == {true: null, false: 1}"; + runTest(); + + source = + "TestAllTypes{repeated_value: [google.protobuf.Value{bool_value: true}," + + " null]}.repeated_value == [true, null]"; + runTest(); + + source = + "TestAllTypes{map_bool_value: {true: null, false: google.protobuf.Value{bool_value:" + + " true}}}.map_bool_value == {true: null, false: true}"; + runTest(); } @Test From 86ddee361843e561bc66e28a4beae742be42d490 Mon Sep 17 00:00:00 2001 From: Salman Muin Kayser Chishti <13schishti@gmail.com> Date: Fri, 13 Mar 2026 07:54:27 +0000 Subject: [PATCH 04/66] Upgrade GitHub Actions to latest versions Signed-off-by: Salman Muin Kayser Chishti <13schishti@gmail.com> --- .github/workflows/workflow.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index 4b206e3ec..b172788c3 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -82,7 +82,7 @@ jobs: uses: actions/checkout@v6 - name: Get changed files id: changed_file - uses: tj-actions/changed-files@v46 + uses: tj-actions/changed-files@v47 with: files: publish/cel_version.bzl - name: Setup Bazel From 4f6a571bfed4efa533d7ee6c3e02fb43fb4900cf Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 17 Mar 2026 14:39:44 -0700 Subject: [PATCH 05/66] Support partial evaluation via unknowns in planner PiperOrigin-RevId: 885217146 --- .../test/java/dev/cel/extensions/BUILD.bazel | 2 + .../extensions/CelOptionalLibraryTest.java | 49 +- runtime/BUILD.bazel | 14 + .../dev/cel/runtime/AccumulatedUnknowns.java | 25 +- .../src/main/java/dev/cel/runtime/BUILD.bazel | 45 +- .../java/dev/cel/runtime/CelRuntimeImpl.java | 5 + .../java/dev/cel/runtime/CelUnknownSet.java | 2 +- .../java/dev/cel/runtime/InterpreterUtil.java | 17 +- .../java/dev/cel/runtime/LiteProgramImpl.java | 6 + .../java/dev/cel/runtime/PartialVars.java | 70 +++ .../main/java/dev/cel/runtime/Program.java | 3 + .../java/dev/cel/runtime/ProgramImpl.java | 8 + .../dev/cel/runtime/planner/Attribute.java | 2 +- .../java/dev/cel/runtime/planner/BUILD.bazel | 22 +- .../java/dev/cel/runtime/planner/EvalAnd.java | 139 +++--- .../cel/runtime/planner/EvalAttribute.java | 4 +- .../cel/runtime/planner/EvalConditional.java | 5 +- .../cel/runtime/planner/EvalCreateList.java | 12 + .../cel/runtime/planner/EvalCreateMap.java | 62 ++- .../cel/runtime/planner/EvalCreateStruct.java | 22 +- .../dev/cel/runtime/planner/EvalFold.java | 4 + .../dev/cel/runtime/planner/EvalHelpers.java | 156 +++--- .../runtime/planner/EvalLateBoundCall.java | 8 + .../cel/runtime/planner/EvalOptionalOr.java | 5 + .../runtime/planner/EvalOptionalOrValue.java | 5 + .../planner/EvalOptionalSelectField.java | 9 + .../java/dev/cel/runtime/planner/EvalOr.java | 139 +++--- .../cel/runtime/planner/EvalVarArgsCall.java | 8 + .../cel/runtime/planner/ExecutionFrame.java | 15 +- .../cel/runtime/planner/MaybeAttribute.java | 4 +- .../cel/runtime/planner/MissingAttribute.java | 2 +- .../runtime/planner/NamespacedAttribute.java | 75 ++- .../cel/runtime/planner/PlannedProgram.java | 32 +- .../runtime/planner/RelativeAttribute.java | 8 +- .../src/test/java/dev/cel/runtime/BUILD.bazel | 9 + .../cel/runtime/PlannerInterpreterTest.java | 255 +++++++++- .../java/dev/cel/runtime/planner/BUILD.bazel | 2 + .../runtime/planner/ProgramPlannerTest.java | 33 ++ .../planner_unknownFieldSelection.baseline | 111 +++++ .../planner_unknownResultSet_errors.baseline | 81 +++ .../planner_unknownResultSet_success.baseline | 461 ++++++++++++++++++ .../src/test/resources/unknownField.baseline | 2 +- .../src/main/java/dev/cel/testing/BUILD.bazel | 2 +- .../dev/cel/testing/BaseInterpreterTest.java | 46 +- 44 files changed, 1665 insertions(+), 321 deletions(-) create mode 100644 runtime/src/main/java/dev/cel/runtime/PartialVars.java create mode 100644 runtime/src/test/resources/planner_unknownFieldSelection.baseline create mode 100644 runtime/src/test/resources/planner_unknownResultSet_errors.baseline create mode 100644 runtime/src/test/resources/planner_unknownResultSet_success.baseline diff --git a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel index 48915fd02..a9dbfaca2 100644 --- a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel @@ -38,6 +38,8 @@ java_library( "//runtime:interpreter_util", "//runtime:lite_runtime", "//runtime:lite_runtime_factory", + "//runtime:partial_vars", + "//runtime:unknown_attributes", "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", "@cel_spec//proto/cel/expr/conformance/test:simple_java_proto", diff --git a/extensions/src/test/java/dev/cel/extensions/CelOptionalLibraryTest.java b/extensions/src/test/java/dev/cel/extensions/CelOptionalLibraryTest.java index 24e9d6d86..ab412fb39 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelOptionalLibraryTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelOptionalLibraryTest.java @@ -49,10 +49,12 @@ import dev.cel.expr.conformance.proto3.TestAllTypes.NestedMessage; import dev.cel.parser.CelMacro; import dev.cel.parser.CelStandardMacro; +import dev.cel.runtime.CelAttributePattern; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelRuntime; import dev.cel.runtime.InterpreterUtil; +import dev.cel.runtime.PartialVars; import java.time.Duration; import java.time.Instant; import java.util.List; @@ -897,14 +899,12 @@ public void optionalIndex_onMap_returnsOptionalValue() throws Exception { @TestParameters("{source: '{?x: x}'}") public void optionalIndex_onMapWithUnknownInput_returnsUnknownResult(String source) throws Exception { - if (testMode.equals(TestMode.PLANNER_CHECKED) || testMode.equals(TestMode.PLANNER_PARSE_ONLY)) { - // TODO: Uncomment once unknowns is implemented - return; - } Cel cel = newCelBuilder().addVar("x", OptionalType.create(SimpleType.INT)).build(); CelAbstractSyntaxTree ast = compile(cel, source); - Object result = cel.createProgram(ast).eval(); + Object result = + cel.createProgram(ast) + .eval(PartialVars.of(CelAttributePattern.fromQualifiedIdentifier("x"))); assertThat(InterpreterUtil.isUnknown(result)).isTrue(); } @@ -987,10 +987,6 @@ public void optionalIndex_onOptionalList_returnsOptionalValue() throws Exception @Test public void optionalIndex_onListWithUnknownInput_returnsUnknownResult() throws Exception { - if (testMode.equals(TestMode.PLANNER_CHECKED) || testMode.equals(TestMode.PLANNER_PARSE_ONLY)) { - // TODO: Uncomment once unknowns is implemented - return; - } Cel cel = newCelBuilder() .addVar("x", OptionalType.create(SimpleType.INT)) @@ -998,7 +994,9 @@ public void optionalIndex_onListWithUnknownInput_returnsUnknownResult() throws E .build(); CelAbstractSyntaxTree ast = compile(cel, "[?x]"); - Object result = cel.createProgram(ast).eval(); + Object result = + cel.createProgram(ast) + .eval(PartialVars.of(CelAttributePattern.fromQualifiedIdentifier("x"))); assertThat(InterpreterUtil.isUnknown(result)).isTrue(); } @@ -1017,6 +1015,29 @@ public void traditionalIndex_onOptionalList_returnsOptionalEmpty() throws Except assertThat(result).isEqualTo(Optional.empty()); } + @Test + public void optionalFieldSelect_fieldMarkedUnknown_returnsUnknownSet() throws Exception { + if (testMode.equals(TestMode.LEGACY_CHECKED)) { + // This case is not possible to setup for legacy runtime + return; + } + + Cel cel = + newCelBuilder() + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .build(); + CelAbstractSyntaxTree ast = compile(cel, "msg.?single_int32"); + + Object result = + cel.createProgram(ast) + .eval( + PartialVars.of( + ImmutableMap.of("msg", TestAllTypes.newBuilder().setSingleInt32(42).build()), + CelAttributePattern.fromQualifiedIdentifier("msg.single_int32"))); + + assertThat(InterpreterUtil.isUnknown(result)).isTrue(); + } + @Test // LHS @TestParameters("{expression: 'optx.or(optional.of(1))'}") @@ -1026,10 +1047,6 @@ public void traditionalIndex_onOptionalList_returnsOptionalEmpty() throws Except @TestParameters("{expression: 'optional.none().orValue(optx)'}") public void optionalChainedFunctions_lhsIsUnknown_returnsUnknown(String expression) throws Exception { - if (testMode.equals(TestMode.PLANNER_CHECKED) || testMode.equals(TestMode.PLANNER_PARSE_ONLY)) { - // TODO: Uncomment once unknowns is implemented - return; - } Cel cel = newCelBuilder() .addVar("optx", OptionalType.create(SimpleType.INT)) @@ -1037,7 +1054,9 @@ public void optionalChainedFunctions_lhsIsUnknown_returnsUnknown(String expressi .build(); CelAbstractSyntaxTree ast = compile(cel, expression); - Object result = cel.createProgram(ast).eval(); + Object result = + cel.createProgram(ast) + .eval(PartialVars.of(CelAttributePattern.fromQualifiedIdentifier("optx"))); assertThat(InterpreterUtil.isUnknown(result)).isTrue(); } diff --git a/runtime/BUILD.bazel b/runtime/BUILD.bazel index 55ee241a0..3e183d236 100644 --- a/runtime/BUILD.bazel +++ b/runtime/BUILD.bazel @@ -351,3 +351,17 @@ java_library( "//runtime/src/main/java/dev/cel/runtime:runtime_planner_impl", ], ) + +java_library( + name = "accumulated_unknowns", + visibility = ["//:internal"], + exports = [ + "//runtime/src/main/java/dev/cel/runtime:accumulated_unknowns", + ], +) + +java_library( + name = "partial_vars", + visibility = ["//:internal"], + exports = ["//runtime/src/main/java/dev/cel/runtime:partial_vars"], +) diff --git a/runtime/src/main/java/dev/cel/runtime/AccumulatedUnknowns.java b/runtime/src/main/java/dev/cel/runtime/AccumulatedUnknowns.java index d27de2da2..d4d54c71f 100644 --- a/runtime/src/main/java/dev/cel/runtime/AccumulatedUnknowns.java +++ b/runtime/src/main/java/dev/cel/runtime/AccumulatedUnknowns.java @@ -15,18 +15,23 @@ package dev.cel.runtime; import com.google.errorprone.annotations.CanIgnoreReturnValue; +import dev.cel.common.annotations.Internal; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashSet; import java.util.Set; +import org.jspecify.annotations.Nullable; /** * An internal representation used for fast accumulation of unknown expr IDs and attributes. For * safety, this object should never be returned as an evaluated result and instead be adapted into * an immutable CelUnknownSet. + * + *

CEL Library Internals. Do Not Use. */ -final class AccumulatedUnknowns { +@Internal +public final class AccumulatedUnknowns { private static final int MAX_UNKNOWN_ATTRIBUTE_SIZE = 500_000; private final Set exprIds; private final Set attributes; @@ -39,8 +44,21 @@ Set attributes() { return attributes; } + /** + * Evaluates if the right hand side is an accumulated unknown, and if so, merges it into the + * accumulator. + */ + public static @Nullable AccumulatedUnknowns maybeMerge( + @Nullable AccumulatedUnknowns accumulator, Object newValue) { + if (newValue instanceof AccumulatedUnknowns) { + AccumulatedUnknowns newUnknowns = (AccumulatedUnknowns) newValue; + return accumulator == null ? newUnknowns : accumulator.merge(newUnknowns); + } + return accumulator; + } + @CanIgnoreReturnValue - AccumulatedUnknowns merge(AccumulatedUnknowns arg) { + public AccumulatedUnknowns merge(AccumulatedUnknowns arg) { enforceMaxAttributeSize(this.attributes, arg.attributes); this.exprIds.addAll(arg.exprIds); this.attributes.addAll(arg.attributes); @@ -55,7 +73,8 @@ static AccumulatedUnknowns create(Collection ids) { return create(ids, new ArrayList<>()); } - static AccumulatedUnknowns create(Collection exprIds, Collection attributes) { + public static AccumulatedUnknowns create( + Collection exprIds, Collection attributes) { return new AccumulatedUnknowns(new HashSet<>(exprIds), new HashSet<>(attributes)); } diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index 10dca9ece..2681c17de 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -826,6 +826,7 @@ java_library( ":evaluation_listener", ":function_binding", ":function_resolver", + ":partial_vars", ":program", ":proto_message_runtime_equality", ":runtime", @@ -938,6 +939,7 @@ java_library( ":function_resolver", ":interpretable", ":interpreter", + ":partial_vars", ":program", ":proto_message_activation_factory", ":runtime_equality", @@ -955,7 +957,6 @@ java_library( "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", - "@maven//:org_jspecify_jspecify", ], ) @@ -1014,6 +1015,7 @@ java_library( ":evaluation_exception", ":function_resolver", ":interpretable", + ":partial_vars", ":program", ":variable_resolver", "//:auto_value", @@ -1029,6 +1031,7 @@ cel_android_library( ":evaluation_exception", ":function_resolver_android", ":interpretable_android", + ":partial_vars_android", ":program_android", ":variable_resolver", "//:auto_value", @@ -1199,6 +1202,7 @@ java_library( ":unknown_attributes", "//common/annotations", "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", "@maven//:org_jspecify_jspecify", ], ) @@ -1214,6 +1218,7 @@ cel_android_library( "//common/annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:org_jspecify_jspecify", + "@maven_android//:com_google_guava_guava", ], ) @@ -1273,10 +1278,13 @@ java_library( java_library( name = "accumulated_unknowns", srcs = ["AccumulatedUnknowns.java"], - visibility = ["//visibility:private"], + tags = [ + ], deps = [ ":unknown_attributes", + "//common/annotations", "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:org_jspecify_jspecify", ], ) @@ -1286,7 +1294,9 @@ cel_android_library( visibility = ["//visibility:private"], deps = [ ":unknown_attributes_android", + "//common/annotations", "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:org_jspecify_jspecify", ], ) @@ -1318,6 +1328,34 @@ cel_android_library( ], ) +java_library( + name = "partial_vars", + srcs = ["PartialVars.java"], + tags = [ + ], + deps = [ + ":variable_resolver", + "//:auto_value", + "//runtime:unknown_attributes", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + ], +) + +cel_android_library( + name = "partial_vars_android", + srcs = ["PartialVars.java"], + tags = [ + ], + deps = [ + ":variable_resolver", + "//:auto_value", + "//runtime:unknown_attributes_android", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven_android//:com_google_guava_guava", + ], +) + java_library( name = "program", srcs = ["Program.java"], @@ -1326,6 +1364,7 @@ java_library( deps = [ ":evaluation_exception", ":function_resolver", + ":partial_vars", ":variable_resolver", "@maven//:com_google_errorprone_error_prone_annotations", ], @@ -1339,8 +1378,8 @@ cel_android_library( deps = [ ":evaluation_exception", ":function_resolver_android", + ":partial_vars_android", ":variable_resolver", - "//:auto_value", "@maven//:com_google_errorprone_error_prone_annotations", ], ) diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java index 346b25ae9..cab2c666e 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java @@ -134,6 +134,11 @@ public Object eval( return program.eval(resolver, lateBoundFunctionResolver); } + @Override + public Object eval(PartialVars partialVars) throws CelEvaluationException { + return program.eval(partialVars); + } + @Override public Object trace(CelEvaluationListener listener) throws CelEvaluationException { throw new UnsupportedOperationException("Trace is not yet supported."); diff --git a/runtime/src/main/java/dev/cel/runtime/CelUnknownSet.java b/runtime/src/main/java/dev/cel/runtime/CelUnknownSet.java index c7f1d0c91..62d975f93 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelUnknownSet.java +++ b/runtime/src/main/java/dev/cel/runtime/CelUnknownSet.java @@ -59,7 +59,7 @@ static CelUnknownSet create(Iterable unknownExprIds) { return create(ImmutableSet.of(), ImmutableSet.copyOf(unknownExprIds)); } - static CelUnknownSet create( + public static CelUnknownSet create( ImmutableSet attributes, ImmutableSet unknownExprIds) { return new AutoValue_CelUnknownSet(attributes, unknownExprIds); } diff --git a/runtime/src/main/java/dev/cel/runtime/InterpreterUtil.java b/runtime/src/main/java/dev/cel/runtime/InterpreterUtil.java index f84897ac2..73607cefd 100644 --- a/runtime/src/main/java/dev/cel/runtime/InterpreterUtil.java +++ b/runtime/src/main/java/dev/cel/runtime/InterpreterUtil.java @@ -14,6 +14,7 @@ package dev.cel.runtime; +import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.CheckReturnValue; import dev.cel.common.annotations.Internal; import org.jspecify.annotations.Nullable; @@ -55,12 +56,12 @@ public static boolean isUnknown(Object obj) { return obj instanceof CelUnknownSet; } - static boolean isAccumulatedUnknowns(Object obj) { + public static boolean isAccumulatedUnknowns(Object obj) { return obj instanceof AccumulatedUnknowns; } /** If the argument is {@link CelUnknownSet}, adapts it into {@link AccumulatedUnknowns} */ - static Object maybeAdaptToAccumulatedUnknowns(Object val) { + public static Object maybeAdaptToAccumulatedUnknowns(Object val) { if (!(val instanceof CelUnknownSet)) { return val; } @@ -68,10 +69,20 @@ static Object maybeAdaptToAccumulatedUnknowns(Object val) { return adaptToAccumulatedUnknowns((CelUnknownSet) val); } - static AccumulatedUnknowns adaptToAccumulatedUnknowns(CelUnknownSet unknowns) { + public static AccumulatedUnknowns adaptToAccumulatedUnknowns(CelUnknownSet unknowns) { return AccumulatedUnknowns.create(unknowns.unknownExprIds(), unknowns.attributes()); } + public static Object maybeAdaptToCelUnknownSet(Object val) { + if (!(val instanceof AccumulatedUnknowns)) { + return val; + } + + AccumulatedUnknowns unknowns = (AccumulatedUnknowns) val; + return CelUnknownSet.create( + ImmutableSet.copyOf(unknowns.attributes()), ImmutableSet.copyOf(unknowns.exprIds())); + } + /** * Enforces strictness on both lhs/rhs arguments from logical operators (i.e: intentionally throws * an appropriate exception when {@link Throwable} is encountered as part of evaluated result. diff --git a/runtime/src/main/java/dev/cel/runtime/LiteProgramImpl.java b/runtime/src/main/java/dev/cel/runtime/LiteProgramImpl.java index 5e57f497b..af8c1a6d0 100644 --- a/runtime/src/main/java/dev/cel/runtime/LiteProgramImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/LiteProgramImpl.java @@ -52,6 +52,12 @@ public Object eval(CelVariableResolver resolver) throws CelEvaluationException { throw new UnsupportedOperationException("To be implemented"); } + @Override + public Object eval(PartialVars partialVars) throws CelEvaluationException { + // TODO: Wire in program planner + throw new UnsupportedOperationException("To be implemented"); + } + static Program plan(Interpretable interpretable) { return new AutoValue_LiteProgramImpl(interpretable); } diff --git a/runtime/src/main/java/dev/cel/runtime/PartialVars.java b/runtime/src/main/java/dev/cel/runtime/PartialVars.java new file mode 100644 index 000000000..1cd081040 --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/PartialVars.java @@ -0,0 +1,70 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import java.util.Map; +import java.util.Optional; + +/** + * A holder for a {@link CelVariableResolver} and a set of {@link CelAttributePattern}s that + * indicate variables or parts of variables whose value are not yet known. + */ +@AutoValue +public abstract class PartialVars { + + /** The resolver to use for resolving evaluation variables. */ + public abstract CelVariableResolver resolver(); + + /** + * A list of attribute patterns specifying which missing attribute paths should be tracked as + * unknown values. + */ + public abstract ImmutableList unknowns(); + + /** Constructs a new {@code PartialVars} from one or more {@link CelAttributePattern}s. */ + public static PartialVars of(CelAttributePattern... unknownAttributes) { + return of((unused) -> Optional.empty(), ImmutableList.copyOf(unknownAttributes)); + } + + /** + * Constructs a new {@code PartialVars} from a {@link CelVariableResolver} and a list of {@link + * CelAttributePattern}s. + */ + public static PartialVars of( + CelVariableResolver resolver, Iterable unknownAttributes) { + return new AutoValue_PartialVars(resolver, ImmutableList.copyOf(unknownAttributes)); + } + + /** + * Constructs a new {@code PartialVars} from a map of variables and an array of {@link + * CelAttributePattern}s. + */ + public static PartialVars of(Map variables, CelAttributePattern... unknownAttributes) { + return of( + (name) -> variables.containsKey(name) ? Optional.of(variables.get(name)) : Optional.empty(), + unknownAttributes); + } + + /** + * Constructs a new {@code PartialVars} from a {@link CelVariableResolver} and an array of {@link + * CelAttributePattern}s. + */ + public static PartialVars of( + CelVariableResolver resolver, CelAttributePattern... unknownAttributes) { + return of(resolver, ImmutableList.copyOf(unknownAttributes)); + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/Program.java b/runtime/src/main/java/dev/cel/runtime/Program.java index c0982f1f8..e808a373c 100644 --- a/runtime/src/main/java/dev/cel/runtime/Program.java +++ b/runtime/src/main/java/dev/cel/runtime/Program.java @@ -43,4 +43,7 @@ Object eval(Map mapValue, CelFunctionResolver lateBoundFunctionResolv */ Object eval(CelVariableResolver resolver, CelFunctionResolver lateBoundFunctionResolver) throws CelEvaluationException; + + /** Evaluate a compiled program with unknown attribute patterns {@code partialVars}. */ + Object eval(PartialVars partialVars) throws CelEvaluationException; } diff --git a/runtime/src/main/java/dev/cel/runtime/ProgramImpl.java b/runtime/src/main/java/dev/cel/runtime/ProgramImpl.java index d0e64429b..c9f4d083b 100644 --- a/runtime/src/main/java/dev/cel/runtime/ProgramImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/ProgramImpl.java @@ -60,6 +60,14 @@ public Object eval(Map mapValue, CelFunctionResolver lateBoundFunctio return evalInternal(Activation.copyOf(mapValue), lateBoundFunctionResolver); } + @Override + public Object eval(PartialVars partialVars) throws CelEvaluationException { + return evalInternal( + UnknownContext.create(partialVars.resolver(), partialVars.unknowns()), + /* lateBoundFunctionResolver= */ Optional.empty(), + /* listener= */ Optional.empty()); + } + @Override public Object trace(CelEvaluationListener listener) throws CelEvaluationException { return evalInternal(Activation.EMPTY, listener); diff --git a/runtime/src/main/java/dev/cel/runtime/planner/Attribute.java b/runtime/src/main/java/dev/cel/runtime/planner/Attribute.java index cc011ed34..90165c1ac 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/Attribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/Attribute.java @@ -20,7 +20,7 @@ /** Represents a resolvable symbol or path (such as a variable or a field selection). */ @Immutable interface Attribute { - Object resolve(GlobalResolver ctx, ExecutionFrame frame); + Object resolve(long exprId, GlobalResolver ctx, ExecutionFrame frame); Attribute addQualifier(Qualifier qualifier); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel index 6561e4e5c..3c18b192f 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel @@ -78,6 +78,8 @@ java_library( "//runtime:evaluation_exception_builder", "//runtime:function_resolver", "//runtime:interpretable", + "//runtime:interpreter_util", + "//runtime:partial_vars", "//runtime:program", "//runtime:resolved_overload", "//runtime:variable_resolver", @@ -128,7 +130,11 @@ java_library( "//common/types", "//common/types:type_providers", "//common/values", + "//runtime:accumulated_unknowns", "//runtime:interpretable", + "//runtime:interpreter_util", + "//runtime:partial_vars", + "//runtime:unknown_attributes", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", "@maven//:org_jspecify_jspecify", @@ -181,7 +187,6 @@ java_library( ":qualifier", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", - "@maven//:com_google_guava_guava", ], ) @@ -235,6 +240,7 @@ java_library( ":execution_frame", ":planned_interpretable", "//common/values", + "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", "//runtime:interpretable", "//runtime:resolved_overload", @@ -250,6 +256,7 @@ java_library( ":planned_interpretable", "//common/exceptions:overload_not_found", "//common/values", + "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", "//runtime:interpretable", "//runtime:resolved_overload", @@ -265,6 +272,7 @@ java_library( ":execution_frame", ":planned_interpretable", "//common/values", + "//runtime:accumulated_unknowns", "//runtime:interpretable", "@maven//:com_google_guava_guava", ], @@ -278,6 +286,7 @@ java_library( ":execution_frame", ":planned_interpretable", "//common/values", + "//runtime:accumulated_unknowns", "//runtime:interpretable", "@maven//:com_google_guava_guava", ], @@ -289,6 +298,7 @@ java_library( deps = [ ":execution_frame", ":planned_interpretable", + "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", "//runtime:interpretable", "@maven//:com_google_guava_guava", @@ -299,11 +309,13 @@ java_library( name = "eval_create_struct", srcs = ["EvalCreateStruct.java"], deps = [ + ":eval_helpers", ":execution_frame", ":planned_interpretable", "//common/types:type_providers", "//common/values", "//common/values:cel_value_provider", + "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", @@ -318,6 +330,7 @@ java_library( ":eval_helpers", ":execution_frame", ":planned_interpretable", + "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", @@ -329,11 +342,13 @@ java_library( name = "eval_create_map", srcs = ["EvalCreateMap.java"], deps = [ + ":eval_helpers", ":execution_frame", ":localized_evaluation_exception", ":planned_interpretable", "//common/exceptions:duplicate_key", "//common/exceptions:invalid_argument", + "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", @@ -348,6 +363,7 @@ java_library( ":activation_wrapper", ":execution_frame", ":planned_interpretable", + "//runtime:accumulated_unknowns", "//runtime:concatenated_list_view", "//runtime:evaluation_exception", "//runtime:interpretable", @@ -365,6 +381,7 @@ java_library( "//common/exceptions:iteration_budget_exceeded", "//runtime:evaluation_exception", "//runtime:function_resolver", + "//runtime:partial_vars", "//runtime:resolved_overload", ], ) @@ -424,6 +441,7 @@ java_library( ":execution_frame", ":planned_interpretable", "//common/exceptions:overload_not_found", + "//runtime:accumulated_unknowns", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -438,6 +456,7 @@ java_library( ":execution_frame", ":planned_interpretable", "//common/exceptions:overload_not_found", + "//runtime:accumulated_unknowns", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -452,6 +471,7 @@ java_library( ":execution_frame", ":planned_interpretable", "//common/values", + "//runtime:accumulated_unknowns", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalAnd.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalAnd.java index 763f8faba..eb7406071 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalAnd.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalAnd.java @@ -1,66 +1,73 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dev.cel.runtime.planner; - -import static dev.cel.runtime.planner.EvalHelpers.evalNonstrictly; - -import com.google.common.base.Preconditions; -import dev.cel.common.values.ErrorValue; -import dev.cel.runtime.GlobalResolver; - -final class EvalAnd extends PlannedInterpretable { - - @SuppressWarnings("Immutable") - private final PlannedInterpretable[] args; - - @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) { - ErrorValue errorValue = null; - for (PlannedInterpretable arg : args) { - Object argVal = evalNonstrictly(arg, resolver, frame); - if (argVal instanceof Boolean) { - // Short-circuit on false - if (!((boolean) argVal)) { - return false; - } - } else if (argVal instanceof ErrorValue) { - errorValue = (ErrorValue) argVal; - } else { - // TODO: Handle unknowns - errorValue = - ErrorValue.create( - arg.exprId(), - new IllegalArgumentException( - String.format("Expected boolean value, found: %s", argVal))); - } - } - - if (errorValue != null) { - return errorValue; - } - - return true; - } - - static EvalAnd create(long exprId, PlannedInterpretable[] args) { - return new EvalAnd(exprId, args); - } - - private EvalAnd(long exprId, PlannedInterpretable[] args) { - super(exprId); - Preconditions.checkArgument(args.length == 2); - this.args = args; - } -} +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import static dev.cel.runtime.planner.EvalHelpers.evalNonstrictly; + +import com.google.common.base.Preconditions; +import dev.cel.common.values.ErrorValue; +import dev.cel.runtime.AccumulatedUnknowns; +import dev.cel.runtime.GlobalResolver; + +final class EvalAnd extends PlannedInterpretable { + + @SuppressWarnings("Immutable") + private final PlannedInterpretable[] args; + + @Override + public Object eval(GlobalResolver resolver, ExecutionFrame frame) { + ErrorValue errorValue = null; + AccumulatedUnknowns unknowns = null; + for (PlannedInterpretable arg : args) { + Object argVal = evalNonstrictly(arg, resolver, frame); + if (argVal instanceof Boolean) { + // Short-circuit on false + if (!((boolean) argVal)) { + return false; + } + } else if (argVal instanceof ErrorValue) { + errorValue = (ErrorValue) argVal; + } else if (argVal instanceof AccumulatedUnknowns) { + unknowns = AccumulatedUnknowns.maybeMerge(unknowns, argVal); + } else { + errorValue = + ErrorValue.create( + arg.exprId(), + new IllegalArgumentException( + String.format("Expected boolean value, found: %s", argVal))); + } + } + + if (unknowns != null) { + return unknowns; + } + + if (errorValue != null) { + return errorValue; + } + + return true; + } + + static EvalAnd create(long exprId, PlannedInterpretable[] args) { + return new EvalAnd(exprId, args); + } + + private EvalAnd(long exprId, PlannedInterpretable[] args) { + super(exprId); + Preconditions.checkArgument(args.length == 2); + this.args = args; + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalAttribute.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalAttribute.java index fdd7ad2a3..a0a95c47a 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalAttribute.java @@ -24,9 +24,9 @@ final class EvalAttribute extends InterpretableAttribute { @Override public Object eval(GlobalResolver resolver, ExecutionFrame frame) { - Object resolved = attr.resolve(resolver, frame); + Object resolved = attr.resolve(exprId(), resolver, frame); if (resolved instanceof MissingAttribute) { - ((MissingAttribute) resolved).resolve(resolver, frame); + ((MissingAttribute) resolved).resolve(exprId(), resolver, frame); } return resolved; diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalConditional.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalConditional.java index 74482d629..3be1f016a 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalConditional.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalConditional.java @@ -15,6 +15,7 @@ package dev.cel.runtime.planner; import com.google.common.base.Preconditions; +import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.GlobalResolver; @@ -28,8 +29,10 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval PlannedInterpretable condition = args[0]; PlannedInterpretable truthy = args[1]; PlannedInterpretable falsy = args[2]; - // TODO: Handle unknowns Object condResult = condition.eval(resolver, frame); + if (condResult instanceof AccumulatedUnknowns) { + return condResult; + } if (!(condResult instanceof Boolean)) { throw new IllegalArgumentException( String.format("Expected boolean value, found :%s", condResult)); diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateList.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateList.java index 773272ea3..bae1e9302 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateList.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateList.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; +import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.GlobalResolver; import java.util.Optional; @@ -32,9 +33,15 @@ final class EvalCreateList extends PlannedInterpretable { @Override public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(values.length); + AccumulatedUnknowns unknowns = null; for (int i = 0; i < values.length; i++) { Object element = EvalHelpers.evalStrictly(values[i], resolver, frame); + if (element instanceof AccumulatedUnknowns) { + unknowns = AccumulatedUnknowns.maybeMerge(unknowns, element); + continue; + } + if (isOptional[i]) { if (!(element instanceof Optional)) { throw new IllegalArgumentException( @@ -51,6 +58,11 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval builder.add(element); } + + if (unknowns != null) { + return unknowns; + } + return builder.build(); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateMap.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateMap.java index f6f73e842..1e1b831bb 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateMap.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateMap.java @@ -21,6 +21,7 @@ import com.google.errorprone.annotations.Immutable; import dev.cel.common.exceptions.CelDuplicateKeyException; import dev.cel.common.exceptions.CelInvalidArgumentException; +import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.GlobalResolver; import java.util.HashSet; @@ -46,38 +47,49 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval ImmutableMap.Builder builder = ImmutableMap.builderWithExpectedSize(keys.length); HashSet keysSeen = Sets.newHashSetWithExpectedSize(keys.length); + AccumulatedUnknowns unknowns = null; for (int i = 0; i < keys.length; i++) { PlannedInterpretable keyInterpretable = keys[i]; Object key = keyInterpretable.eval(resolver, frame); - if (!(key instanceof String - || key instanceof Long - || key instanceof UnsignedLong - || key instanceof Boolean)) { - throw new LocalizedEvaluationException( - new CelInvalidArgumentException("Unsupported key type: " + key), - keyInterpretable.exprId()); - } - boolean isDuplicate = !keysSeen.add(key); - if (!isDuplicate) { - if (key instanceof Long) { - long longVal = (Long) key; - if (longVal >= 0) { - isDuplicate = keysSeen.contains(UnsignedLong.valueOf(longVal)); + if (key instanceof AccumulatedUnknowns) { + unknowns = AccumulatedUnknowns.maybeMerge(unknowns, key); + } else { + if (!(key instanceof String + || key instanceof Long + || key instanceof UnsignedLong + || key instanceof Boolean)) { + throw new LocalizedEvaluationException( + new CelInvalidArgumentException("Unsupported key type: " + key), + keyInterpretable.exprId()); + } + + boolean isDuplicate = !keysSeen.add(key); + if (!isDuplicate) { + if (key instanceof Long) { + long longVal = (Long) key; + if (longVal >= 0) { + isDuplicate = keysSeen.contains(UnsignedLong.valueOf(longVal)); + } + } else if (key instanceof UnsignedLong) { + UnsignedLong ulongVal = (UnsignedLong) key; + isDuplicate = keysSeen.contains(ulongVal.longValue()); } - } else if (key instanceof UnsignedLong) { - UnsignedLong ulongVal = (UnsignedLong) key; - isDuplicate = keysSeen.contains(ulongVal.longValue()); } - } - if (isDuplicate) { - throw new LocalizedEvaluationException( - CelDuplicateKeyException.of(key), keyInterpretable.exprId()); + if (isDuplicate) { + throw new LocalizedEvaluationException( + CelDuplicateKeyException.of(key), keyInterpretable.exprId()); + } } - Object val = values[i].eval(resolver, frame); + Object val = EvalHelpers.evalStrictly(values[i], resolver, frame); + + unknowns = AccumulatedUnknowns.maybeMerge(unknowns, val); + if (unknowns != null) { + continue; + } if (isOptional[i]) { if (!(val instanceof Optional)) { @@ -94,13 +106,15 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval continue; } val = opt.get(); - } else { - System.out.println(); } builder.put(key, val); } + if (unknowns != null) { + return unknowns; + } + return builder.buildOrThrow(); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java index 4edc87b79..cdeb0c574 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java @@ -14,14 +14,15 @@ package dev.cel.runtime.planner; +import com.google.common.collect.Maps; import com.google.errorprone.annotations.Immutable; import dev.cel.common.types.CelType; import dev.cel.common.values.CelValueProvider; import dev.cel.common.values.StructValue; +import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.GlobalResolver; import java.util.Collections; -import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -45,17 +46,22 @@ final class EvalCreateStruct extends PlannedInterpretable { @Override public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { - Map fieldValues = new HashMap<>(); + Map fieldValues = Maps.newHashMapWithExpectedSize(keys.length); + AccumulatedUnknowns unknowns = null; for (int i = 0; i < keys.length; i++) { - Object value = values[i].eval(resolver, frame); + Object value = EvalHelpers.evalStrictly(values[i], resolver, frame); + + if (value instanceof AccumulatedUnknowns) { + unknowns = AccumulatedUnknowns.maybeMerge(unknowns, value); + continue; + } if (isOptional[i]) { if (!(value instanceof Optional)) { throw new IllegalArgumentException( String.format( - "Cannot initialize optional entry 'single_double_wrapper' from non-optional value" - + " %s", - value)); + "Cannot initialize optional entry '%s' from non-optional value" + " %s", + keys[i], value)); } Optional opt = (Optional) value; @@ -71,6 +77,10 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval fieldValues.put(keys[i], value); } + if (unknowns != null) { + return unknowns; + } + // Either a primitive (wrappers) or a struct is produced Object value = valueProvider diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java index 3545ee4f7..197db42ad 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; +import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.ConcatenatedListView; import dev.cel.runtime.GlobalResolver; @@ -73,6 +74,9 @@ private EvalFold( @Override public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { Object iterRangeRaw = iterRange.eval(resolver, frame); + if (iterRangeRaw instanceof AccumulatedUnknowns) { + return iterRangeRaw; + } Folder folder = new Folder(resolver, accuVar, iterVar, iterVar2); folder.accuVal = maybeWrapAccumulator(accuInit.eval(folder, frame)); diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java index 92d234acc..38b060b92 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java @@ -1,78 +1,78 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dev.cel.runtime.planner; - -import com.google.common.base.Joiner; -import dev.cel.common.CelErrorCode; -import dev.cel.common.exceptions.CelRuntimeException; -import dev.cel.common.values.CelValueConverter; -import dev.cel.common.values.ErrorValue; -import dev.cel.runtime.CelEvaluationException; -import dev.cel.runtime.CelResolvedOverload; -import dev.cel.runtime.GlobalResolver; - -final class EvalHelpers { - - static Object evalNonstrictly( - PlannedInterpretable interpretable, GlobalResolver resolver, ExecutionFrame frame) { - try { - return interpretable.eval(resolver, frame); - } catch (LocalizedEvaluationException e) { - // Intercept the localized exception to get a more specific expr ID for error reporting - // Example: foo [1] && strict_err [2] -> ID 2 is propagated. - return ErrorValue.create(e.exprId(), e); - } catch (Exception e) { - return ErrorValue.create(interpretable.exprId(), e); - } - } - - static Object evalStrictly( - PlannedInterpretable interpretable, GlobalResolver resolver, ExecutionFrame frame) { - try { - return interpretable.eval(resolver, frame); - } catch (LocalizedEvaluationException e) { - // Already localized - propagate as-is to preserve inner expression ID - throw e; - } catch (CelRuntimeException e) { - // Wrap with current interpretable's location - throw new LocalizedEvaluationException(e, interpretable.exprId()); - } catch (Exception e) { - // Wrap generic exceptions with location - throw new LocalizedEvaluationException( - e, CelErrorCode.INTERNAL_ERROR, interpretable.exprId()); - } - } - - static Object dispatch( - CelResolvedOverload overload, CelValueConverter valueConverter, Object[] args) - throws CelEvaluationException { - try { - Object result = overload.getDefinition().apply(args); - return valueConverter.maybeUnwrap(valueConverter.toRuntimeValue(result)); - } catch (CelRuntimeException e) { - // Function dispatch failure that's already been handled -- just propagate. - throw e; - } catch (RuntimeException e) { - // Unexpected function dispatch failure. - throw new IllegalArgumentException( - String.format( - "Function '%s' failed with arg(s) '%s'", - overload.getOverloadId(), Joiner.on(", ").join(args)), - e); - } - } - - private EvalHelpers() {} -} +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import com.google.common.base.Joiner; +import dev.cel.common.CelErrorCode; +import dev.cel.common.exceptions.CelRuntimeException; +import dev.cel.common.values.CelValueConverter; +import dev.cel.common.values.ErrorValue; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelResolvedOverload; +import dev.cel.runtime.GlobalResolver; + +final class EvalHelpers { + + static Object evalNonstrictly( + PlannedInterpretable interpretable, GlobalResolver resolver, ExecutionFrame frame) { + try { + return interpretable.eval(resolver, frame); + } catch (LocalizedEvaluationException e) { + // Intercept the localized exception to get a more specific expr ID for error reporting + // Example: foo [1] && strict_err [2] -> ID 2 is propagated. + return ErrorValue.create(e.exprId(), e); + } catch (Exception e) { + return ErrorValue.create(interpretable.exprId(), e); + } + } + + static Object evalStrictly( + PlannedInterpretable interpretable, GlobalResolver resolver, ExecutionFrame frame) { + try { + return interpretable.eval(resolver, frame); + } catch (LocalizedEvaluationException e) { + // Already localized - propagate as-is to preserve inner expression ID + throw e; + } catch (CelRuntimeException e) { + // Wrap with current interpretable's location + throw new LocalizedEvaluationException(e, interpretable.exprId()); + } catch (Exception e) { + // Wrap generic exceptions with location + throw new LocalizedEvaluationException( + e, CelErrorCode.INTERNAL_ERROR, interpretable.exprId()); + } + } + + static Object dispatch( + CelResolvedOverload overload, CelValueConverter valueConverter, Object[] args) + throws CelEvaluationException { + try { + Object result = overload.getDefinition().apply(args); + return valueConverter.maybeUnwrap(valueConverter.toRuntimeValue(result)); + } catch (CelRuntimeException e) { + // Function dispatch failure that's already been handled -- just propagate. + throw e; + } catch (RuntimeException e) { + // Unexpected function dispatch failure. + throw new IllegalArgumentException( + String.format( + "Function '%s' failed with arg(s) '%s'", + overload.getOverloadId(), Joiner.on(", ").join(args)), + e); + } + } + + private EvalHelpers() {} +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalLateBoundCall.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalLateBoundCall.java index a22ba8e94..cdee878ee 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalLateBoundCall.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalLateBoundCall.java @@ -19,6 +19,7 @@ import com.google.common.collect.ImmutableList; import dev.cel.common.exceptions.CelOverloadNotFoundException; import dev.cel.common.values.CelValueConverter; +import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelResolvedOverload; import dev.cel.runtime.GlobalResolver; @@ -36,10 +37,17 @@ final class EvalLateBoundCall extends PlannedInterpretable { @Override public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { Object[] argVals = new Object[args.length]; + AccumulatedUnknowns unknowns = null; for (int i = 0; i < args.length; i++) { PlannedInterpretable arg = args[i]; // Late bound functions are assumed to be strict. argVals[i] = evalStrictly(arg, resolver, frame); + + unknowns = AccumulatedUnknowns.maybeMerge(unknowns, argVals[i]); + } + + if (unknowns != null) { + return unknowns; } CelResolvedOverload resolvedOverload = diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOr.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOr.java index 70009d567..5ad1933d7 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOr.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOr.java @@ -17,6 +17,7 @@ import com.google.common.base.Preconditions; import com.google.errorprone.annotations.Immutable; import dev.cel.common.exceptions.CelOverloadNotFoundException; +import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.GlobalResolver; import java.util.Optional; @@ -29,6 +30,10 @@ final class EvalOptionalOr extends PlannedInterpretable { public Object eval(GlobalResolver resolver, ExecutionFrame frame) { Object lhsValue = EvalHelpers.evalStrictly(lhs, resolver, frame); + if (lhsValue instanceof AccumulatedUnknowns) { + return lhsValue; + } + if (!(lhsValue instanceof Optional)) { throw new CelOverloadNotFoundException("or"); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOrValue.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOrValue.java index 7a4940c7c..6634d60f6 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOrValue.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOrValue.java @@ -17,6 +17,7 @@ import com.google.common.base.Preconditions; import com.google.errorprone.annotations.Immutable; import dev.cel.common.exceptions.CelOverloadNotFoundException; +import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.GlobalResolver; import java.util.Optional; @@ -28,6 +29,10 @@ final class EvalOptionalOrValue extends PlannedInterpretable { @Override public Object eval(GlobalResolver resolver, ExecutionFrame frame) { Object lhsValue = EvalHelpers.evalStrictly(lhs, resolver, frame); + if (lhsValue instanceof AccumulatedUnknowns) { + return lhsValue; + } + if (!(lhsValue instanceof Optional)) { throw new CelOverloadNotFoundException("orValue"); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalSelectField.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalSelectField.java index bc14149f3..8887aa697 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalSelectField.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalSelectField.java @@ -18,6 +18,7 @@ import com.google.errorprone.annotations.Immutable; import dev.cel.common.values.CelValueConverter; import dev.cel.common.values.SelectableValue; +import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.GlobalResolver; import java.util.Map; import java.util.Optional; @@ -42,6 +43,10 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) { } Object runtimeOperandValue = celValueConverter.toRuntimeValue(operandValue); + if (runtimeOperandValue instanceof AccumulatedUnknowns) { + return runtimeOperandValue; + } + boolean hasField = false; if (runtimeOperandValue instanceof SelectableValue) { @@ -62,6 +67,10 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) { return resultValue; } + if (resultValue instanceof AccumulatedUnknowns) { + return resultValue; + } + return Optional.of(resultValue); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalOr.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalOr.java index 22fc56a7f..bc19ed81a 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalOr.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalOr.java @@ -1,66 +1,73 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dev.cel.runtime.planner; - -import static dev.cel.runtime.planner.EvalHelpers.evalNonstrictly; - -import com.google.common.base.Preconditions; -import dev.cel.common.values.ErrorValue; -import dev.cel.runtime.GlobalResolver; - -final class EvalOr extends PlannedInterpretable { - - @SuppressWarnings("Immutable") - private final PlannedInterpretable[] args; - - @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) { - ErrorValue errorValue = null; - for (PlannedInterpretable arg : args) { - Object argVal = evalNonstrictly(arg, resolver, frame); - if (argVal instanceof Boolean) { - // Short-circuit on true - if (((boolean) argVal)) { - return true; - } - } else if (argVal instanceof ErrorValue) { - errorValue = (ErrorValue) argVal; - } else { - // TODO: Handle unknowns - errorValue = - ErrorValue.create( - arg.exprId(), - new IllegalArgumentException( - String.format("Expected boolean value, found: %s", argVal))); - } - } - - if (errorValue != null) { - return errorValue; - } - - return false; - } - - static EvalOr create(long exprId, PlannedInterpretable[] args) { - return new EvalOr(exprId, args); - } - - private EvalOr(long exprId, PlannedInterpretable[] args) { - super(exprId); - Preconditions.checkArgument(args.length == 2); - this.args = args; - } -} +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import static dev.cel.runtime.planner.EvalHelpers.evalNonstrictly; + +import com.google.common.base.Preconditions; +import dev.cel.common.values.ErrorValue; +import dev.cel.runtime.AccumulatedUnknowns; +import dev.cel.runtime.GlobalResolver; + +final class EvalOr extends PlannedInterpretable { + + @SuppressWarnings("Immutable") + private final PlannedInterpretable[] args; + + @Override + public Object eval(GlobalResolver resolver, ExecutionFrame frame) { + ErrorValue errorValue = null; + AccumulatedUnknowns unknowns = null; + for (PlannedInterpretable arg : args) { + Object argVal = evalNonstrictly(arg, resolver, frame); + if (argVal instanceof Boolean) { + // Short-circuit on true + if (((boolean) argVal)) { + return true; + } + } else if (argVal instanceof ErrorValue) { + errorValue = (ErrorValue) argVal; + } else if (argVal instanceof AccumulatedUnknowns) { + unknowns = AccumulatedUnknowns.maybeMerge(unknowns, argVal); + } else { + errorValue = + ErrorValue.create( + arg.exprId(), + new IllegalArgumentException( + String.format("Expected boolean value, found: %s", argVal))); + } + } + + if (unknowns != null) { + return unknowns; + } + + if (errorValue != null) { + return errorValue; + } + + return false; + } + + static EvalOr create(long exprId, PlannedInterpretable[] args) { + return new EvalOr(exprId, args); + } + + private EvalOr(long exprId, PlannedInterpretable[] args) { + super(exprId); + Preconditions.checkArgument(args.length == 2); + this.args = args; + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java index 9f14f8bf9..eb8745632 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java @@ -18,6 +18,7 @@ import static dev.cel.runtime.planner.EvalHelpers.evalStrictly; import dev.cel.common.values.CelValueConverter; +import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelResolvedOverload; import dev.cel.runtime.GlobalResolver; @@ -34,12 +35,19 @@ final class EvalVarArgsCall extends PlannedInterpretable { @Override public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { Object[] argVals = new Object[args.length]; + AccumulatedUnknowns unknowns = null; for (int i = 0; i < args.length; i++) { PlannedInterpretable arg = args[i]; argVals[i] = resolvedOverload.isStrict() ? evalStrictly(arg, resolver, frame) : evalNonstrictly(arg, resolver, frame); + + unknowns = AccumulatedUnknowns.maybeMerge(unknowns, argVals[i]); + } + + if (unknowns != null) { + return unknowns; } return EvalHelpers.dispatch(resolvedOverload, celValueConverter, argVals); diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java b/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java index 80ee4b318..e29c68dd8 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java @@ -19,6 +19,7 @@ import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelFunctionResolver; import dev.cel.runtime.CelResolvedOverload; +import dev.cel.runtime.PartialVars; import java.util.Collection; import java.util.Optional; @@ -27,6 +28,7 @@ final class ExecutionFrame { private final int comprehensionIterationLimit; private final CelFunctionResolver functionResolver; + private final PartialVars partialVars; private int iterationCount; Optional findOverload( @@ -47,12 +49,19 @@ void incrementIterations() { } } - static ExecutionFrame create(CelFunctionResolver functionResolver, CelOptions celOptions) { - return new ExecutionFrame(functionResolver, celOptions.comprehensionMaxIterations()); + static ExecutionFrame create( + CelFunctionResolver functionResolver, PartialVars partialVars, CelOptions celOptions) { + return new ExecutionFrame( + functionResolver, partialVars, celOptions.comprehensionMaxIterations()); } - private ExecutionFrame(CelFunctionResolver functionResolver, int limit) { + Optional partialVars() { + return Optional.ofNullable(partialVars); + } + + private ExecutionFrame(CelFunctionResolver functionResolver, PartialVars partialVars, int limit) { this.comprehensionIterationLimit = limit; this.functionResolver = functionResolver; + this.partialVars = partialVars; } } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/MaybeAttribute.java b/runtime/src/main/java/dev/cel/runtime/planner/MaybeAttribute.java index 40a9f6203..1506eb180 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/MaybeAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/MaybeAttribute.java @@ -28,10 +28,10 @@ final class MaybeAttribute implements Attribute { private final ImmutableList attributes; @Override - public Object resolve(GlobalResolver ctx, ExecutionFrame frame) { + public Object resolve(long exprId, GlobalResolver ctx, ExecutionFrame frame) { MissingAttribute maybeError = null; for (NamespacedAttribute attr : attributes) { - Object value = attr.resolve(ctx, frame); + Object value = attr.resolve(exprId, ctx, frame); if (value == null) { continue; } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/MissingAttribute.java b/runtime/src/main/java/dev/cel/runtime/planner/MissingAttribute.java index 02b04781c..b7fb8ad72 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/MissingAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/MissingAttribute.java @@ -25,7 +25,7 @@ final class MissingAttribute implements Attribute { private final Kind kind; @Override - public Object resolve(GlobalResolver ctx, ExecutionFrame frame) { + public Object resolve(long exprId, GlobalResolver ctx, ExecutionFrame frame) { switch (kind) { case ATTRIBUTE_NOT_FOUND: throw CelAttributeNotFoundException.forMissingAttributes(missingAttributes); diff --git a/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java b/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java index cc8ca1d97..ed37eada1 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java @@ -15,6 +15,7 @@ package dev.cel.runtime.planner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.Immutable; import dev.cel.common.types.CelType; @@ -23,14 +24,21 @@ import dev.cel.common.types.SimpleType; import dev.cel.common.types.TypeType; import dev.cel.common.values.CelValueConverter; +import dev.cel.runtime.AccumulatedUnknowns; +import dev.cel.runtime.CelAttribute; +import dev.cel.runtime.CelAttributePattern; import dev.cel.runtime.GlobalResolver; +import dev.cel.runtime.InterpreterUtil; +import dev.cel.runtime.PartialVars; +import java.util.Map; import java.util.NoSuchElementException; +import java.util.Optional; import org.jspecify.annotations.Nullable; @Immutable final class NamespacedAttribute implements Attribute { private final boolean disambiguateNames; - private final ImmutableSet namespacedNames; + private final ImmutableMap candidateAttributes; private final ImmutableList qualifiers; private final CelValueConverter celValueConverter; private final CelTypeProvider typeProvider; @@ -40,11 +48,11 @@ ImmutableList qualifiers() { } ImmutableSet candidateVariableNames() { - return namespacedNames; + return candidateAttributes.keySet(); } @Override - public Object resolve(GlobalResolver ctx, ExecutionFrame frame) { + public Object resolve(long exprId, GlobalResolver ctx, ExecutionFrame frame) { GlobalResolver inputVars = ctx; // Unwrap any local activations to ensure that we reach the variables provided as input // to the expression in the event that we need to disambiguate between global and local @@ -53,13 +61,33 @@ public Object resolve(GlobalResolver ctx, ExecutionFrame frame) { inputVars = unwrapToNonLocal(ctx); } - for (String name : namespacedNames) { + for (Map.Entry entry : candidateAttributes.entrySet()) { + String name = entry.getKey(); + CelAttribute attr = entry.getValue(); + GlobalResolver resolver = ctx; if (disambiguateNames) { resolver = inputVars; } Object value = resolver.resolve(name); + value = InterpreterUtil.maybeAdaptToAccumulatedUnknowns(value); + + PartialVars partialVars = frame.partialVars().orElse(null); + + if (partialVars != null) { + ImmutableList patterns = partialVars.unknowns(); + for (Qualifier qualifier : qualifiers) { + attr = attr.qualify(CelAttribute.Qualifier.fromGeneric(qualifier.value())); + } + + CelAttributePattern partialMatch = findPartialMatchingPattern(attr, patterns).orElse(null); + if (partialMatch != null) { + return AccumulatedUnknowns.create( + ImmutableList.of(exprId), ImmutableList.of(partialMatch.simplify(attr))); + } + } + if (value != null) { return applyQualifiers(value, celValueConverter, qualifiers); } @@ -71,7 +99,7 @@ public Object resolve(GlobalResolver ctx, ExecutionFrame frame) { } } - return MissingAttribute.newMissingAttribute(namespacedNames); + return MissingAttribute.newMissingAttribute(candidateAttributes.keySet()); } private @Nullable Object findIdent(String name) { @@ -131,10 +159,17 @@ private GlobalResolver unwrapToNonLocal(GlobalResolver resolver) { @Override public NamespacedAttribute addQualifier(Qualifier qualifier) { + ImmutableMap.Builder attributesBuilder = ImmutableMap.builder(); + CelAttribute.Qualifier celQualifier = CelAttribute.Qualifier.fromGeneric(qualifier.value()); + + for (Map.Entry entry : candidateAttributes.entrySet()) { + attributesBuilder.put(entry.getKey(), entry.getValue().qualify(celQualifier)); + } + return new NamespacedAttribute( typeProvider, celValueConverter, - namespacedNames, + attributesBuilder.buildOrThrow(), disambiguateNames, ImmutableList.builder().addAll(qualifiers).add(qualifier).build()); } @@ -150,37 +185,49 @@ private static Object applyQualifiers( return celValueConverter.maybeUnwrap(obj); } + private static Optional findPartialMatchingPattern( + CelAttribute attr, ImmutableList patterns) { + for (CelAttributePattern pattern : patterns) { + if (pattern.isPartialMatch(attr)) { + return Optional.of(pattern); + } + } + return Optional.empty(); + } + static NamespacedAttribute create( CelTypeProvider typeProvider, CelValueConverter celValueConverter, ImmutableSet namespacedNames) { - ImmutableSet.Builder namesBuilder = ImmutableSet.builder(); + ImmutableMap.Builder attributesBuilder = ImmutableMap.builder(); boolean disambiguateNames = false; + for (String name : namespacedNames) { + String baseName = name; if (name.startsWith(".")) { disambiguateNames = true; - namesBuilder.add(name.substring(1)); - } else { - namesBuilder.add(name); + baseName = name.substring(1); } + attributesBuilder.put(baseName, CelAttribute.fromQualifiedIdentifier(baseName)); } + return new NamespacedAttribute( typeProvider, celValueConverter, - namesBuilder.build(), + attributesBuilder.buildOrThrow(), disambiguateNames, ImmutableList.of()); } - NamespacedAttribute( + private NamespacedAttribute( CelTypeProvider typeProvider, CelValueConverter celValueConverter, - ImmutableSet namespacedNames, + ImmutableMap candidateAttributes, boolean disambiguateNames, ImmutableList qualifiers) { this.typeProvider = typeProvider; this.celValueConverter = celValueConverter; - this.namespacedNames = namespacedNames; + this.candidateAttributes = candidateAttributes; this.disambiguateNames = disambiguateNames; this.qualifiers = qualifiers; } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/PlannedProgram.java b/runtime/src/main/java/dev/cel/runtime/planner/PlannedProgram.java index 8b419cab2..34fc34b50 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/PlannedProgram.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/PlannedProgram.java @@ -26,6 +26,8 @@ import dev.cel.runtime.CelResolvedOverload; import dev.cel.runtime.CelVariableResolver; import dev.cel.runtime.GlobalResolver; +import dev.cel.runtime.InterpreterUtil; +import dev.cel.runtime.PartialVars; import dev.cel.runtime.Program; import java.util.Collection; import java.util.Map; @@ -58,47 +60,61 @@ public Optional findOverloadMatchingArgs( @Override public Object eval() throws CelEvaluationException { - return evalOrThrow(interpretable(), GlobalResolver.EMPTY, EMPTY_FUNCTION_RESOLVER); + return evalOrThrow(interpretable(), GlobalResolver.EMPTY, EMPTY_FUNCTION_RESOLVER, null); } @Override public Object eval(Map mapValue) throws CelEvaluationException { - return evalOrThrow(interpretable(), Activation.copyOf(mapValue), EMPTY_FUNCTION_RESOLVER); + return evalOrThrow(interpretable(), Activation.copyOf(mapValue), EMPTY_FUNCTION_RESOLVER, null); } @Override public Object eval(Map mapValue, CelFunctionResolver lateBoundFunctionResolver) throws CelEvaluationException { - return evalOrThrow(interpretable(), Activation.copyOf(mapValue), lateBoundFunctionResolver); + return evalOrThrow( + interpretable(), Activation.copyOf(mapValue), lateBoundFunctionResolver, null); } @Override public Object eval(CelVariableResolver resolver) throws CelEvaluationException { return evalOrThrow( - interpretable(), (name) -> resolver.find(name).orElse(null), EMPTY_FUNCTION_RESOLVER); + interpretable(), (name) -> resolver.find(name).orElse(null), EMPTY_FUNCTION_RESOLVER, null); } @Override public Object eval(CelVariableResolver resolver, CelFunctionResolver lateBoundFunctionResolver) throws CelEvaluationException { return evalOrThrow( - interpretable(), (name) -> resolver.find(name).orElse(null), lateBoundFunctionResolver); + interpretable(), + (name) -> resolver.find(name).orElse(null), + lateBoundFunctionResolver, + null); + } + + @Override + public Object eval(PartialVars partialVars) throws CelEvaluationException { + return evalOrThrow( + interpretable(), + (name) -> partialVars.resolver().find(name).orElse(null), + EMPTY_FUNCTION_RESOLVER, + partialVars); } private Object evalOrThrow( PlannedInterpretable interpretable, GlobalResolver resolver, - CelFunctionResolver functionResolver) + CelFunctionResolver functionResolver, + PartialVars partialVars) throws CelEvaluationException { try { - ExecutionFrame frame = ExecutionFrame.create(functionResolver, options()); + ExecutionFrame frame = ExecutionFrame.create(functionResolver, partialVars, options()); Object evalResult = interpretable.eval(resolver, frame); if (evalResult instanceof ErrorValue) { ErrorValue errorValue = (ErrorValue) evalResult; throw newCelEvaluationException(errorValue.exprId(), errorValue.value()); } - return evalResult; + return InterpreterUtil.maybeAdaptToCelUnknownSet(evalResult); } catch (RuntimeException e) { throw newCelEvaluationException(interpretable.exprId(), e); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/RelativeAttribute.java b/runtime/src/main/java/dev/cel/runtime/planner/RelativeAttribute.java index b3d83c390..1ab2fa3e7 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/RelativeAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/RelativeAttribute.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; import dev.cel.common.values.CelValueConverter; +import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.GlobalResolver; /** @@ -31,15 +32,18 @@ final class RelativeAttribute implements Attribute { private final ImmutableList qualifiers; @Override - public Object resolve(GlobalResolver ctx, ExecutionFrame frame) { + public Object resolve(long exprId, GlobalResolver ctx, ExecutionFrame frame) { Object obj = EvalHelpers.evalStrictly(operand, ctx, frame); + if (obj instanceof AccumulatedUnknowns) { + return obj; + } + obj = celValueConverter.toRuntimeValue(obj); for (Qualifier qualifier : qualifiers) { obj = qualifier.qualify(obj); } - // TODO: Handle unknowns return celValueConverter.maybeUnwrap(obj); } diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index 8a0b1f9de..577010971 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -130,16 +130,25 @@ java_library( srcs = [ "PlannerInterpreterTest.java", ], + resources = [ + "//runtime/testdata", + ], deps = [ "//common:cel_ast", "//common:compiler_common", "//common:container", "//common:options", + "//common/types", "//common/types:type_providers", "//extensions", "//runtime", + "//runtime:function_binding", "//runtime:runtime_experimental_factory", + "//runtime:unknown_attributes", "//testing:base_interpreter_test", + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_google_testparameterinjector_test_parameter_injector", "@maven//:junit_junit", ], diff --git a/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java index 3254855c7..2c0bec739 100644 --- a/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java +++ b/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java @@ -14,6 +14,8 @@ package dev.cel.runtime; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Timestamp; import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import dev.cel.common.CelAbstractSyntaxTree; @@ -21,8 +23,15 @@ import dev.cel.common.CelOptions; import dev.cel.common.CelValidationException; import dev.cel.common.types.CelTypeProvider; +import dev.cel.common.types.ListType; +import dev.cel.common.types.SimpleType; +import dev.cel.common.types.StructTypeReference; +import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.extensions.CelExtensions; import dev.cel.testing.BaseInterpreterTest; +import java.util.Arrays; +import java.util.Objects; +import org.junit.Test; import org.junit.runner.RunWith; /** Interpreter tests using ProgramPlanner */ @@ -37,7 +46,8 @@ protected CelRuntimeBuilder newBaseRuntimeBuilder(CelOptions celOptions) { .addLateBoundFunctions("record") .setOptions(celOptions) .addLibraries(CelExtensions.optional()) - .addFileTypes(TEST_FILE_DESCRIPTORS); + .addFileTypes(TEST_FILE_DESCRIPTORS) + .addMessageTypes(TestAllTypes.getDescriptor()); } @Override @@ -70,26 +80,247 @@ protected CelAbstractSyntaxTree prepareTest(CelTypeProvider typeProvider) { } } + @Override + public void optional_errors() { + if (isParseOnly) { + // Parsed-only evaluation contains function name in the + // error message instead of the function overload. + skipBaselineVerification(); + } else { + super.optional_errors(); + } + } + @Override public void unknownField() { - // TODO: Unknown support not implemented yet + // Exercised in planner_unknownFieldAccess instead skipBaselineVerification(); } @Override public void unknownResultSet() { - // TODO: Unknown support not implemented yet + // Exercised in planner_unknownResultSet_success instead skipBaselineVerification(); } - @Override - public void optional_errors() { - if (isParseOnly) { - // Parsed-only evaluation contains function name in the - // error message instead of the function overload. - skipBaselineVerification(); - } else { - super.optional_errors(); - } + @Test + public void planner_unknownFieldSelection() { + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); + declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); + + CelAttributePattern patternX = CelAttributePattern.fromQualifiedIdentifier("x"); + + source = "x"; + // We have the full message, but we're claiming that the attribute is unknown. + runTest(ImmutableMap.of("x", TestAllTypes.getDefaultInstance()), patternX); + // A "partially known message". The result is still an unknown. + runTest( + ImmutableMap.of("x", TestAllTypes.getDefaultInstance()), + CelAttributePattern.fromQualifiedIdentifier("x.single_int32")); + + source = "x.single_int32"; + runTest(ImmutableMap.of(), patternX); + runTest(ImmutableMap.of(), CelAttributePattern.fromQualifiedIdentifier("x.single_int32")); + + source = "x.map_int32_int64[22]"; + runTest(ImmutableMap.of(), patternX); + runTest(ImmutableMap.of(), CelAttributePattern.fromQualifiedIdentifier("x.map_int32_int64")); + + source = "x.repeated_nested_message[1]"; + runTest(ImmutableMap.of(), patternX); + runTest( + ImmutableMap.of(), + CelAttributePattern.fromQualifiedIdentifier("x.repeated_nested_message")); + + source = "x.single_nested_message.bb"; + runTest(ImmutableMap.of(), patternX); + runTest( + ImmutableMap.of(), + CelAttributePattern.fromQualifiedIdentifier("x.single_nested_message.bb")); + + source = "{1: x.single_int32}"; + runTest(ImmutableMap.of(), patternX); + runTest(ImmutableMap.of(), CelAttributePattern.fromQualifiedIdentifier("x.single_int32")); + + source = "[1, x.single_int32]"; + runTest(ImmutableMap.of(), patternX); + runTest(ImmutableMap.of(), CelAttributePattern.fromQualifiedIdentifier("x.single_int32")); + } + + @Test + public void planner_unknownResultSet_success() { + setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())); + declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); + TestAllTypes message = + TestAllTypes.newBuilder() + .setSingleString("test") + .setSingleTimestamp(Timestamp.newBuilder().setSeconds(15)) + .build(); + ImmutableMap variables = ImmutableMap.of("x", message); + CelAttributePattern unknownInt32 = + CelAttributePattern.fromQualifiedIdentifier("x.single_int32"); + CelAttributePattern unknownInt64 = + CelAttributePattern.fromQualifiedIdentifier("x.single_int64"); + + source = "x.single_int32 == 1 && true"; + runTest(variables, unknownInt32); + + source = "x.single_int32 == 1 && false"; + runTest(variables, unknownInt32); + + source = "x.single_int32 == 1 && x.single_int64 == 1"; + runTest(variables, unknownInt32, unknownInt64); + + source = "true && x.single_int32 == 1"; + runTest(variables, unknownInt32); + + source = "false && x.single_int32 == 1"; + runTest(variables, unknownInt32); + + source = "x.single_int32 == 1 || x.single_string == \"test\""; + runTest(variables, unknownInt32); + + source = "x.single_int32 == 1 || x.single_string != \"test\""; + runTest(variables, unknownInt32); + + source = "x.single_int32 == 1 || x.single_int64 == 1"; + runTest(variables, unknownInt32, unknownInt64); + + source = "true || x.single_int32 == 1"; + runTest(variables, unknownInt32); + + source = "false || x.single_int32 == 1"; + runTest(variables, unknownInt32); + + // dispatch test + declareFunction( + "f", memberOverload("f", Arrays.asList(SimpleType.INT, SimpleType.INT), SimpleType.BOOL)); + celRuntime = + newBaseRuntimeBuilder( + CelOptions.current() + .enableTimestampEpoch(true) + .enableHeterogeneousNumericComparisons(true) + .enableOptionalSyntax(true) + .comprehensionMaxIterations(1_000) + .build()) + .addFunctionBindings( + CelFunctionBinding.from("f", Integer.class, Integer.class, Objects::equals)) + .setContainer(CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage())) + .build(); + + source = "x.single_int32.f(1)"; + runTest(variables, unknownInt32); + + source = "1.f(x.single_int32)"; + runTest(variables, unknownInt32); + + source = "x.single_int64.f(x.single_int32)"; + runTest(variables, unknownInt32, unknownInt64); + + source = "[0, 2, 4].exists(z, z == 2 || z == x.single_int32)"; + runTest(variables, unknownInt32); + + source = "[0, 2, 4].exists(z, z == x.single_int32)"; + runTest(variables, unknownInt32); + + source = + "[0, 2, 4].exists_one(z, z == 0 || (z == 2 && z == x.single_int32) " + + "|| (z == 4 && z == x.single_int64))"; + runTest(variables, unknownInt32, unknownInt64); + + source = "[0, 2].all(z, z == 2 || z == x.single_int32)"; + runTest(variables, unknownInt32); + + source = + "[0, 2, 4].filter(z, z == 0 || (z == 2 && z == x.single_int32) " + + "|| (z == 4 && z == x.single_int64))"; + runTest(variables, unknownInt32, unknownInt64); + + source = + "[0, 2, 4].map(z, z == 0 || (z == 2 && z == x.single_int32) " + + "|| (z == 4 && z == x.single_int64))"; + runTest(variables, unknownInt32, unknownInt64); + + source = "x.single_int32 == 1 ? 1 : 2"; + runTest(variables, unknownInt32); + + source = "true ? x.single_int32 : 2"; + runTest(variables, unknownInt32); + + source = "true ? 1 : x.single_int32"; + runTest(variables, unknownInt32); + + source = "false ? x.single_int32 : 2"; + runTest(variables, unknownInt32); + + source = "false ? 1 : x.single_int32"; + runTest(variables, unknownInt32); + + source = "x.single_int64 == 1 ? x.single_int32 : x.single_int32"; + runTest(variables, unknownInt32, unknownInt64); + + source = "{x.single_int32: 2, 3: 4}"; + runTest(variables, unknownInt32); + + source = "{1: x.single_int32, 3: 4}"; + runTest(variables, unknownInt32); + + source = "{1: x.single_int32, x.single_int64: 4}"; + runTest(variables, unknownInt32, unknownInt64); + + source = "[1, x.single_int32, 3, 4]"; + runTest(variables, unknownInt32); + + source = "[1, x.single_int32, x.single_int64, 4]"; + runTest(variables, unknownInt32, unknownInt64); + + source = "TestAllTypes{single_int32: x.single_int32}.single_int32 == 2"; + runTest(variables, unknownInt32); + + source = "TestAllTypes{single_int32: x.single_int32, single_int64: x.single_int64}"; + runTest(variables, unknownInt32, unknownInt64); + + clearAllDeclarations(); + declareVariable("unknown_list", ListType.create(SimpleType.INT)); + source = "unknown_list.map(x, x)"; + runTest(variables, CelAttributePattern.fromQualifiedIdentifier("unknown_list")); + } + + @Test + public void planner_unknownResultSet_errors() { + declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); + TestAllTypes message = + TestAllTypes.newBuilder() + .setSingleString("test") + .setSingleTimestamp(Timestamp.newBuilder().setSeconds(15)) + .build(); + ImmutableMap variables = ImmutableMap.of("x", message); + CelAttributePattern unknownInt32 = + CelAttributePattern.fromQualifiedIdentifier("x.single_int32"); + + source = "x.single_int32 == 1 && x.single_timestamp <= timestamp(\"bad timestamp string\")"; + runTest(variables, unknownInt32); + + source = "x.single_timestamp <= timestamp(\"bad timestamp string\") && x.single_int32 == 1"; + runTest(variables, unknownInt32); + + source = + "x.single_timestamp <= timestamp(\"bad timestamp string\") " + + "&& x.single_timestamp > timestamp(\"another bad timestamp string\")"; + runTest(variables, unknownInt32); + + source = "x.single_int32 == 1 || x.single_timestamp <= timestamp(\"bad timestamp string\")"; + runTest(variables, unknownInt32); + + source = "x.single_timestamp <= timestamp(\"bad timestamp string\") || x.single_int32 == 1"; + runTest(variables, unknownInt32); + + source = + "x.single_timestamp <= timestamp(\"bad timestamp string\") " + + "|| x.single_timestamp > timestamp(\"another bad timestamp string\")"; + runTest(variables, unknownInt32); + + source = "x"; + runTest(ImmutableMap.of(), CelAttributePattern.fromQualifiedIdentifier("x")); } } diff --git a/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel index fb05b0b31..9116818dc 100644 --- a/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel @@ -43,10 +43,12 @@ java_library( "//runtime:descriptor_type_resolver", "//runtime:dispatcher", "//runtime:function_binding", + "//runtime:partial_vars", "//runtime:program", "//runtime:runtime_equality", "//runtime:runtime_helpers", "//runtime:standard_functions", + "//runtime:unknown_attributes", "//runtime/planner:program_planner", "//runtime/standard:type", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", diff --git a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java index 20b4e641a..de30902d3 100644 --- a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java +++ b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java @@ -65,13 +65,17 @@ import dev.cel.expr.conformance.proto3.TestAllTypes.NestedMessage; import dev.cel.extensions.CelExtensions; import dev.cel.parser.CelStandardMacro; +import dev.cel.runtime.CelAttribute; +import dev.cel.runtime.CelAttributePattern; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelLateFunctionBindings; import dev.cel.runtime.CelStandardFunctions; import dev.cel.runtime.CelStandardFunctions.StandardFunction; +import dev.cel.runtime.CelUnknownSet; import dev.cel.runtime.DefaultDispatcher; import dev.cel.runtime.DescriptorTypeResolver; +import dev.cel.runtime.PartialVars; import dev.cel.runtime.Program; import dev.cel.runtime.RuntimeEquality; import dev.cel.runtime.RuntimeHelpers; @@ -946,6 +950,35 @@ public void plan_comprehension_iterationLimit_success() throws Exception { ImmutableList.of(2L, 3L), ImmutableList.of(3L, 4L), ImmutableList.of(4L, 5L))); } + @Test + public void plan_partialEval_withWildcardQualification() throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addVar("unk", MapType.create(SimpleType.STRING, SimpleType.BOOL)) + .addVar("unk.a", SimpleType.BOOL) + .addVar("unk.b", SimpleType.BOOL) + .build(); + CelAbstractSyntaxTree ast = compile(compiler, "unk.a && unk.b && unk['c']"); + + Program program = PLANNER.plan(ast); + + CelUnknownSet result = + (CelUnknownSet) + program.eval( + PartialVars.of( + CelAttributePattern.create("unk") + .qualify(CelAttribute.Qualifier.ofWildCard()))); + + assertThat(result) + .isEqualTo( + CelUnknownSet.create( + ImmutableSet.of( + CelAttribute.create("unk"), + CelAttribute.create("unk").qualify(CelAttribute.Qualifier.ofString("a")), + CelAttribute.create("unk").qualify(CelAttribute.Qualifier.ofString("b"))), + ImmutableSet.of(2L, 5L, 7L))); + } + @Test public void localShadowIdentifier_inSelect() throws Exception { CelCompiler celCompiler = diff --git a/runtime/src/test/resources/planner_unknownFieldSelection.baseline b/runtime/src/test/resources/planner_unknownFieldSelection.baseline new file mode 100644 index 000000000..0cbc75299 --- /dev/null +++ b/runtime/src/test/resources/planner_unknownFieldSelection.baseline @@ -0,0 +1,111 @@ +Source: x +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=, unknown_attributes=[x]} +result: CelUnknownSet{attributes=[x], unknownExprIds=[1]} + +Source: x +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x], unknownExprIds=[1]} + +Source: x.single_int32 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {unknown_attributes=[x]} +result: CelUnknownSet{attributes=[x], unknownExprIds=[2]} + +Source: x.single_int32 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[2]} + +Source: x.map_int32_int64[22] +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {unknown_attributes=[x]} +result: CelUnknownSet{attributes=[x], unknownExprIds=[2]} + +Source: x.map_int32_int64[22] +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {unknown_attributes=[x.map_int32_int64]} +result: CelUnknownSet{attributes=[x.map_int32_int64], unknownExprIds=[2]} + +Source: x.repeated_nested_message[1] +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {unknown_attributes=[x]} +result: CelUnknownSet{attributes=[x], unknownExprIds=[2]} + +Source: x.repeated_nested_message[1] +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {unknown_attributes=[x.repeated_nested_message]} +result: CelUnknownSet{attributes=[x.repeated_nested_message], unknownExprIds=[2]} + +Source: x.single_nested_message.bb +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {unknown_attributes=[x]} +result: CelUnknownSet{attributes=[x], unknownExprIds=[3]} + +Source: x.single_nested_message.bb +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {unknown_attributes=[x.single_nested_message.bb]} +result: CelUnknownSet{attributes=[x.single_nested_message.bb], unknownExprIds=[3]} + +Source: {1: x.single_int32} +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {unknown_attributes=[x]} +result: CelUnknownSet{attributes=[x], unknownExprIds=[5]} + +Source: {1: x.single_int32} +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[5]} + +Source: [1, x.single_int32] +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {unknown_attributes=[x]} +result: CelUnknownSet{attributes=[x], unknownExprIds=[4]} + +Source: [1, x.single_int32] +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[4]} diff --git a/runtime/src/test/resources/planner_unknownResultSet_errors.baseline b/runtime/src/test/resources/planner_unknownResultSet_errors.baseline new file mode 100644 index 000000000..812067ddf --- /dev/null +++ b/runtime/src/test/resources/planner_unknownResultSet_errors.baseline @@ -0,0 +1,81 @@ +Source: x.single_int32 == 1 && x.single_timestamp <= timestamp("bad timestamp string") +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[2]} + +Source: x.single_timestamp <= timestamp("bad timestamp string") && x.single_int32 == 1 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[8]} + +Source: x.single_timestamp <= timestamp("bad timestamp string") && x.single_timestamp > timestamp("another bad timestamp string") +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +error: evaluation error at test_location:89: Text 'another bad timestamp string' could not be parsed at index 0 +error_code: BAD_FORMAT + +Source: x.single_int32 == 1 || x.single_timestamp <= timestamp("bad timestamp string") +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[2]} + +Source: x.single_timestamp <= timestamp("bad timestamp string") || x.single_int32 == 1 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[8]} + +Source: x.single_timestamp <= timestamp("bad timestamp string") || x.single_timestamp > timestamp("another bad timestamp string") +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +error: evaluation error at test_location:89: Text 'another bad timestamp string' could not be parsed at index 0 +error_code: BAD_FORMAT + +Source: x +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {unknown_attributes=[x]} +result: CelUnknownSet{attributes=[x], unknownExprIds=[1]} \ No newline at end of file diff --git a/runtime/src/test/resources/planner_unknownResultSet_success.baseline b/runtime/src/test/resources/planner_unknownResultSet_success.baseline new file mode 100644 index 000000000..2f2c218d0 --- /dev/null +++ b/runtime/src/test/resources/planner_unknownResultSet_success.baseline @@ -0,0 +1,461 @@ +Source: x.single_int32 == 1 && true +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[2]} + +Source: x.single_int32 == 1 && false +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: false + +Source: x.single_int32 == 1 && x.single_int64 == 1 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32, x.single_int64]} +result: CelUnknownSet{attributes=[x.single_int32, x.single_int64], unknownExprIds=[2, 7]} + +Source: true && x.single_int32 == 1 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[4]} + +Source: false && x.single_int32 == 1 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: false + +Source: x.single_int32 == 1 || x.single_string == "test" +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: true + +Source: x.single_int32 == 1 || x.single_string != "test" +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[2]} + +Source: x.single_int32 == 1 || x.single_int64 == 1 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32, x.single_int64]} +result: CelUnknownSet{attributes=[x.single_int32, x.single_int64], unknownExprIds=[2, 7]} + +Source: true || x.single_int32 == 1 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: true + +Source: false || x.single_int32 == 1 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[4]} + +Source: x.single_int32.f(1) +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[2]} + +Source: 1.f(x.single_int32) +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[4]} + +Source: x.single_int64.f(x.single_int32) +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32, x.single_int64]} +result: CelUnknownSet{attributes=[x.single_int32, x.single_int64], unknownExprIds=[2, 5]} + +Source: [0, 2, 4].exists(z, z == 2 || z == x.single_int32) +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: true + +Source: [0, 2, 4].exists(z, z == x.single_int32) +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[10]} + +Source: [0, 2, 4].exists_one(z, z == 0 || (z == 2 && z == x.single_int32) || (z == 4 && z == x.single_int64)) +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32, x.single_int64]} +result: CelUnknownSet{attributes=[x.single_int64], unknownExprIds=[27]} + +Source: [0, 2].all(z, z == 2 || z == x.single_int32) +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[13]} + +Source: [0, 2, 4].filter(z, z == 0 || (z == 2 && z == x.single_int32) || (z == 4 && z == x.single_int64)) +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32, x.single_int64]} +result: CelUnknownSet{attributes=[x.single_int64], unknownExprIds=[27]} + +Source: [0, 2, 4].map(z, z == 0 || (z == 2 && z == x.single_int32) || (z == 4 && z == x.single_int64)) +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32, x.single_int64]} +result: CelUnknownSet{attributes=[x.single_int32, x.single_int64], unknownExprIds=[18, 27]} + +Source: x.single_int32 == 1 ? 1 : 2 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[2]} + +Source: true ? x.single_int32 : 2 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[4]} + +Source: true ? 1 : x.single_int32 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: 1 + +Source: false ? x.single_int32 : 2 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: 2 + +Source: false ? 1 : x.single_int32 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[5]} + +Source: x.single_int64 == 1 ? x.single_int32 : x.single_int32 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32, x.single_int64]} +result: CelUnknownSet{attributes=[x.single_int64], unknownExprIds=[2]} + +Source: {x.single_int32: 2, 3: 4} +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[4]} + +Source: {1: x.single_int32, 3: 4} +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[5]} + +Source: {1: x.single_int32, x.single_int64: 4} +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32, x.single_int64]} +result: CelUnknownSet{attributes=[x.single_int32, x.single_int64], unknownExprIds=[5, 8]} + +Source: [1, x.single_int32, 3, 4] +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[4]} + +Source: [1, x.single_int32, x.single_int64, 4] +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32, x.single_int64]} +result: CelUnknownSet{attributes=[x.single_int32, x.single_int64], unknownExprIds=[4, 6]} + +Source: TestAllTypes{single_int32: x.single_int32}.single_int32 == 2 +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32]} +result: CelUnknownSet{attributes=[x.single_int32], unknownExprIds=[4]} + +Source: TestAllTypes{single_int32: x.single_int32, single_int64: x.single_int64} +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +declare f { + function f int.(int) -> bool +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x.single_int32, x.single_int64]} +result: CelUnknownSet{attributes=[x.single_int32, x.single_int64], unknownExprIds=[4, 7]} + +Source: unknown_list.map(x, x) +declare unknown_list { + value list(int) +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[unknown_list]} +result: CelUnknownSet{attributes=[unknown_list], unknownExprIds=[1]} \ No newline at end of file diff --git a/runtime/src/test/resources/unknownField.baseline b/runtime/src/test/resources/unknownField.baseline index c5f3c755a..8e4598bef 100644 --- a/runtime/src/test/resources/unknownField.baseline +++ b/runtime/src/test/resources/unknownField.baseline @@ -52,4 +52,4 @@ declare x { } =====> bindings: {} -result: CelUnknownSet{attributes=[], unknownExprIds=[3]} +result: CelUnknownSet{attributes=[], unknownExprIds=[3]} \ No newline at end of file diff --git a/testing/src/main/java/dev/cel/testing/BUILD.bazel b/testing/src/main/java/dev/cel/testing/BUILD.bazel index f2480a034..2ecabdf05 100644 --- a/testing/src/main/java/dev/cel/testing/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/BUILD.bazel @@ -90,7 +90,7 @@ java_library( "//extensions:optional_library", "//runtime", "//runtime:function_binding", - "//runtime:late_function_binding", + "//runtime:partial_vars", "//runtime:unknown_attributes", "@cel_spec//proto/cel/expr:checked_java_proto", "@cel_spec//proto/cel/expr:syntax_java_proto", diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index f3c1cf398..69db9c9db 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -75,6 +75,7 @@ import dev.cel.expr.conformance.proto3.TestAllTypes.NestedEnum; import dev.cel.expr.conformance.proto3.TestAllTypes.NestedMessage; import dev.cel.extensions.CelOptionalLibrary; +import dev.cel.runtime.CelAttributePattern; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelLateFunctionBindings; @@ -83,6 +84,7 @@ import dev.cel.runtime.CelRuntimeFactory; import dev.cel.runtime.CelUnknownSet; import dev.cel.runtime.CelVariableResolver; +import dev.cel.runtime.PartialVars; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; import java.io.IOException; import java.time.Duration; @@ -193,24 +195,43 @@ private Object runTest(Map input, CelLateFunctionBindings lateFunctio return runTestInternal(input, Optional.of(lateFunctionBindings)); } + @CanIgnoreReturnValue + protected Object runTest(Map input, CelAttributePattern... patterns) { + return runTestInternal(input, Optional.empty(), patterns); + } + /** * Helper to run a test for configured instance variables. Input must be of type map or {@link * CelVariableResolver}. */ - @SuppressWarnings("unchecked") private Object runTestInternal( Object input, Optional lateFunctionBindings) { + return runTestInternal(input, lateFunctionBindings, new CelAttributePattern[0]); + } + + // Test only + @SuppressWarnings("unchecked") + private Object runTestInternal( + Object input, + Optional lateFunctionBindings, + CelAttributePattern... patterns) { CelAbstractSyntaxTree ast = compileTestCase(); if (ast == null) { // Usually indicates test was not setup correctly println("Source compilation failed"); return null; } - printBinding(input); + printBinding(input, patterns); Object result = null; try { CelRuntime.Program program = celRuntime.createProgram(ast); - if (lateFunctionBindings.isPresent()) { + if (patterns.length > 0) { + PartialVars partialVars = + input instanceof Map + ? PartialVars.of((Map) input, patterns) + : PartialVars.of((CelVariableResolver) input, patterns); + result = program.eval(partialVars); + } else if (lateFunctionBindings.isPresent()) { if (input instanceof Map) { Map map = ((Map) input); CelVariableResolver variableResolver = (name) -> Optional.ofNullable(map.get(name)); @@ -2532,17 +2553,17 @@ private static String readResourceContent(String path) throws IOException { } @SuppressWarnings("unchecked") - private void printBinding(Object input) { + private void printBinding(Object input, CelAttributePattern... patterns) { if (input instanceof Map) { Map inputMap = (Map) input; - if (inputMap.isEmpty()) { + if (inputMap.isEmpty() && patterns.length == 0) { println("bindings: {}"); return; } boolean first = true; StringBuilder sb = new StringBuilder().append("{"); - for (Map.Entry entry : ((Map) input).entrySet()) { + for (Map.Entry entry : inputMap.entrySet()) { if (!first) { sb.append(", "); } @@ -2556,10 +2577,21 @@ private void printBinding(Object input) { sb.append(UnredactedDebugFormatForTest.unredactedToString(entry.getValue())); } } + if (patterns.length > 0) { + if (!inputMap.isEmpty()) { + sb.append(", "); + } + sb.append("unknown_attributes="); + sb.append(Arrays.toString(patterns)); + } sb.append("}"); println("bindings: " + sb); } else { - println("bindings: " + input); + if (patterns.length > 0) { + println("bindings: " + input + ", unknown_attributes=" + Arrays.toString(patterns)); + } else { + println("bindings: " + input); + } } } From 486a380903332f9882ad755ffe806ad58cab6e05 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 17 Mar 2026 14:56:01 -0700 Subject: [PATCH 06/66] Allow specifying a set of optimizers to run to the policy compiler PiperOrigin-RevId: 885225062 --- .../src/main/java/dev/cel/policy/BUILD.bazel | 2 ++ .../cel/policy/CelPolicyCompilerBuilder.java | 6 ++++ .../dev/cel/policy/CelPolicyCompilerImpl.java | 36 ++++++++++--------- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/policy/src/main/java/dev/cel/policy/BUILD.bazel b/policy/src/main/java/dev/cel/policy/BUILD.bazel index 52b0b1ba7..916f16f9b 100644 --- a/policy/src/main/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/main/java/dev/cel/policy/BUILD.bazel @@ -138,6 +138,7 @@ java_library( ], deps = [ ":compiler", + "//optimizer:ast_optimizer", "@maven//:com_google_errorprone_error_prone_annotations", ], ) @@ -214,6 +215,7 @@ java_library( "//common/types", "//common/types:type_providers", "//optimizer", + "//optimizer:ast_optimizer", "//optimizer:optimization_exception", "//optimizer:optimizer_builder", "//optimizer/optimizers:common_subexpression_elimination", diff --git a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerBuilder.java b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerBuilder.java index 592a0120d..4089477a1 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerBuilder.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerBuilder.java @@ -16,6 +16,8 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.CheckReturnValue; +import dev.cel.optimizer.CelAstOptimizer; +import java.util.List; /** Interface for building an instance of {@link CelPolicyCompiler} */ public interface CelPolicyCompilerBuilder { @@ -38,6 +40,10 @@ public interface CelPolicyCompilerBuilder { @CanIgnoreReturnValue CelPolicyCompilerBuilder setAstDepthLimit(int iterationLimit); + /** Configures the policy compiler to run the provided optimizers on compiled policies. */ + @CanIgnoreReturnValue + CelPolicyCompilerBuilder setOptimizers(List optimizers); + @CheckReturnValue CelPolicyCompiler build(); } diff --git a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java index f6f893c1c..7841b9827 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java @@ -33,6 +33,7 @@ import dev.cel.common.formats.ValueString; import dev.cel.common.types.CelType; import dev.cel.common.types.SimpleType; +import dev.cel.optimizer.CelAstOptimizer; import dev.cel.optimizer.CelOptimizationException; import dev.cel.optimizer.CelOptimizer; import dev.cel.optimizer.CelOptimizerFactory; @@ -63,6 +64,7 @@ final class CelPolicyCompilerImpl implements CelPolicyCompiler { private final Cel cel; private final String variablesPrefix; private final int iterationLimit; + private final ImmutableList optimizers; private final Optional astDepthValidator; @Override @@ -140,19 +142,7 @@ public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledR } CelOptimizer astOptimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(cel) - .addAstOptimizers( - ConstantFoldingOptimizer.getInstance(), - SubexpressionOptimizer.newInstance( - SubexpressionOptimizerOptions.newBuilder() - // "record" is used for recording subexpression results via - // BlueprintLateFunctionBinding. Safely eliminable, since repeated - // invocation does not change the intermediate results. - .addEliminableFunctions("record") - .populateMacroCalls(true) - .enableCelBlock(true) - .build())) - .build(); + CelOptimizerFactory.standardCelOptimizerBuilder(cel).addAstOptimizers(optimizers).build(); try { // Optimize the composed graph using const fold and CSE ast = astOptimizer.optimize(ast); @@ -339,6 +329,7 @@ static final class Builder implements CelPolicyCompilerBuilder { private final Cel cel; private String variablesPrefix; private int iterationLimit; + private ImmutableList optimizers; private Optional astDepthLimitValidator; private Builder(Cel cel) { @@ -362,7 +353,7 @@ public Builder setIterationLimit(int iterationLimit) { @Override @CanIgnoreReturnValue - public CelPolicyCompilerBuilder setAstDepthLimit(int astDepthLimit) { + public Builder setAstDepthLimit(int astDepthLimit) { if (astDepthLimit < 0) { astDepthLimitValidator = Optional.empty(); } else { @@ -371,27 +362,40 @@ public CelPolicyCompilerBuilder setAstDepthLimit(int astDepthLimit) { return this; } + @Override + public Builder setOptimizers(List optimizers) { + this.optimizers = ImmutableList.copyOf(optimizers); + return this; + } + @Override public CelPolicyCompiler build() { return new CelPolicyCompilerImpl( - cel, this.variablesPrefix, this.iterationLimit, astDepthLimitValidator); + cel, this.variablesPrefix, this.iterationLimit, this.optimizers, astDepthLimitValidator); } } static Builder newBuilder(Cel cel) { return new Builder(cel) .setVariablesPrefix(DEFAULT_VARIABLE_PREFIX) - .setIterationLimit(DEFAULT_ITERATION_LIMIT); + .setIterationLimit(DEFAULT_ITERATION_LIMIT) + .setOptimizers( + ImmutableList.of( + ConstantFoldingOptimizer.getInstance(), + SubexpressionOptimizer.newInstance( + SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()))); } private CelPolicyCompilerImpl( Cel cel, String variablesPrefix, int iterationLimit, + ImmutableList optimizers, Optional astDepthValidator) { this.cel = checkNotNull(cel); this.variablesPrefix = checkNotNull(variablesPrefix); this.iterationLimit = iterationLimit; + this.optimizers = optimizers; this.astDepthValidator = astDepthValidator; } } From b0a3f588f5a78988a61bda0ee9a449e21848d56e Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 19 Mar 2026 10:41:18 -0700 Subject: [PATCH 07/66] Set enableTimestampEpoch by default PiperOrigin-RevId: 886258398 --- common/src/main/java/dev/cel/common/CelOptions.java | 1 + 1 file changed, 1 insertion(+) diff --git a/common/src/main/java/dev/cel/common/CelOptions.java b/common/src/main/java/dev/cel/common/CelOptions.java index 9cf9a9caa..e3bb8776e 100644 --- a/common/src/main/java/dev/cel/common/CelOptions.java +++ b/common/src/main/java/dev/cel/common/CelOptions.java @@ -177,6 +177,7 @@ public static Builder current() { .enableUnsignedComparisonAndArithmeticIsUnsigned(true) .enableUnsignedLongs(true) .enableRegexPartialMatch(true) + .enableTimestampEpoch(true) .errorOnDuplicateMapKeys(true) .evaluateCanonicalTypesToNativeValues(true) .errorOnIntWrap(true) From 3572797c1bd1f19ad1ab4ebc7fd34df8ef0e5aea Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 19 Mar 2026 15:35:29 -0700 Subject: [PATCH 08/66] Various fixes to CelEnvironment, YAML parsing and serialization Includes: - Container setting is made optional to prevent the CEL environment from overriding an already set container with an empty one in case the env is extended multiple times - Brings CEL environment YAML serialization in parity (examples, type, description) - Fixes "comprehensions" to be "two-var-comprehensions" - Partitioned newVaueString into newYamlString and newSourceString, where the former respects YAML multiline syntax PiperOrigin-RevId: 886408709 --- .../java/dev/cel/bundle/CelEnvironment.java | 57 ++++++++++++++----- .../cel/bundle/CelEnvironmentYamlParser.java | 28 +++++++++ .../bundle/CelEnvironmentYamlSerializer.java | 6 +- .../bundle/CelEnvironmentExporterTest.java | 3 +- .../dev/cel/bundle/CelEnvironmentTest.java | 45 ++++++++++++--- .../bundle/CelEnvironmentYamlParserTest.java | 56 ++++++++++-------- .../dev/cel/common/formats/ParserContext.java | 30 +++++++++- .../dev/cel/common/formats/YamlHelper.java | 2 +- .../common/formats/YamlParserContextImpl.java | 13 ++++- .../dev/cel/policy/CelPolicyYamlParser.java | 39 +++++++------ .../cel/policy/CelPolicyYamlParserTest.java | 16 +++++- .../java/dev/cel/policy/PolicyTestHelper.java | 8 +-- .../resources/environment/extended_env.yaml | 57 +++++++++++-------- 13 files changed, 259 insertions(+), 101 deletions(-) diff --git a/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java b/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java index 8614b87b5..b85f16cb1 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java +++ b/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java @@ -43,6 +43,7 @@ import dev.cel.common.types.OptionalType; import dev.cel.common.types.SimpleType; import dev.cel.common.types.TypeParamType; +import dev.cel.common.types.TypeType; import dev.cel.compiler.CelCompiler; import dev.cel.compiler.CelCompilerBuilder; import dev.cel.compiler.CelCompilerLibrary; @@ -71,27 +72,28 @@ public abstract class CelEnvironment { "math", CanonicalCelExtension.MATH, "optional", CanonicalCelExtension.OPTIONAL, "protos", CanonicalCelExtension.PROTOS, + "regex", CanonicalCelExtension.REGEX, "sets", CanonicalCelExtension.SETS, "strings", CanonicalCelExtension.STRINGS, - "comprehensions", CanonicalCelExtension.COMPREHENSIONS); + "two-var-comprehensions", CanonicalCelExtension.COMPREHENSIONS); private static final ImmutableMap> LIMIT_HANDLERS = ImmutableMap.of( "cel.limit.expression_code_points", - (options, value) -> options.maxExpressionCodePointSize(value), + CelOptions.Builder::maxExpressionCodePointSize, "cel.limit.parse_error_recovery", - (options, value) -> options.maxParseErrorRecoveryLimit(value), + CelOptions.Builder::maxParseErrorRecoveryLimit, "cel.limit.parse_recursion_depth", - (options, value) -> options.maxParseRecursionDepth(value)); + CelOptions.Builder::maxParseRecursionDepth); private static final ImmutableMap FEATURE_HANDLERS = ImmutableMap.of( "cel.feature.macro_call_tracking", - (options, enabled) -> options.populateMacroCalls(enabled), + CelOptions.Builder::populateMacroCalls, "cel.feature.backtick_escape_syntax", - (options, enabled) -> options.enableQuotedIdentifierSyntax(enabled), + CelOptions.Builder::enableQuotedIdentifierSyntax, "cel.feature.cross_type_numeric_comparisons", - (options, enabled) -> options.enableHeterogeneousNumericComparisons(enabled)); + CelOptions.Builder::enableHeterogeneousNumericComparisons); /** Environment source in textual format (ex: textproto, YAML). */ public abstract Optional source(); @@ -99,10 +101,8 @@ public abstract class CelEnvironment { /** Name of the environment. */ public abstract String name(); - /** - * Container, which captures default namespace and aliases for value resolution. - */ - public abstract CelContainer container(); + /** Container, which captures default namespace and aliases for value resolution. */ + public abstract Optional container(); /** * An optional description of the environment (example: location of the file containing the config @@ -226,7 +226,6 @@ public static Builder newBuilder() { return new AutoValue_CelEnvironment.Builder() .setName("") .setDescription("") - .setContainer(CelContainer.ofName("")) .setVariables(ImmutableSet.of()) .setFunctions(ImmutableSet.of()) .setFeatures(ImmutableSet.of()) @@ -242,7 +241,6 @@ public CelCompiler extend(CelCompiler celCompiler, CelOptions celOptions) CelCompilerBuilder compilerBuilder = celCompiler .toCompilerBuilder() - .setContainer(container()) .setOptions(celOptions) .setTypeProvider(celTypeProvider) .addVarDeclarations( @@ -254,6 +252,8 @@ public CelCompiler extend(CelCompiler celCompiler, CelOptions celOptions) .map(f -> f.toCelFunctionDecl(celTypeProvider)) .collect(toImmutableList())); + container().ifPresent(compilerBuilder::setContainer); + addAllCompilerExtensions(compilerBuilder, celOptions); applyStandardLibrarySubset(compilerBuilder); @@ -416,6 +416,8 @@ public abstract static class VariableDecl { /** The type of the variable. */ public abstract TypeDecl type(); + public abstract Optional description(); + /** Builder for {@link VariableDecl}. */ @AutoValue.Builder public abstract static class Builder implements RequiredFieldsChecker { @@ -428,6 +430,8 @@ public abstract static class Builder implements RequiredFieldsChecker { public abstract VariableDecl.Builder setType(TypeDecl typeDecl); + public abstract VariableDecl.Builder setDescription(String name); + @Override public ImmutableList requiredFields() { return ImmutableList.of( @@ -459,6 +463,8 @@ public abstract static class FunctionDecl { public abstract String name(); + public abstract Optional description(); + public abstract ImmutableSet overloads(); /** Builder for {@link FunctionDecl}. */ @@ -471,6 +477,8 @@ public abstract static class Builder implements RequiredFieldsChecker { public abstract FunctionDecl.Builder setName(String name); + public abstract FunctionDecl.Builder setDescription(String description); + public abstract FunctionDecl.Builder setOverloads(ImmutableSet overloads); @Override @@ -519,6 +527,9 @@ public abstract static class OverloadDecl { /** List of function overload type values. */ public abstract ImmutableList arguments(); + /** Examples for the overload. */ + public abstract ImmutableList examples(); + /** Return type of the overload. Required. */ public abstract TypeDecl returnType(); @@ -537,8 +548,21 @@ public abstract static class Builder implements RequiredFieldsChecker { // This should stay package-private to encourage add/set methods to be used instead. abstract ImmutableList.Builder argumentsBuilder(); + abstract ImmutableList.Builder examplesBuilder(); + public abstract OverloadDecl.Builder setArguments(ImmutableList args); + @CanIgnoreReturnValue + public OverloadDecl.Builder addExamples(Iterable examples) { + this.examplesBuilder().addAll(checkNotNull(examples)); + return this; + } + + @CanIgnoreReturnValue + public OverloadDecl.Builder addExamples(String... examples) { + return addExamples(Arrays.asList(examples)); + } + @CanIgnoreReturnValue public OverloadDecl.Builder addArguments(Iterable args) { this.argumentsBuilder().addAll(checkNotNull(args)); @@ -667,6 +691,10 @@ public CelType toCelType(CelTypeProvider celTypeProvider) { CelType keyType = params().get(0).toCelType(celTypeProvider); CelType valueType = params().get(1).toCelType(celTypeProvider); return MapType.create(keyType, valueType); + case "type": + checkState( + params().size() == 1, "Expected 1 parameter for type, got %s", params().size()); + return TypeType.create(params().get(0).toCelType(celTypeProvider)); default: if (isTypeParam()) { return TypeParamType.create(name()); @@ -838,6 +866,7 @@ enum CanonicalCelExtension { SETS( (options, version) -> CelExtensions.sets(options), (options, version) -> CelExtensions.sets(options)), + REGEX((options, version) -> CelExtensions.regex(), (options, version) -> CelExtensions.regex()), LISTS((options, version) -> CelExtensions.lists(), (options, version) -> CelExtensions.lists()), COMPREHENSIONS( (options, version) -> CelExtensions.comprehensions(), @@ -1054,7 +1083,7 @@ public static OverloadSelector.Builder newBuilder() { } @FunctionalInterface - private static interface BooleanOptionConsumer { + private interface BooleanOptionConsumer { void accept(CelOptions.Builder options, boolean value); } } diff --git a/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java b/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java index ce8857654..f129d9f5d 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java +++ b/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java @@ -353,6 +353,9 @@ private VariableDecl parseVariable(ParserContext ctx, Node node) { case "name": builder.setName(newString(ctx, valueNode)); break; + case "description": + builder.setDescription(newString(ctx, valueNode)); + break; case "type": if (typeDeclBuilder != null) { ctx.reportError( @@ -428,6 +431,9 @@ private FunctionDecl parseFunction(ParserContext ctx, Node node) { case "overloads": builder.setOverloads(parseOverloads(ctx, valueNode)); break; + case "description": + builder.setDescription(newString(ctx, valueNode).trim()); + break; default: ctx.reportError(keyId, String.format("Unsupported function tag: %s", keyName)); break; @@ -479,6 +485,9 @@ private static ImmutableSet parseOverloads(ParserContext ctx case "target": overloadDeclBuilder.setTarget(parseTypeDecl(ctx, valueNode)); break; + case "examples": + overloadDeclBuilder.addExamples(parseOverloadExamples(ctx, valueNode)); + break; default: ctx.reportError(keyId, String.format("Unsupported overload tag: %s", fieldName)); break; @@ -494,6 +503,25 @@ private static ImmutableSet parseOverloads(ParserContext ctx return overloadSetBuilder.build(); } + private static ImmutableList parseOverloadExamples(ParserContext ctx, Node node) { + long listValueId = ctx.collectMetadata(node); + if (!assertYamlType(ctx, listValueId, node, YamlNodeType.LIST)) { + return ImmutableList.of(); + } + SequenceNode paramsListNode = (SequenceNode) node; + ImmutableList.Builder builder = ImmutableList.builder(); + for (Node elementNode : paramsListNode.getValue()) { + long elementNodeId = ctx.collectMetadata(elementNode); + if (!assertYamlType(ctx, elementNodeId, elementNode, YamlNodeType.STRING)) { + continue; + } + + builder.add(((ScalarNode) elementNode).getValue()); + } + + return builder.build(); + } + private static ImmutableList parseOverloadArguments( ParserContext ctx, Node node) { long listValueId = ctx.collectMetadata(node); diff --git a/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlSerializer.java b/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlSerializer.java index 179faf2ac..9d5b4b69e 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlSerializer.java +++ b/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlSerializer.java @@ -79,10 +79,8 @@ public Node representData(Object data) { if (!environment.description().isEmpty()) { configMap.put("description", environment.description()); } - if (!environment.container().name().isEmpty() - || !environment.container().abbreviations().isEmpty() - || !environment.container().aliases().isEmpty()) { - configMap.put("container", environment.container()); + if (environment.container().isPresent()) { + configMap.put("container", environment.container().get()); } if (!environment.extensions().isEmpty()) { configMap.put("extensions", environment.extensions().asList()); diff --git a/bundle/src/test/java/dev/cel/bundle/CelEnvironmentExporterTest.java b/bundle/src/test/java/dev/cel/bundle/CelEnvironmentExporterTest.java index 10b9dee8e..ae0de2c18 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelEnvironmentExporterTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelEnvironmentExporterTest.java @@ -333,7 +333,7 @@ public void container() { CelEnvironmentExporter exporter = CelEnvironmentExporter.newBuilder().build(); CelEnvironment celEnvironment = exporter.export(cel); - CelContainer container = celEnvironment.container(); + CelContainer container = celEnvironment.container().get(); assertThat(container.name()).isEqualTo("cntnr"); assertThat(container.abbreviations()).containsExactly("foo.Bar", "baz.Qux").inOrder(); assertThat(container.aliases()).containsAtLeast("nm", "user.name", "id", "user.id").inOrder(); @@ -368,4 +368,3 @@ public void options() { CelEnvironment.Limit.create("cel.limit.parse_recursion_depth", 10)); } } - diff --git a/bundle/src/test/java/dev/cel/bundle/CelEnvironmentTest.java b/bundle/src/test/java/dev/cel/bundle/CelEnvironmentTest.java index f7eb254d7..a5a2f3e6d 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelEnvironmentTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelEnvironmentTest.java @@ -28,6 +28,10 @@ import dev.cel.common.CelOptions; import dev.cel.common.CelValidationException; import dev.cel.common.CelValidationResult; +import dev.cel.common.types.CelType; +import dev.cel.common.types.CelTypeProvider; +import dev.cel.common.types.SimpleType; +import dev.cel.common.types.TypeType; import dev.cel.compiler.CelCompiler; import dev.cel.compiler.CelCompilerFactory; import dev.cel.parser.CelStandardMacro; @@ -44,9 +48,7 @@ public void newBuilder_defaults() { assertThat(environment.source()).isEmpty(); assertThat(environment.name()).isEmpty(); assertThat(environment.description()).isEmpty(); - assertThat(environment.container().name()).isEmpty(); - assertThat(environment.container().abbreviations()).isEmpty(); - assertThat(environment.container().aliases()).isEmpty(); + assertThat(environment.container()).isEmpty(); assertThat(environment.extensions()).isEmpty(); assertThat(environment.variables()).isEmpty(); assertThat(environment.functions()).isEmpty(); @@ -65,10 +67,10 @@ public void container() { .build()) .build(); - assertThat(environment.container().name()).isEqualTo("cntr"); - assertThat(environment.container().abbreviations()).containsExactly("foo.Bar", "baz.Qux"); - assertThat(environment.container().aliases()) - .containsExactly("nm", "user.name", "id", "user.id"); + CelContainer container = environment.container().get(); + assertThat(container.name()).isEqualTo("cntr"); + assertThat(container.abbreviations()).containsExactly("foo.Bar", "baz.Qux"); + assertThat(container.aliases()).containsExactly("nm", "user.name", "id", "user.id"); } @Test @@ -81,9 +83,10 @@ public void extend_allExtensions() throws Exception { ExtensionConfig.latest("math"), ExtensionConfig.latest("optional"), ExtensionConfig.latest("protos"), + ExtensionConfig.latest("regex"), ExtensionConfig.latest("sets"), ExtensionConfig.latest("strings"), - ExtensionConfig.latest("comprehensions")); + ExtensionConfig.latest("two-var-comprehensions")); CelEnvironment environment = CelEnvironment.newBuilder().addExtensions(extensionConfigs).build(); @@ -435,4 +438,30 @@ public void stdlibSubset_functionOverloadExcluded() throws Exception { result = extendedCompiler.compile("1 == 1 && 1 != 1 + 1"); assertThat(result.getErrorString()).contains("found no matching overload for '_+_'"); } + + @Test + public void typeDecl_toCelType_type() { + CelTypeProvider typeProvider = + CelCompilerFactory.standardCelCompilerBuilder().build().getTypeProvider(); + CelEnvironment.TypeDecl typeDecl = + CelEnvironment.TypeDecl.newBuilder() + .setName("type") + .addParams(CelEnvironment.TypeDecl.create("int")) + .build(); + + CelType celType = typeDecl.toCelType(typeProvider); + + assertThat(celType).isEqualTo(TypeType.create(SimpleType.INT)); + } + + @Test + public void typeDecl_toCelType_type_wrongParamCount_throws() { + CelTypeProvider typeProvider = + CelCompilerFactory.standardCelCompilerBuilder().build().getTypeProvider(); + CelEnvironment.TypeDecl typeDecl = CelEnvironment.TypeDecl.newBuilder().setName("type").build(); + + IllegalStateException e = + assertThrows(IllegalStateException.class, () -> typeDecl.toCelType(typeProvider)); + assertThat(e).hasMessageThat().contains("Expected 1 parameter for type, got 0"); + } } diff --git a/bundle/src/test/java/dev/cel/bundle/CelEnvironmentYamlParserTest.java b/bundle/src/test/java/dev/cel/bundle/CelEnvironmentYamlParserTest.java index e98f6110e..043664e8e 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelEnvironmentYamlParserTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelEnvironmentYamlParserTest.java @@ -675,9 +675,7 @@ private enum EnvironmentParseErrorTestcase { + " | - version: 0\n" + " | ..^"), ILLEGAL_LIBRARY_SUBSET_TAG( - "name: 'test_suite_name'\n" - + "stdlib:\n" - + " unknown_tag: 'test_value'\n", + "name: 'test_suite_name'\n" + "stdlib:\n" + " unknown_tag: 'test_value'\n", "ERROR: :3:3: Unsupported library subset tag: unknown_tag\n" + " | unknown_tag: 'test_value'\n" + " | ..^"), @@ -859,30 +857,40 @@ private enum EnvironmentYamlResourceTestCase { .setVariables( VariableDecl.newBuilder() .setName("msg") + .setDescription( + "msg represents all possible type permutation which CEL understands from a" + + " proto perspective") .setType(TypeDecl.create("cel.expr.conformance.proto3.TestAllTypes")) .build()) .setFunctions( - FunctionDecl.create( - "isEmpty", - ImmutableSet.of( - OverloadDecl.newBuilder() - .setId("wrapper_string_isEmpty") - .setTarget(TypeDecl.create("google.protobuf.StringValue")) - .setReturnType(TypeDecl.create("bool")) - .build(), - OverloadDecl.newBuilder() - .setId("list_isEmpty") - .setTarget( - TypeDecl.newBuilder() - .setName("list") - .addParams( - TypeDecl.newBuilder() - .setName("T") - .setIsTypeParam(true) - .build()) - .build()) - .setReturnType(TypeDecl.create("bool")) - .build()))) + FunctionDecl.newBuilder() + .setName("isEmpty") + .setDescription( + "determines whether a list is empty,\nor a string has no characters") + .setOverloads( + ImmutableSet.of( + OverloadDecl.newBuilder() + .setId("wrapper_string_isEmpty") + .setTarget(TypeDecl.create("google.protobuf.StringValue")) + .addExamples("''.isEmpty() // true") + .setReturnType(TypeDecl.create("bool")) + .build(), + OverloadDecl.newBuilder() + .setId("list_isEmpty") + .addExamples("[].isEmpty() // true") + .addExamples("[1].isEmpty() // false") + .setTarget( + TypeDecl.newBuilder() + .setName("list") + .addParams( + TypeDecl.newBuilder() + .setName("T") + .setIsTypeParam(true) + .build()) + .build()) + .setReturnType(TypeDecl.create("bool")) + .build())) + .build()) .setFeatures(CelEnvironment.FeatureFlag.create("cel.feature.macro_call_tracking", true)) .setLimits( ImmutableSet.of( diff --git a/common/src/main/java/dev/cel/common/formats/ParserContext.java b/common/src/main/java/dev/cel/common/formats/ParserContext.java index 0bdfdb299..17eff473f 100644 --- a/common/src/main/java/dev/cel/common/formats/ParserContext.java +++ b/common/src/main/java/dev/cel/common/formats/ParserContext.java @@ -42,6 +42,32 @@ public interface ParserContext { Map getIdToOffsetMap(); - /** NewString creates a new ValueString from the YAML node. */ - ValueString newValueString(T node); + /** + * @deprecated Use {@link #newSourceString} instead. + */ + @Deprecated + default ValueString newValueString(T node) { + return newSourceString(node); + } + + /** + * NewYamlString creates a new ValueString from the YAML node, evaluated according to standard + * YAML parsing rules. + * + *

This respects the whitespace folding semantics defined by the node's scalar style (e.g., + * folded string {@code >} versus literal string {@code |}). Use this method for general string + * fields such as {@code description}, {@code name}, or {@code id}. + */ + ValueString newYamlString(T node); + + /** + * NewRawString creates a new ValueString from the YAML node, preserving formatting for accurate + * source mapping. + * + *

This extracts the verbatim text directly from the source file, preserving raw block + * indentation and unmodified newlines. Use this method when the string represents code or a CEL + * expression where precise character-level offsets must be maintained for accurate diagnostic + * error reporting. + */ + ValueString newSourceString(T node); } diff --git a/common/src/main/java/dev/cel/common/formats/YamlHelper.java b/common/src/main/java/dev/cel/common/formats/YamlHelper.java index e0780b01f..c16126f95 100644 --- a/common/src/main/java/dev/cel/common/formats/YamlHelper.java +++ b/common/src/main/java/dev/cel/common/formats/YamlHelper.java @@ -136,7 +136,7 @@ public static boolean newBoolean(ParserContext ctx, Node node) { } public static String newString(ParserContext ctx, Node node) { - return ctx.newValueString(node).value(); + return ctx.newYamlString(node).value(); } private YamlHelper() {} diff --git a/common/src/main/java/dev/cel/common/formats/YamlParserContextImpl.java b/common/src/main/java/dev/cel/common/formats/YamlParserContextImpl.java index 456872803..9f6077562 100644 --- a/common/src/main/java/dev/cel/common/formats/YamlParserContextImpl.java +++ b/common/src/main/java/dev/cel/common/formats/YamlParserContextImpl.java @@ -62,7 +62,18 @@ public Map getIdToOffsetMap() { } @Override - public ValueString newValueString(Node node) { + public ValueString newYamlString(Node node) { + long id = collectMetadata(node); + if (!assertYamlType(this, id, node, YamlNodeType.STRING, YamlNodeType.TEXT)) { + return ValueString.of(id, ERROR); + } + + ScalarNode scalarNode = (ScalarNode) node; + return ValueString.of(id, scalarNode.getValue()); + } + + @Override + public ValueString newSourceString(Node node) { long id = collectMetadata(node); if (!assertYamlType(this, id, node, YamlNodeType.STRING, YamlNodeType.TEXT)) { return ValueString.of(id, ERROR); diff --git a/policy/src/main/java/dev/cel/policy/CelPolicyYamlParser.java b/policy/src/main/java/dev/cel/policy/CelPolicyYamlParser.java index 43595c4ab..18b406af0 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicyYamlParser.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicyYamlParser.java @@ -126,13 +126,13 @@ public CelPolicy parsePolicy(PolicyParserContext ctx, Node node) { parseImports(policyBuilder, ctx, valueNode); break; case "name": - policyBuilder.setName(ctx.newValueString(valueNode)); + policyBuilder.setName(ctx.newYamlString(valueNode)); break; case "description": - policyBuilder.setDescription(ctx.newValueString(valueNode)); + policyBuilder.setDescription(ctx.newYamlString(valueNode)); break; case "display_name": - policyBuilder.setDisplayName(ctx.newValueString(valueNode)); + policyBuilder.setDisplayName(ctx.newYamlString(valueNode)); break; case "rule": policyBuilder.setRule(parseRule(ctx, policyBuilder, valueNode)); @@ -189,7 +189,7 @@ private void parseImport( continue; } - policyBuilder.addImport(Import.create(valueId, ctx.newValueString(value))); + policyBuilder.addImport(Import.create(valueId, ctx.newYamlString(value))); } } @@ -212,10 +212,10 @@ public CelPolicy.Rule parseRule( Node value = nodeTuple.getValueNode(); switch (fieldName) { case "id": - ruleBuilder.setRuleId(ctx.newValueString(value)); + ruleBuilder.setRuleId(ctx.newYamlString(value)); break; case "description": - ruleBuilder.setDescription(ctx.newValueString(value)); + ruleBuilder.setDescription(ctx.newYamlString(value)); break; case "variables": ruleBuilder.addVariables(parseVariables(ctx, policyBuilder, value)); @@ -267,7 +267,7 @@ public CelPolicy.Match parseMatch( Node value = nodeTuple.getValueNode(); switch (fieldName) { case "condition": - matchBuilder.setCondition(ctx.newValueString(value)); + matchBuilder.setCondition(ctx.newSourceString(value)); break; case "output": matchBuilder @@ -275,7 +275,7 @@ public CelPolicy.Match parseMatch( .filter(result -> result.kind().equals(Match.Result.Kind.RULE)) .ifPresent( result -> ctx.reportError(tagId, "Only the rule or the output may be set")); - matchBuilder.setResult(Match.Result.ofOutput(ctx.newValueString(value))); + matchBuilder.setResult(Match.Result.ofOutput(ctx.newSourceString(value))); break; case "explanation": matchBuilder @@ -286,7 +286,7 @@ public CelPolicy.Match parseMatch( ctx.reportError( tagId, "Explanation can only be set on output match cases, not nested rules")); - matchBuilder.setExplanation(ctx.newValueString(value)); + matchBuilder.setExplanation(ctx.newYamlString(value)); break; case "rule": matchBuilder @@ -356,8 +356,8 @@ private Variable parseVariableInline( Node keyNode = nodeTuple.getKeyNode(); long keyId = ctx.collectMetadata(keyNode); builder - .setName(ctx.newValueString(keyNode)) - .setExpression(ctx.newValueString(nodeTuple.getValueNode())); + .setName(ctx.newYamlString(keyNode)) + .setExpression(ctx.newSourceString(nodeTuple.getValueNode())); iterations++; if (iterations > 1) { @@ -385,16 +385,16 @@ private Variable parseVariableObject( String keyName = ((ScalarNode) keyNode).getValue(); switch (keyName) { case "name": - builder.setName(ctx.newValueString(valueNode)); + builder.setName(ctx.newYamlString(valueNode)); break; case "expression": - builder.setExpression(ctx.newValueString(valueNode)); + builder.setExpression(ctx.newSourceString(valueNode)); break; case "description": - builder.setDescription(ctx.newValueString(valueNode)); + builder.setDescription(ctx.newYamlString(valueNode)); break; case "display_name": - builder.setDisplayName(ctx.newValueString(valueNode)); + builder.setDisplayName(ctx.newYamlString(valueNode)); break; default: tagVisitor.visitVariableTag(ctx, keyId, keyName, valueNode, policyBuilder, builder); @@ -449,8 +449,13 @@ public Map getIdToOffsetMap() { } @Override - public ValueString newValueString(Node node) { - return ctx.newValueString(node); + public ValueString newYamlString(Node node) { + return ctx.newYamlString(node); + } + + @Override + public ValueString newSourceString(Node node) { + return ctx.newSourceString(node); } } diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyYamlParserTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyYamlParserTest.java index f8327c255..22aec6746 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyYamlParserTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyYamlParserTest.java @@ -99,6 +99,20 @@ public void parseYamlPolicy_withDescription() throws Exception { .hasValue(ValueString.of(10, "this is a description of the variable")); } + @Test + public void parseYamlPolicy_withDescription_foldedStyle() throws Exception { + String policySource = + "name: 'policy_name'\n" + + "description: >-\n" + + " this is a multiline string\n" + + " that gets folded into a single line"; + + CelPolicy policy = POLICY_PARSER.parse(policySource); + + assertThat(policy.description().map(ValueString::value)) + .hasValue("this is a multiline string that gets folded into a single line"); + } + @Test public void parseYamlPolicy_withDisplayName() throws Exception { String policySource = @@ -144,7 +158,7 @@ public void parseYamlPolicy_withImports() throws Exception { assertThat(policy.imports()) .containsExactly( Import.create(8L, ValueString.of(9L, "foo")), - Import.create(12L, ValueString.of(13L, " bar"))) + Import.create(12L, ValueString.of(13L, "bar"))) .inOrder(); } diff --git a/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java b/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java index 8d9e0084b..18d5ffc69 100644 --- a/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java +++ b/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java @@ -268,7 +268,7 @@ public void visitPolicyTag( CelPolicy.Builder policyBuilder) { switch (tagName) { case "kind": - policyBuilder.putMetadata("kind", ctx.newValueString(node)); + policyBuilder.putMetadata("kind", ctx.newYamlString(node)); break; case "metadata": long metadataId = ctx.collectMetadata(node); @@ -299,7 +299,7 @@ public void visitRuleTag( Rule.Builder ruleBuilder) { switch (tagName) { case "failurePolicy": - policyBuilder.putMetadata(tagName, ctx.newValueString(node)); + policyBuilder.putMetadata(tagName, ctx.newYamlString(node)); break; case "matchConstraints": long matchConstraintsId = ctx.collectMetadata(node); @@ -343,13 +343,13 @@ public void visitMatchTag( case "expression": // The K8s expression to validate must return false in order to generate a violation // message. - ValueString conditionValue = ctx.newValueString(node); + ValueString conditionValue = ctx.newYamlString(node); conditionValue = conditionValue.toBuilder().setValue("!(" + conditionValue.value() + ")").build(); matchBuilder.setCondition(conditionValue); break; case "messageExpression": - matchBuilder.setResult(Result.ofOutput(ctx.newValueString(node))); + matchBuilder.setResult(Result.ofOutput(ctx.newYamlString(node))); break; default: TagVisitor.super.visitMatchTag(ctx, id, tagName, node, policyBuilder, matchBuilder); diff --git a/testing/src/test/resources/environment/extended_env.yaml b/testing/src/test/resources/environment/extended_env.yaml index 4763c868f..9fc2d511d 100644 --- a/testing/src/test/resources/environment/extended_env.yaml +++ b/testing/src/test/resources/environment/extended_env.yaml @@ -15,32 +15,43 @@ name: "extended-env" container: "cel.expr" extensions: - - name: "optional" - version: "2" - - name: "math" - version: "latest" +- name: "optional" + version: "2" +- name: "math" + version: "latest" variables: - - name: "msg" - type_name: "cel.expr.conformance.proto3.TestAllTypes" +- name: "msg" + type_name: "cel.expr.conformance.proto3.TestAllTypes" + description: >- + msg represents all possible type permutation which + CEL understands from a proto perspective functions: - - name: "isEmpty" - overloads: - - id: "wrapper_string_isEmpty" - target: - type_name: "google.protobuf.StringValue" - return: - type_name: "bool" - - id: "list_isEmpty" - target: - type_name: "list" - params: - - type_name: "T" - is_type_param: true - return: - type_name: "bool" +- name: "isEmpty" + description: |- + determines whether a list is empty, + or a string has no characters + overloads: + - id: "wrapper_string_isEmpty" + examples: + - "''.isEmpty() // true" + target: + type_name: "google.protobuf.StringValue" + return: + type_name: "bool" + - id: "list_isEmpty" + examples: + - "[].isEmpty() // true" + - "[1].isEmpty() // false" + target: + type_name: "list" + params: + - type_name: "T" + is_type_param: true + return: + type_name: "bool" features: - - name: cel.feature.macro_call_tracking - enabled: true +- name: cel.feature.macro_call_tracking + enabled: true limits: - name: cel.limit.expression_code_points value: 1000 From 690d0829812ad66d4dfeff25e07204683d628524 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 20 Mar 2026 10:38:12 -0700 Subject: [PATCH 09/66] Deprecate enableTimestampEpoch option PiperOrigin-RevId: 886867078 --- .../java/dev/cel/bundle/CelExperimentalFactory.java | 6 +----- bundle/src/test/java/dev/cel/bundle/CelImplTest.java | 1 - .../dev/cel/checker/CelCheckerLegacyImplTest.java | 2 +- common/src/main/java/dev/cel/common/CelOptions.java | 12 +++++++++--- .../java/dev/cel/conformance/ConformanceTest.java | 1 - .../dev/cel/extensions/CelOptionalLibraryTest.java | 6 +----- .../optimizers/ConstantFoldingOptimizerTest.java | 2 +- .../optimizer/optimizers/InliningOptimizerTest.java | 3 +-- .../SubexpressionOptimizerBaselineTest.java | 3 +-- .../optimizers/SubexpressionOptimizerTest.java | 3 +-- .../cel/runtime/CelRuntimeExperimentalFactory.java | 6 +----- .../test/java/dev/cel/runtime/ActivationTest.java | 3 +-- .../dev/cel/runtime/CelStandardFunctionsTest.java | 2 +- .../dev/cel/runtime/DescriptorTypeResolverTest.java | 2 -- .../java/dev/cel/runtime/PlannerInterpreterTest.java | 1 - .../java/dev/cel/testing/BaseInterpreterTest.java | 1 - .../java/dev/cel/testing/CelBaselineTestCase.java | 1 - .../validators/RegexLiteralValidatorTest.java | 3 +-- .../validators/TimestampLiteralValidatorTest.java | 5 ++--- 19 files changed, 22 insertions(+), 41 deletions(-) diff --git a/bundle/src/main/java/dev/cel/bundle/CelExperimentalFactory.java b/bundle/src/main/java/dev/cel/bundle/CelExperimentalFactory.java index 2275d1c56..9a3e95dd8 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelExperimentalFactory.java +++ b/bundle/src/main/java/dev/cel/bundle/CelExperimentalFactory.java @@ -50,11 +50,7 @@ public static CelBuilder plannerCelBuilder() { CelCheckerLegacyImpl.newBuilder().setStandardEnvironmentEnabled(true)), CelRuntimeImpl.newBuilder()) // CEL-Internal-2 - .setOptions( - CelOptions.current() - .enableHeterogeneousNumericComparisons(true) - .enableTimestampEpoch(true) - .build()); + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()); } private CelExperimentalFactory() {} diff --git a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java index 9f7083c92..ae37fca50 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java @@ -2109,7 +2109,6 @@ public void program_fdsContainsWktDependency_descriptorInstancesMatch() throws E standardCelBuilderWithMacros() .addMessageTypes(descriptors) // CEL-Internal-2 - .setOptions(CelOptions.current().enableTimestampEpoch(true).build()) .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) .build(); CelAbstractSyntaxTree ast = diff --git a/checker/src/test/java/dev/cel/checker/CelCheckerLegacyImplTest.java b/checker/src/test/java/dev/cel/checker/CelCheckerLegacyImplTest.java index c0c54381d..92a70c2d6 100644 --- a/checker/src/test/java/dev/cel/checker/CelCheckerLegacyImplTest.java +++ b/checker/src/test/java/dev/cel/checker/CelCheckerLegacyImplTest.java @@ -63,7 +63,7 @@ public void toCheckerBuilder_isImmutable() { public void toCheckerBuilder_singularFields_copied() { CelStandardDeclarations subsetDecls = CelStandardDeclarations.newBuilder().includeFunctions(StandardFunction.BOOL).build(); - CelOptions celOptions = CelOptions.current().enableTimestampEpoch(true).build(); + CelOptions celOptions = CelOptions.current().build(); CelContainer celContainer = CelContainer.ofName("foo"); CelType expectedResultType = SimpleType.BOOL; CelTypeProvider customTypeProvider = diff --git a/common/src/main/java/dev/cel/common/CelOptions.java b/common/src/main/java/dev/cel/common/CelOptions.java index e3bb8776e..0c348c8a8 100644 --- a/common/src/main/java/dev/cel/common/CelOptions.java +++ b/common/src/main/java/dev/cel/common/CelOptions.java @@ -293,14 +293,20 @@ public abstract static class Builder { public abstract Builder enableHomogeneousLiterals(boolean value); /** - * Enable the {@code int64_to_timestamp} overload which creates a timestamp from Uxix epoch + * Enable the {@code int64_to_timestamp} overload which creates a timestamp from Unix epoch * seconds. * - *

This option will be automatically enabled after a sufficient period of time has elapsed to - * ensure that all runtimes support the implementation. + *

Historically used to opt-in to this feature, this option is now enabled by default across + * all runtimes. * *

TODO: Remove this feature once it has been auto-enabled. + * + * @deprecated This option is now enabled by default. If you are passing {@code true}, simply + * remove this method call. If you are passing {@code false} to disable this feature, subset + * the environment instead using {@code dev.cel.checker.CelStandardDeclarations} and {@code + * dev.cel.runtime.CelStandardFunctions}. */ + @Deprecated public abstract Builder enableTimestampEpoch(boolean value); /** diff --git a/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java b/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java index 5a25fb9d9..86f9b8f29 100644 --- a/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java +++ b/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java @@ -58,7 +58,6 @@ public final class ConformanceTest extends Statement { private static final CelOptions OPTIONS = CelOptions.current() - .enableTimestampEpoch(true) .enableHeterogeneousNumericComparisons(true) .enableProtoDifferencerEquality(true) .enableOptionalSyntax(true) diff --git a/extensions/src/test/java/dev/cel/extensions/CelOptionalLibraryTest.java b/extensions/src/test/java/dev/cel/extensions/CelOptionalLibraryTest.java index ab412fb39..4f348c12a 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelOptionalLibraryTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelOptionalLibraryTest.java @@ -123,11 +123,7 @@ private CelBuilder newCelBuilder(int version) { } return celBuilder - .setOptions( - CelOptions.current() - .enableTimestampEpoch(true) - .enableHeterogeneousNumericComparisons(true) - .build()) + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) .setStandardMacros(CelStandardMacro.STANDARD_MACROS) .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) .addMessageTypes(TestAllTypes.getDescriptor()) diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java index e259a7a35..a8cadf83a 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -47,7 +47,7 @@ @RunWith(TestParameterInjector.class) public class ConstantFoldingOptimizerTest { private static final CelOptions CEL_OPTIONS = - CelOptions.current().populateMacroCalls(true).enableTimestampEpoch(true).build(); + CelOptions.current().populateMacroCalls(true).build(); private static final Cel CEL = CelFactory.standardCelBuilder() .addVar("x", SimpleType.DYN) diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/InliningOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/InliningOptimizerTest.java index 7930a03d8..da2e9b745 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/InliningOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/InliningOptimizerTest.java @@ -58,8 +58,7 @@ public class InliningOptimizerTest { "child", StructTypeReference.create(TestAllTypes.NestedMessage.getDescriptor().getFullName())) .addVar("shadowed_ident", SimpleType.INT) - .setOptions( - CelOptions.current().populateMacroCalls(true).enableTimestampEpoch(true).build()) + .setOptions(CelOptions.current().populateMacroCalls(true).build()) .build(); @Test diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java index 07573f428..802ef3037 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java @@ -265,8 +265,7 @@ private static CelBuilder newCelBuilder() { .addMessageTypes(TestAllTypes.getDescriptor()) .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .setOptions( - CelOptions.current().enableTimestampEpoch(true).populateMacroCalls(true).build()) + .setOptions(CelOptions.current().populateMacroCalls(true).build()) .addCompilerLibraries( CelExtensions.optional(), CelExtensions.bindings(), CelExtensions.comprehensions()) .addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions()) diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 735cd24f0..2289a7d4a 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -105,8 +105,7 @@ private static CelBuilder newCelBuilder() { return CelFactory.standardCelBuilder() .addMessageTypes(TestAllTypes.getDescriptor()) .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .setOptions( - CelOptions.current().enableTimestampEpoch(true).populateMacroCalls(true).build()) + .setOptions(CelOptions.current().populateMacroCalls(true).build()) .addCompilerLibraries(CelExtensions.bindings()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeExperimentalFactory.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeExperimentalFactory.java index d0089e48d..743f90669 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeExperimentalFactory.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeExperimentalFactory.java @@ -41,11 +41,7 @@ public final class CelRuntimeExperimentalFactory { public static CelRuntimeBuilder plannerRuntimeBuilder() { return CelRuntimeImpl.newBuilder() // CEL-Internal-2 - .setOptions( - CelOptions.current() - .enableTimestampEpoch(true) - .enableHeterogeneousNumericComparisons(true) - .build()); + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()); } private CelRuntimeExperimentalFactory() {} diff --git a/runtime/src/test/java/dev/cel/runtime/ActivationTest.java b/runtime/src/test/java/dev/cel/runtime/ActivationTest.java index 5e3f3f1fe..fc435c848 100644 --- a/runtime/src/test/java/dev/cel/runtime/ActivationTest.java +++ b/runtime/src/test/java/dev/cel/runtime/ActivationTest.java @@ -33,11 +33,10 @@ public final class ActivationTest { private static final CelOptions TEST_OPTIONS = - CelOptions.current().enableTimestampEpoch(true).enableUnsignedLongs(true).build(); + CelOptions.current().enableUnsignedLongs(true).build(); private static final CelOptions TEST_OPTIONS_SKIP_UNSET_FIELDS = CelOptions.current() - .enableTimestampEpoch(true) .enableUnsignedLongs(true) .fromProtoUnsetFieldOption(CelOptions.ProtoUnsetFieldOptions.SKIP) .build(); diff --git a/runtime/src/test/java/dev/cel/runtime/CelStandardFunctionsTest.java b/runtime/src/test/java/dev/cel/runtime/CelStandardFunctionsTest.java index d85ef7424..c5f5572a7 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelStandardFunctionsTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelStandardFunctionsTest.java @@ -224,7 +224,7 @@ public void unsignedLongsDisabled_int64Identity_throws() { public void timestampEpochDisabled_int64Identity_throws() { CelCompiler celCompiler = CelCompilerFactory.standardCelCompilerBuilder() - .setOptions(CelOptions.current().enableTimestampEpoch(true).build()) + .setOptions(CelOptions.current().build()) .build(); CelRuntime celRuntime = CelRuntimeFactory.standardCelRuntimeBuilder() diff --git a/runtime/src/test/java/dev/cel/runtime/DescriptorTypeResolverTest.java b/runtime/src/test/java/dev/cel/runtime/DescriptorTypeResolverTest.java index 878576f94..bdd865601 100644 --- a/runtime/src/test/java/dev/cel/runtime/DescriptorTypeResolverTest.java +++ b/runtime/src/test/java/dev/cel/runtime/DescriptorTypeResolverTest.java @@ -24,7 +24,6 @@ import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; -import dev.cel.common.CelOptions; import dev.cel.common.types.OpaqueType; import dev.cel.common.types.OptionalType; import dev.cel.common.types.ProtoMessageTypeProvider; @@ -44,7 +43,6 @@ public class DescriptorTypeResolverTest { private static final Cel CEL = CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableTimestampEpoch(true).build()) .setTypeProvider(PROTO_MESSAGE_TYPE_PROVIDER) .addCompilerLibraries(CelOptionalLibrary.INSTANCE) .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) diff --git a/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java index 2c0bec739..181842ab4 100644 --- a/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java +++ b/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java @@ -198,7 +198,6 @@ public void planner_unknownResultSet_success() { celRuntime = newBaseRuntimeBuilder( CelOptions.current() - .enableTimestampEpoch(true) .enableHeterogeneousNumericComparisons(true) .enableOptionalSyntax(true) .comprehensionMaxIterations(1_000) diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index 69db9c9db..42cb5e41b 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -114,7 +114,6 @@ public abstract class BaseInterpreterTest extends CelBaselineTestCase { private static final CelOptions BASE_CEL_OPTIONS = CelOptions.current() - .enableTimestampEpoch(true) .enableHeterogeneousNumericComparisons(true) .enableOptionalSyntax(true) .comprehensionMaxIterations(1_000) diff --git a/testing/src/main/java/dev/cel/testing/CelBaselineTestCase.java b/testing/src/main/java/dev/cel/testing/CelBaselineTestCase.java index 79ae88f47..8c79e5931 100644 --- a/testing/src/main/java/dev/cel/testing/CelBaselineTestCase.java +++ b/testing/src/main/java/dev/cel/testing/CelBaselineTestCase.java @@ -56,7 +56,6 @@ public abstract class CelBaselineTestCase extends BaselineTestCase { protected static final int COMPREHENSION_MAX_ITERATIONS = 1_000; protected static final CelOptions TEST_OPTIONS = CelOptions.current() - .enableTimestampEpoch(true) .enableHeterogeneousNumericComparisons(true) .enableHiddenAccumulatorVar(true) .enableOptionalSyntax(true) diff --git a/validator/src/test/java/dev/cel/validator/validators/RegexLiteralValidatorTest.java b/validator/src/test/java/dev/cel/validator/validators/RegexLiteralValidatorTest.java index 35a9ffd4f..a41317371 100644 --- a/validator/src/test/java/dev/cel/validator/validators/RegexLiteralValidatorTest.java +++ b/validator/src/test/java/dev/cel/validator/validators/RegexLiteralValidatorTest.java @@ -39,8 +39,7 @@ @RunWith(TestParameterInjector.class) public class RegexLiteralValidatorTest { - private static final CelOptions CEL_OPTIONS = - CelOptions.current().enableTimestampEpoch(true).build(); + private static final CelOptions CEL_OPTIONS = CelOptions.current().build(); private static final Cel CEL = CelFactory.standardCelBuilder().setOptions(CEL_OPTIONS).build(); diff --git a/validator/src/test/java/dev/cel/validator/validators/TimestampLiteralValidatorTest.java b/validator/src/test/java/dev/cel/validator/validators/TimestampLiteralValidatorTest.java index 7770df54c..404ed7f7e 100644 --- a/validator/src/test/java/dev/cel/validator/validators/TimestampLiteralValidatorTest.java +++ b/validator/src/test/java/dev/cel/validator/validators/TimestampLiteralValidatorTest.java @@ -41,8 +41,7 @@ @RunWith(TestParameterInjector.class) public class TimestampLiteralValidatorTest { - private static final CelOptions CEL_OPTIONS = - CelOptions.current().enableTimestampEpoch(true).build(); + private static final CelOptions CEL_OPTIONS = CelOptions.current().build(); private static final Cel CEL = CelFactory.standardCelBuilder().setOptions(CEL_OPTIONS).build(); @@ -205,7 +204,7 @@ public void parentIsNotCallExpr_doesNotThrow(String source) throws Exception { public void env_withSetResultType_success() throws Exception { Cel cel = CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableTimestampEpoch(true).build()) + .setOptions(CelOptions.current().build()) .setResultType(SimpleType.BOOL) .build(); CelValidator validator = From b92f1f6a4ddaefe3f9aeca6a284b1cf097f63ef2 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 26 Mar 2026 16:28:49 -0700 Subject: [PATCH 10/66] JSON field name resolution fix for shadowed cases PiperOrigin-RevId: 890106179 --- .../src/test/java/dev/cel/bundle/BUILD.bazel | 2 + .../test/java/dev/cel/bundle/CelImplTest.java | 135 +++++++++++++++--- .../cel/common/values/ProtoMessageValue.java | 10 +- .../values/ProtoMessageValueProvider.java | 10 +- .../cel/common/internal/DynamicProtoTest.java | 2 +- .../types/ProtoMessageTypeProviderTest.java | 6 +- .../cel/policy/CelPolicyCompilerImplTest.java | 2 +- .../runtime/DescriptorMessageProvider.java | 32 ++--- .../dev/cel/runtime/CelLiteRuntimeTest.java | 2 +- testing/protos/BUILD.bazel | 5 + testing/src/test/resources/protos/BUILD.bazel | 13 ++ .../test/resources/protos/single_file.proto | 14 +- .../protos/single_file_extensions.proto | 27 ++++ 13 files changed, 208 insertions(+), 52 deletions(-) create mode 100644 testing/src/test/resources/protos/single_file_extensions.proto diff --git a/bundle/src/test/java/dev/cel/bundle/BUILD.bazel b/bundle/src/test/java/dev/cel/bundle/BUILD.bazel index ffa3322fe..2901e1ff9 100644 --- a/bundle/src/test/java/dev/cel/bundle/BUILD.bazel +++ b/bundle/src/test/java/dev/cel/bundle/BUILD.bazel @@ -17,6 +17,7 @@ java_library( deps = [ "//:java_truth", "//bundle:cel", + "//bundle:cel_experimental_factory", "//bundle:cel_impl", "//bundle:environment", "//bundle:environment_exception", @@ -55,6 +56,7 @@ java_library( "//runtime:evaluation_listener", "//runtime:function_binding", "//runtime:unknown_attributes", + "//testing/protos:single_file_extension_java_proto", "//testing/protos:single_file_java_proto", "@cel_spec//proto/cel/expr:checked_java_proto", "@cel_spec//proto/cel/expr:syntax_java_proto", diff --git a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java index ae37fca50..22ef7e2f4 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java @@ -98,6 +98,7 @@ import dev.cel.expr.conformance.proto2.Proto2ExtensionScopedMessage; import dev.cel.expr.conformance.proto2.TestAllTypesExtensions; import dev.cel.expr.conformance.proto3.TestAllTypes; +import dev.cel.extensions.CelExtensions; import dev.cel.parser.CelParserImpl; import dev.cel.parser.CelStandardMacro; import dev.cel.runtime.CelAttribute; @@ -113,7 +114,8 @@ import dev.cel.runtime.CelUnknownSet; import dev.cel.runtime.CelVariableResolver; import dev.cel.runtime.UnknownContext; -import dev.cel.testing.testdata.SingleFileProto.SingleFile; +import dev.cel.testing.testdata.SingleFile; +import dev.cel.testing.testdata.SingleFileExtensionsProto; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; import java.time.Instant; import java.util.ArrayList; @@ -2142,20 +2144,90 @@ public void toBuilder_isImmutable() { } @Test - public void eval_withJsonFieldName() throws Exception { - Cel cel = - standardCelBuilderWithMacros() - .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) - .addMessageTypes(SingleFile.getDescriptor()) - .setOptions(CelOptions.current().enableJsonFieldNames(true).build()) - .build(); - CelAbstractSyntaxTree ast = cel.compile("file.camelCased").getAst(); + public void eval_withJsonFieldName(@TestParameter RuntimeEnv runtimeEnv) throws Exception { + Cel cel = runtimeEnv.cel; + CelAbstractSyntaxTree ast = + cel.compile( + "file.int32_snake_case_json_name == 1 && " + + "file.int64CamelCaseJsonName == 2 && " + + "file.uint32DefaultJsonName == 3u && " + + "file.`uint64-custom-json-name` == 4u && " + + "file.single_string == 'shadows' && " + + "file.singleString == 'shadowed'") + .getAst(); + + boolean result = + (boolean) + cel.createProgram(ast) + .eval( + ImmutableMap.of( + "file", + SingleFile.newBuilder() + .setInt32SnakeCaseJsonName(1) + .setInt64CamelCaseJsonName(2L) + .setUint32DefaultJsonName(3) + .setUint64CustomJsonName(4) + .setStringJsonNameShadows("shadows") + .setSingleString("shadowed") + .setExtension(SingleFileExtensionsProto.int64CamelCaseJsonName, 5L) + .build())); - Object result = - cel.createProgram(ast) - .eval(ImmutableMap.of("file", SingleFile.newBuilder().setSnakeCased("foo").build())); + assertThat(result).isTrue(); + } + + @Test + public void eval_withJsonFieldName_fieldsFallBack(@TestParameter RuntimeEnv runtimeEnv) throws Exception { + Cel cel = runtimeEnv.cel; + CelAbstractSyntaxTree ast = + cel.compile( + "dyn(file).int32_snake_case_json_name == 1 && " + + "dyn(file).`uint64-custom-json-name` == 4u && " + + "dyn(file).single_string == 'shadows' && " + + "dyn(file).string_json_name_shadows == 'shadows' && " + + "dyn(file).singleString == 'shadowed'") + .getAst(); + + boolean result = + (boolean) + cel.createProgram(ast) + .eval( + ImmutableMap.of( + "file", + SingleFile.newBuilder() + .setInt32SnakeCaseJsonName(1) + .setInt64CamelCaseJsonName(2L) + .setUint32DefaultJsonName(3) + .setUint64CustomJsonName(4) + .setStringJsonNameShadows("shadows") + .setSingleString("shadowed") + .build())); - assertThat(result).isEqualTo("foo"); + assertThat(result).isTrue(); + } + + @Test + public void eval_withJsonFieldName_extensionFields(@TestParameter RuntimeEnv runtimeEnv) throws Exception { + Cel cel = runtimeEnv.cel; + CelAbstractSyntaxTree ast = + cel.compile( + "proto.getExt(file, dev.cel.testing.testdata.int64CamelCaseJsonName) == 5 &&" + + " proto.getExt(file, dev.cel.testing.testdata.single_string) == 'foo'") + .getAst(); + + boolean result = + (boolean) + cel.createProgram(ast) + .eval( + ImmutableMap.of( + "file", + SingleFile.newBuilder() + .setInt64CamelCaseJsonName(2L) + .setExtension(SingleFileExtensionsProto.int64CamelCaseJsonName, 5L) + .setSingleString("This should not be used") + .setExtension(SingleFileExtensionsProto.singleString, "foo") + .build())); + + assertThat(result).isTrue(); } @Test @@ -2171,7 +2243,7 @@ public void eval_withJsonFieldName_runtimeOptionDisabled_throws() throws Excepti .addMessageTypes(SingleFile.getDescriptor()) .setOptions(CelOptions.current().enableJsonFieldNames(false).build()) .build(); - CelAbstractSyntaxTree ast = celCompiler.compile("file.camelCased").getAst(); + CelAbstractSyntaxTree ast = celCompiler.compile("file.int64CamelCaseJsonName").getAst(); CelEvaluationException e = assertThrows( @@ -2183,7 +2255,8 @@ public void eval_withJsonFieldName_runtimeOptionDisabled_throws() throws Excepti assertThat(e) .hasMessageThat() .contains( - "field 'camelCased' is not declared in message 'dev.cel.testing.testdata.SingleFile"); + "field 'int64CamelCaseJsonName' is not declared in message" + + " 'dev.cel.testing.testdata.SingleFile"); } @Test @@ -2194,7 +2267,7 @@ public void compile_withJsonFieldName_astTagged() throws Exception { .addMessageTypes(SingleFile.getDescriptor()) .setOptions(CelOptions.current().enableJsonFieldNames(true).build()) .build(); - CelAbstractSyntaxTree ast = cel.compile("file.camelCased").getAst(); + CelAbstractSyntaxTree ast = cel.compile("file.int64CamelCaseJsonName").getAst(); assertThat(ast.getSource().getExtensions()) .contains( @@ -2243,4 +2316,34 @@ private static TypeProvider aliasingProvider(ImmutableMap typeAlia } }; } + + private enum RuntimeEnv { + LEGACY(setupEnv(CelFactory.standardCelBuilder())), + PLANNER(setupEnv(CelExperimentalFactory.plannerCelBuilder())) + ; + + private final Cel cel; + + private static Cel setupEnv(CelBuilder celBuilder) { + ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); + SingleFileExtensionsProto.registerAllExtensions(extensionRegistry); + return celBuilder + .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) + .addMessageTypes(SingleFile.getDescriptor()) + .addFileTypes(SingleFileExtensionsProto.getDescriptor()) + .addCompilerLibraries(CelExtensions.protos()) + .setExtensionRegistry(extensionRegistry) + .setOptions( + CelOptions.current() + .enableJsonFieldNames(true) + .enableHeterogeneousNumericComparisons(true) + .enableQuotedIdentifierSyntax(true) + .build()) + .build(); + } + + RuntimeEnv(Cel cel) { + this.cel = cel; + } + } } diff --git a/common/src/main/java/dev/cel/common/values/ProtoMessageValue.java b/common/src/main/java/dev/cel/common/values/ProtoMessageValue.java index e402bb429..12d47c253 100644 --- a/common/src/main/java/dev/cel/common/values/ProtoMessageValue.java +++ b/common/src/main/java/dev/cel/common/values/ProtoMessageValue.java @@ -92,11 +92,6 @@ public static ProtoMessageValue create( private FieldDescriptor findField( CelDescriptorPool celDescriptorPool, Descriptor descriptor, String fieldName) { - FieldDescriptor fieldDescriptor = descriptor.findFieldByName(fieldName); - if (fieldDescriptor != null) { - return fieldDescriptor; - } - if (enableJsonFieldNames()) { for (FieldDescriptor fd : descriptor.getFields()) { if (fd.getJsonName().equals(fieldName)) { @@ -105,6 +100,11 @@ private FieldDescriptor findField( } } + FieldDescriptor fieldDescriptor = descriptor.findFieldByName(fieldName); + if (fieldDescriptor != null) { + return fieldDescriptor; + } + return celDescriptorPool .findExtensionDescriptor(descriptor, fieldName) .orElseThrow( diff --git a/common/src/main/java/dev/cel/common/values/ProtoMessageValueProvider.java b/common/src/main/java/dev/cel/common/values/ProtoMessageValueProvider.java index b7895d845..7beb40c61 100644 --- a/common/src/main/java/dev/cel/common/values/ProtoMessageValueProvider.java +++ b/common/src/main/java/dev/cel/common/values/ProtoMessageValueProvider.java @@ -68,11 +68,6 @@ public Optional newValue(String structType, Map fields) } private FieldDescriptor findField(Descriptor descriptor, String fieldName) { - FieldDescriptor fieldDescriptor = descriptor.findFieldByName(fieldName); - if (fieldDescriptor != null) { - return fieldDescriptor; - } - if (celOptions.enableJsonFieldNames()) { for (FieldDescriptor fd : descriptor.getFields()) { if (fd.getJsonName().equals(fieldName)) { @@ -81,6 +76,11 @@ private FieldDescriptor findField(Descriptor descriptor, String fieldName) { } } + FieldDescriptor fieldDescriptor = descriptor.findFieldByName(fieldName); + if (fieldDescriptor != null) { + return fieldDescriptor; + } + return protoMessageFactory .getDescriptorPool() .findExtensionDescriptor(descriptor, fieldName) diff --git a/common/src/test/java/dev/cel/common/internal/DynamicProtoTest.java b/common/src/test/java/dev/cel/common/internal/DynamicProtoTest.java index cc5ba5632..7be994391 100644 --- a/common/src/test/java/dev/cel/common/internal/DynamicProtoTest.java +++ b/common/src/test/java/dev/cel/common/internal/DynamicProtoTest.java @@ -37,7 +37,7 @@ import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelDescriptors; import dev.cel.testing.testdata.MultiFile; -import dev.cel.testing.testdata.SingleFileProto.SingleFile; +import dev.cel.testing.testdata.SingleFile; import java.io.IOException; import org.junit.Test; import org.junit.runner.RunWith; diff --git a/common/src/test/java/dev/cel/common/types/ProtoMessageTypeProviderTest.java b/common/src/test/java/dev/cel/common/types/ProtoMessageTypeProviderTest.java index 16797b714..c9f9d9e21 100644 --- a/common/src/test/java/dev/cel/common/types/ProtoMessageTypeProviderTest.java +++ b/common/src/test/java/dev/cel/common/types/ProtoMessageTypeProviderTest.java @@ -23,7 +23,7 @@ import dev.cel.common.types.StructType.Field; import dev.cel.expr.conformance.proto2.TestAllTypes; import dev.cel.expr.conformance.proto2.TestAllTypesExtensions; -import dev.cel.testing.testdata.SingleFileProto.SingleFile; +import dev.cel.testing.testdata.SingleFile; import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; @@ -269,8 +269,8 @@ public void findField_withJsonNameOption() { (ProtoMessageType) typeProvider.findType(SingleFile.getDescriptor().getFullName()).get(); // Note that these are the same fields, with json_name option set - Optional snakeCasedField = msgType.findField("snake_cased"); - Optional jsonNameField = msgType.findField("camelCased"); + Optional snakeCasedField = msgType.findField("int64_camel_case_json_name"); + Optional jsonNameField = msgType.findField("int64CamelCaseJsonName"); assertThat(snakeCasedField).isEmpty(); assertThat(jsonNameField).isPresent(); diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index 336a392ff..fec5f9b94 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -45,7 +45,7 @@ import dev.cel.policy.PolicyTestHelper.TestYamlPolicy; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelLateFunctionBindings; -import dev.cel.testing.testdata.SingleFileProto.SingleFile; +import dev.cel.testing.testdata.SingleFile; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; import java.io.IOException; import java.util.Map; diff --git a/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java b/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java index ba0e442ec..ecbba5e7e 100644 --- a/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java +++ b/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java @@ -173,30 +173,28 @@ public Object hasField(Object message, String fieldName) { } private FieldDescriptor findField(Descriptor descriptor, String fieldName) { - FieldDescriptor fieldDescriptor = descriptor.findFieldByName(fieldName); - if (fieldDescriptor == null) { - Optional maybeFieldDescriptor = - protoMessageFactory.getDescriptorPool().findExtensionDescriptor(descriptor, fieldName); - if (maybeFieldDescriptor.isPresent()) { - fieldDescriptor = maybeFieldDescriptor.get(); - } - } - - if (fieldDescriptor == null && celOptions.enableJsonFieldNames()) { + if (celOptions.enableJsonFieldNames()) { for (FieldDescriptor fd : descriptor.getFields()) { if (fd.getJsonName().equals(fieldName)) { - fieldDescriptor = fd; - break; + return fd; } } } - if (fieldDescriptor == null) { - throw new IllegalArgumentException( - String.format( - "field '%s' is not declared in message '%s'", fieldName, descriptor.getFullName())); + FieldDescriptor fieldDescriptor = descriptor.findFieldByName(fieldName); + if (fieldDescriptor != null) { + return fieldDescriptor; + } + fieldDescriptor = + protoMessageFactory.getDescriptorPool().findExtensionDescriptor(descriptor, fieldName).orElse(null); + if (fieldDescriptor != null) { + return fieldDescriptor; } - return fieldDescriptor; + + + throw new IllegalArgumentException( + String.format( + "field '%s' is not declared in message '%s'", fieldName, descriptor.getFullName())); } private static MessageOrBuilder assertFullProtoMessage(Object candidate, String fieldName) { diff --git a/runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeTest.java b/runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeTest.java index 4ffe0941c..0ce7bd184 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeTest.java @@ -59,8 +59,8 @@ import dev.cel.testing.testdata.MultiFile; import dev.cel.testing.testdata.MultiFileCelDescriptor; import dev.cel.testing.testdata.SimpleEnum; +import dev.cel.testing.testdata.SingleFile; import dev.cel.testing.testdata.SingleFileCelDescriptor; -import dev.cel.testing.testdata.SingleFileProto.SingleFile; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; diff --git a/testing/protos/BUILD.bazel b/testing/protos/BUILD.bazel index c51fbba85..17706ca45 100644 --- a/testing/protos/BUILD.bazel +++ b/testing/protos/BUILD.bazel @@ -9,6 +9,11 @@ alias( actual = "//testing/src/test/resources/protos:single_file_java_proto", ) +alias( + name = "single_file_extension_java_proto", + actual = "//testing/src/test/resources/protos:single_file_extension_java_proto", +) + alias( name = "multi_file_java_proto", actual = "//testing/src/test/resources/protos:multi_file_java_proto", diff --git a/testing/src/test/resources/protos/BUILD.bazel b/testing/src/test/resources/protos/BUILD.bazel index af361b174..1fac2e1f0 100644 --- a/testing/src/test/resources/protos/BUILD.bazel +++ b/testing/src/test/resources/protos/BUILD.bazel @@ -25,6 +25,19 @@ java_proto_library( deps = [":single_file_proto"], ) +proto_library( + name = "single_file_extension_proto", + srcs = ["single_file_extensions.proto"], + deps = [":single_file_proto"], +) + +java_proto_library( + name = "single_file_extension_java_proto", + tags = [ + ], + deps = [":single_file_extension_proto"], +) + proto_library( name = "multi_file_proto", srcs = [ diff --git a/testing/src/test/resources/protos/single_file.proto b/testing/src/test/resources/protos/single_file.proto index b5ce518e0..8306cc16c 100644 --- a/testing/src/test/resources/protos/single_file.proto +++ b/testing/src/test/resources/protos/single_file.proto @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -syntax = "proto3"; +edition = "2024"; package dev.cel.testing.testdata; option java_package = "dev.cel.testing.testdata"; -option java_outer_classname = "SingleFileProto"; message SingleFile { message Path { @@ -26,5 +25,14 @@ message SingleFile { string name = 1; Path path = 2; - string snake_cased = 3 [json_name = "camelCased"]; + int32 int32_snake_case_json_name = 4 [json_name = "int32_snake_case_json_name"]; + int64 int64_camel_case_json_name = 5 [json_name = "int64CamelCaseJsonName"]; + uint32 uint32_default_json_name = 6; + uint64 uint64_custom_json_name = 7 [json_name = "uint64-custom-json-name"]; + + // Collides with normal field name. + string string_json_name_shadows = 8 [json_name = "single_string"]; + string single_string = 9; + + extensions 1000 to max; } diff --git a/testing/src/test/resources/protos/single_file_extensions.proto b/testing/src/test/resources/protos/single_file_extensions.proto new file mode 100644 index 000000000..9d18d38df --- /dev/null +++ b/testing/src/test/resources/protos/single_file_extensions.proto @@ -0,0 +1,27 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2024"; + +package dev.cel.testing.testdata; + +import "testing/src/test/resources/protos/single_file.proto"; + +option java_package = "dev.cel.testing.testdata"; +option features.enforce_naming_style = STYLE_LEGACY; + +extend SingleFile { + int64 int64CamelCaseJsonName = 1000; + string single_string = 1001; +} From 28d2b3cd2203bac039fd6402ebd328a02d6367c5 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 27 Mar 2026 12:55:08 -0700 Subject: [PATCH 11/66] Refactor test runner to accept required descriptors at the callsite, introduce BindingTransformer PiperOrigin-RevId: 890595249 --- .../dev/cel/conformance/ConformanceTest.java | 30 +++++-- .../conformance/ConformanceTestRunner.java | 9 +- .../dev/cel/testing/testrunner/BUILD.bazel | 7 ++ .../testing/testrunner/CelTestContext.java | 82 +++++++++++++++++++ .../CelTestSuiteTextProtoParser.java | 32 ++++++-- .../testrunner/CelTestSuiteYamlParser.java | 11 ++- .../testing/testrunner/TestRunnerLibrary.java | 42 ++++++++-- .../dev/cel/testing/utils/ExprValueUtils.java | 75 ++++++++--------- .../testrunner/TestRunnerLibraryTest.java | 71 ++++++++++++++++ 9 files changed, 290 insertions(+), 69 deletions(-) diff --git a/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java b/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java index 86f9b8f29..db57ccb79 100644 --- a/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java +++ b/conformance/src/test/java/dev/cel/conformance/ConformanceTest.java @@ -16,8 +16,6 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; -import static dev.cel.testing.utils.ExprValueUtils.DEFAULT_EXTENSION_REGISTRY; -import static dev.cel.testing.utils.ExprValueUtils.DEFAULT_TYPE_REGISTRY; import static dev.cel.testing.utils.ExprValueUtils.fromValue; import static dev.cel.testing.utils.ExprValueUtils.toExprValue; @@ -29,6 +27,8 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.TypeRegistry; import dev.cel.checker.CelChecker; import dev.cel.common.CelContainer; import dev.cel.common.CelOptions; @@ -84,6 +84,21 @@ public final class ConformanceTest extends Statement { CelExtensions.strings(), CelOptionalLibrary.INSTANCE); + static final TypeRegistry CONFORMANCE_TYPE_REGISTRY = + TypeRegistry.newBuilder() + .add(dev.cel.expr.conformance.proto2.TestAllTypes.getDescriptor()) + .add(dev.cel.expr.conformance.proto3.TestAllTypes.getDescriptor()) + .build(); + + static final ExtensionRegistry CONFORMANCE_EXTENSION_REGISTRY = + createConformanceExtensionRegistry(); + + private static ExtensionRegistry createConformanceExtensionRegistry() { + ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); + dev.cel.expr.conformance.proto2.TestAllTypesExtensions.registerAllExtensions(extensionRegistry); + return extensionRegistry; + } + private static final CelParser PARSER_WITH_MACROS = CelParserFactory.standardCelParserBuilder() .setOptions(OPTIONS) @@ -106,7 +121,7 @@ private static CelChecker getChecker(SimpleTest test) throws Exception { ImmutableList.Builder decls = ImmutableList.builderWithExpectedSize(test.getTypeEnvCount()); for (dev.cel.expr.Decl decl : test.getTypeEnvList()) { - decls.add(Decl.parseFrom(decl.toByteArray(), DEFAULT_EXTENSION_REGISTRY)); + decls.add(Decl.parseFrom(decl.toByteArray(), CONFORMANCE_EXTENSION_REGISTRY)); } return CelCompilerFactory.standardCelCheckerBuilder() .setOptions(OPTIONS) @@ -127,7 +142,7 @@ private static CelRuntime getRuntime(SimpleTest test, boolean usePlanner) { // CEL-Internal-2 .setOptions(OPTIONS) .addLibraries(CANONICAL_RUNTIME_EXTENSIONS) - .setExtensionRegistry(DEFAULT_EXTENSION_REGISTRY) + .setExtensionRegistry(CONFORMANCE_EXTENSION_REGISTRY) .addMessageTypes(dev.cel.expr.conformance.proto2.TestAllTypes.getDescriptor()) .addMessageTypes(dev.cel.expr.conformance.proto3.TestAllTypes.getDescriptor()) .addFileTypes(dev.cel.expr.conformance.proto2.TestAllTypesExtensions.getDescriptor()); @@ -151,7 +166,8 @@ private static ImmutableMap getBindings(SimpleTest test) throws private static Object fromExprValue(ExprValue value) throws Exception { switch (value.getKindCase()) { case VALUE: - return fromValue(value.getValue()); + return fromValue( + value.getValue(), CONFORMANCE_TYPE_REGISTRY, CONFORMANCE_EXTENSION_REGISTRY); default: throw new IllegalArgumentException( String.format("Unexpected binding value kind: %s", value.getKindCase())); @@ -224,7 +240,7 @@ public void evaluate() throws Throwable { assertThat(result) .ignoringRepeatedFieldOrderOfFieldDescriptors( MapValue.getDescriptor().findFieldByName("entries")) - .unpackingAnyUsing(DEFAULT_TYPE_REGISTRY, DEFAULT_EXTENSION_REGISTRY) + .unpackingAnyUsing(CONFORMANCE_TYPE_REGISTRY, CONFORMANCE_EXTENSION_REGISTRY) .isEqualTo(ExprValue.newBuilder().setValue(test.getValue()).build()); break; case EVAL_ERROR: @@ -237,7 +253,7 @@ public void evaluate() throws Throwable { assertThat(result) .ignoringRepeatedFieldOrderOfFieldDescriptors( MapValue.getDescriptor().findFieldByName("entries")) - .unpackingAnyUsing(DEFAULT_TYPE_REGISTRY, DEFAULT_EXTENSION_REGISTRY) + .unpackingAnyUsing(CONFORMANCE_TYPE_REGISTRY, CONFORMANCE_EXTENSION_REGISTRY) .isEqualTo(ExprValue.newBuilder().setValue(test.getTypedResult().getResult()).build()); assertThat(resultType).isEqualTo(test.getTypedResult().getDeducedType()); break; diff --git a/conformance/src/test/java/dev/cel/conformance/ConformanceTestRunner.java b/conformance/src/test/java/dev/cel/conformance/ConformanceTestRunner.java index dc3d5021e..4c3631d31 100644 --- a/conformance/src/test/java/dev/cel/conformance/ConformanceTestRunner.java +++ b/conformance/src/test/java/dev/cel/conformance/ConformanceTestRunner.java @@ -14,8 +14,7 @@ package dev.cel.conformance; -import static dev.cel.testing.utils.ExprValueUtils.DEFAULT_EXTENSION_REGISTRY; -import static dev.cel.testing.utils.ExprValueUtils.DEFAULT_TYPE_REGISTRY; +import static dev.cel.conformance.ConformanceTest.CONFORMANCE_EXTENSION_REGISTRY; import com.google.common.base.Preconditions; import com.google.common.base.Splitter; @@ -50,14 +49,16 @@ private static ImmutableSortedMap loadTestFiles() { SPLITTER.splitToList(System.getProperty("dev.cel.conformance.ConformanceTests.tests")); try { TextFormat.Parser parser = - TextFormat.Parser.newBuilder().setTypeRegistry(DEFAULT_TYPE_REGISTRY).build(); + TextFormat.Parser.newBuilder() + .setTypeRegistry(ConformanceTest.CONFORMANCE_TYPE_REGISTRY) + .build(); ImmutableSortedMap.Builder testFiles = ImmutableSortedMap.naturalOrder(); for (String testPath : testPaths) { SimpleTestFile.Builder fileBuilder = SimpleTestFile.newBuilder(); try (BufferedReader input = Files.newBufferedReader(Paths.get(testPath), StandardCharsets.UTF_8)) { - parser.merge(input, DEFAULT_EXTENSION_REGISTRY, fileBuilder); + parser.merge(input, CONFORMANCE_EXTENSION_REGISTRY, fileBuilder); } SimpleTestFile testFile = fileBuilder.build(); testFiles.put(testFile.getName(), testFile); diff --git a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel index 6924f753f..5af0665f9 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel @@ -92,6 +92,7 @@ java_library( "//bundle:environment", "//bundle:environment_yaml_parser", "//common:cel_ast", + "//common:cel_descriptor_util", "//common:compiler_common", "//common:options", "//common:proto_ast", @@ -134,6 +135,7 @@ java_library( ":cel_test_suite", ":cel_test_suite_exception", "//common:compiler_common", + "//common/annotations", "//common/formats:file_source", "//common/formats:parser_context", "//common/formats:yaml_helper", @@ -163,10 +165,14 @@ java_library( ":result_matcher", "//:auto_value", "//bundle:cel", + "//common:cel_descriptor_util", "//common:options", "//policy:parser", "//runtime", + "//testing/testrunner:proto_descriptor_utils", + "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", ], ) @@ -223,6 +229,7 @@ java_library( ":cel_test_suite", ":cel_test_suite_exception", ":registry_utils", + "//common/annotations", "@cel_spec//proto/cel/expr:expr_java_proto", "@cel_spec//proto/cel/expr/conformance/test:suite_java_proto", "@maven//:com_google_guava_guava", diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java index aa0d4b34f..5635b6152 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java @@ -14,12 +14,23 @@ package dev.cel.testing.testrunner; import com.google.auto.value.AutoValue; +import com.google.auto.value.extension.memoized.Memoized; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.TypeRegistry; import dev.cel.bundle.Cel; import dev.cel.bundle.CelFactory; +import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelOptions; import dev.cel.policy.CelPolicyParser; import dev.cel.runtime.CelLateFunctionBindings; +import dev.cel.testing.utils.ProtoDescriptorUtils; +import java.io.IOException; +import java.util.Arrays; import java.util.Map; import java.util.Optional; @@ -63,6 +74,19 @@ public abstract class CelTestContext { */ public abstract Optional celLateFunctionBindings(); + /** Interface for transforming bindings before evaluation. */ + @FunctionalInterface + public interface BindingTransformer { + ImmutableMap transform(ImmutableMap bindings) throws Exception; + } + + /** + * The binding transformer for the CEL test. + * + *

This transformer is used to transform the bindings before evaluation. + */ + public abstract Optional bindingTransformer(); + /** * The variable bindings for the CEL test. * @@ -99,6 +123,34 @@ public abstract class CelTestContext { */ public abstract Optional fileDescriptorSetPath(); + abstract ImmutableSet fileTypes(); + + @Memoized + public Optional typeRegistry() { + if (fileTypes().isEmpty() && !fileDescriptorSetPath().isPresent()) { + return Optional.empty(); + } + TypeRegistry.Builder builder = TypeRegistry.newBuilder(); + if (!fileTypes().isEmpty()) { + builder.add( + CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileTypes()) + .messageTypeDescriptors()); + } + if (fileDescriptorSetPath().isPresent()) { + try { + builder.add( + ProtoDescriptorUtils.getAllDescriptorsFromJvm(fileDescriptorSetPath().get()) + .messageTypeDescriptors()); + } catch (IOException e) { + throw new IllegalStateException( + "Failed to load descriptors from path: " + fileDescriptorSetPath().get(), e); + } + } + return Optional.of(builder.build()); + } + + public abstract Optional extensionRegistry(); + /** Returns a builder for {@link CelTestContext} with the current instance's values. */ public abstract Builder toBuilder(); @@ -123,6 +175,8 @@ public abstract static class Builder { public abstract Builder setCelLateFunctionBindings( CelLateFunctionBindings celLateFunctionBindings); + public abstract Builder setBindingTransformer(BindingTransformer bindingTransformer); + public abstract Builder setVariableBindings(Map variableBindings); public abstract Builder setResultMatcher(ResultMatcher resultMatcher); @@ -133,6 +187,34 @@ public abstract Builder setCelLateFunctionBindings( public abstract Builder setFileDescriptorSetPath(String fileDescriptorSetPath); + abstract ImmutableSet.Builder fileTypesBuilder(); + + @CanIgnoreReturnValue + public Builder addMessageTypes(Descriptor... descriptors) { + return addMessageTypes(Arrays.asList(descriptors)); + } + + @CanIgnoreReturnValue + public Builder addMessageTypes(Iterable descriptors) { + for (Descriptor descriptor : descriptors) { + addFileTypes(descriptor.getFile()); + } + return this; + } + + @CanIgnoreReturnValue + public Builder addFileTypes(FileDescriptor... fileDescriptors) { + return addFileTypes(Arrays.asList(fileDescriptors)); + } + + @CanIgnoreReturnValue + public Builder addFileTypes(Iterable fileDescriptors) { + fileTypesBuilder().addAll(fileDescriptors); + return this; + } + + public abstract Builder setExtensionRegistry(ExtensionRegistry extensionRegistry); + public abstract CelTestContext build(); } } diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java index 3819e38d2..5e7e62498 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java @@ -22,6 +22,7 @@ import com.google.protobuf.TextFormat; import com.google.protobuf.TextFormat.ParseException; import com.google.protobuf.TypeRegistry; +import dev.cel.common.annotations.Internal; import dev.cel.expr.conformance.test.InputValue; import dev.cel.expr.conformance.test.TestCase; import dev.cel.expr.conformance.test.TestSection; @@ -35,23 +36,40 @@ /** * CelTestSuiteTextProtoParser intakes a textproto document that describes the structure of a CEL * test suite, parses it then creates a {@link CelTestSuite}. + * + *

CEL Library Internals. Do Not Use. */ -final class CelTestSuiteTextProtoParser { +@Internal +public final class CelTestSuiteTextProtoParser { /** Creates a new instance of {@link CelTestSuiteTextProtoParser}. */ - static CelTestSuiteTextProtoParser newInstance() { + public static CelTestSuiteTextProtoParser newInstance() { return new CelTestSuiteTextProtoParser(); } - CelTestSuite parse(String textProto) throws IOException, CelTestSuiteException { - TestSuite testSuite = parseTestSuite(textProto); + public CelTestSuite parse(String textProto) throws IOException, CelTestSuiteException { + return parse( + textProto, TypeRegistry.getEmptyTypeRegistry(), ExtensionRegistry.getEmptyRegistry()); + } + + public CelTestSuite parse(String textProto, TypeRegistry customTypeRegistry) + throws IOException, CelTestSuiteException { + return parse(textProto, customTypeRegistry, ExtensionRegistry.getEmptyRegistry()); + } + + public CelTestSuite parse( + String textProto, TypeRegistry customTypeRegistry, ExtensionRegistry customExtensionRegistry) + throws IOException, CelTestSuiteException { + TestSuite testSuite = parseTestSuite(textProto, customTypeRegistry, customExtensionRegistry); return parseCelTestSuite(testSuite); } - private TestSuite parseTestSuite(String textProto) throws IOException { + private TestSuite parseTestSuite( + String textProto, TypeRegistry customTypeRegistry, ExtensionRegistry customExtensionRegistry) + throws IOException { String fileDescriptorSetPath = System.getProperty("file_descriptor_set_path"); - TypeRegistry typeRegistry = TypeRegistry.getEmptyTypeRegistry(); - ExtensionRegistry extensionRegistry = ExtensionRegistry.getEmptyRegistry(); + TypeRegistry typeRegistry = customTypeRegistry; + ExtensionRegistry extensionRegistry = customExtensionRegistry; if (fileDescriptorSetPath != null) { extensionRegistry = RegistryUtils.getExtensionRegistry(fileDescriptorSetPath); typeRegistry = RegistryUtils.getTypeRegistry(fileDescriptorSetPath); diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteYamlParser.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteYamlParser.java index d1a3d6615..71c4b9231 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteYamlParser.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteYamlParser.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import dev.cel.common.CelIssue; +import dev.cel.common.annotations.Internal; import dev.cel.common.formats.CelFileSource; import dev.cel.common.formats.ParserContext; import dev.cel.common.formats.YamlHelper.YamlNodeType; @@ -43,15 +44,18 @@ /** * CelTestSuiteYamlParser intakes a YAML document that describes the structure of a CEL test suite, * parses it then creates a {@link CelTestSuite}. + * + *

CEL Library Internals. Do Not Use. */ -final class CelTestSuiteYamlParser { +@Internal +public final class CelTestSuiteYamlParser { /** Creates a new instance of {@link CelTestSuiteYamlParser}. */ - static CelTestSuiteYamlParser newInstance() { + public static CelTestSuiteYamlParser newInstance() { return new CelTestSuiteYamlParser(); } - CelTestSuite parse(String celTestSuiteYamlContent) throws CelTestSuiteException { + public CelTestSuite parse(String celTestSuiteYamlContent) throws CelTestSuiteException { return parseYaml(celTestSuiteYamlContent, ""); } @@ -110,6 +114,7 @@ private CelTestSuite.Builder parseTestSuite(ParserContext ctx, Node node) case "description": builder.setDescription(newString(ctx, valueNode)); break; + case "section": case "sections": builder.setSections(parseSections(ctx, valueNode)); break; diff --git a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java index a5e912ccb..742b0cfb5 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java @@ -31,11 +31,13 @@ import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import com.google.protobuf.TextFormat; +import com.google.protobuf.TypeRegistry; import dev.cel.bundle.Cel; import dev.cel.bundle.CelEnvironment; import dev.cel.bundle.CelEnvironment.ExtensionConfig; import dev.cel.bundle.CelEnvironmentYamlParser; import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelOptions; import dev.cel.common.CelProtoAbstractSyntaxTree; import dev.cel.common.CelValidationException; @@ -205,6 +207,16 @@ private static Cel extendCel(CelTestContext celTestContext, CelOptions celOption .build(); } + if (!celTestContext.fileTypes().isEmpty()) { + extendedCel = + extendedCel + .toCelBuilder() + .addMessageTypes( + CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(celTestContext.fileTypes()) + .messageTypeDescriptors()) + .build(); + } + CelEnvironment environment = CelEnvironment.newBuilder().build(); // Extend the cel object with the config file if provided. @@ -302,8 +314,15 @@ private static Object getEvaluationResult( return getEvaluationResultWithMessage( getEvaluatedContextExpr(testCase, celTestContext), program, celCoverageIndex); case BINDINGS: - return getEvaluationResultWithBindings( - getBindings(testCase, celTestContext), program, celCoverageIndex); + ImmutableMap bindings = getBindings(testCase, celTestContext); + if (celTestContext.bindingTransformer().isPresent()) { + try { + bindings = celTestContext.bindingTransformer().get().transform(bindings); + } catch (Exception e) { + throw new CelEvaluationException("Binding transformation failed: " + e.getMessage(), e); + } + } + return getEvaluationResultWithBindings(bindings, program, celCoverageIndex); case NO_INPUT: ImmutableMap.Builder newBindings = ImmutableMap.builder(); for (Map.Entry entry : celTestContext.variableBindings().entrySet()) { @@ -396,10 +415,23 @@ private static Object evaluateInput(Cel cel, String expr) private static Object getValueFromBinding(Object value, CelTestContext celTestContext) throws IOException { if (value instanceof Value) { - if (celTestContext.fileDescriptorSetPath().isPresent()) { - return fromValue((Value) value, celTestContext.fileDescriptorSetPath().get()); + if (celTestContext.typeRegistry().isPresent() + || celTestContext.extensionRegistry().isPresent()) { + if (celTestContext.typeRegistry().isPresent()) { + ExtensionRegistry extensionRegistry = + celTestContext.extensionRegistry().orElse(ExtensionRegistry.getEmptyRegistry()); + return fromValue((Value) value, celTestContext.typeRegistry().get(), extensionRegistry); + } else if (celTestContext.extensionRegistry().isPresent()) { + return fromValue( + (Value) value, + TypeRegistry.newBuilder().build(), + celTestContext.extensionRegistry().get()); + } else if (celTestContext.fileDescriptorSetPath().isPresent()) { + return fromValue((Value) value, celTestContext.fileDescriptorSetPath().get()); + } } - return fromValue((Value) value); + return fromValue( + (Value) value, TypeRegistry.newBuilder().build(), ExtensionRegistry.getEmptyRegistry()); } return value; } diff --git a/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java b/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java index 041c0f52d..9bccecc95 100644 --- a/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java +++ b/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java @@ -28,8 +28,6 @@ import com.google.protobuf.Message; import com.google.protobuf.NullValue; import com.google.protobuf.TypeRegistry; -import dev.cel.common.CelDescriptorUtil; -import dev.cel.common.CelDescriptors; import dev.cel.common.internal.DefaultInstanceMessageFactory; import dev.cel.common.internal.ProtoTimeUtils; import dev.cel.common.types.CelType; @@ -55,8 +53,6 @@ public final class ExprValueUtils { private ExprValueUtils() {} - public static final TypeRegistry DEFAULT_TYPE_REGISTRY = newDefaultTypeRegistry(); - public static final ExtensionRegistry DEFAULT_EXTENSION_REGISTRY = newDefaultExtensionRegistry(); /** * Converts a {@link Value} to a Java native object using the given file descriptor set to parse @@ -68,10 +64,9 @@ private ExprValueUtils() {} * @throws IOException If there's an error during conversion. */ public static Object fromValue(Value value, String fileDescriptorSetPath) throws IOException { - if (value.getKindCase().equals(Value.KindCase.OBJECT_VALUE)) { - return parseAny(value.getObjectValue(), fileDescriptorSetPath); - } - return toNativeObject(value); + TypeRegistry typeRegistry = RegistryUtils.getTypeRegistry(fileDescriptorSetPath); + ExtensionRegistry extensionRegistry = RegistryUtils.getExtensionRegistry(fileDescriptorSetPath); + return fromValue(value, typeRegistry, extensionRegistry); } /** @@ -81,19 +76,38 @@ public static Object fromValue(Value value, String fileDescriptorSetPath) throws * @return The converted Java object. * @throws IOException If there's an error during conversion. */ - public static Object fromValue(Value value) throws IOException { + + /** + * Converts a {@link Value} to a Java native object using custom registries. + * + * @param value The {@link Value} to convert. + * @param typeRegistry The type registry to use for object resolution. + * @param extensionRegistry The extension registry to use for object resolution. + * @return The converted Java object. + * @throws IOException If there's an error during conversion. + */ + public static Object fromValue( + Value value, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) + throws IOException { if (value.getKindCase().equals(Value.KindCase.OBJECT_VALUE)) { Descriptor descriptor = - DEFAULT_TYPE_REGISTRY.getDescriptorForTypeUrl(value.getObjectValue().getTypeUrl()); + typeRegistry.getDescriptorForTypeUrl(value.getObjectValue().getTypeUrl()); + if (descriptor == null) { + throw new IOException( + "Unknown type, descriptor was not found in registry: " + + value.getObjectValue().getTypeUrl()); + } Message prototype = getDefaultInstance(descriptor); return prototype .getParserForType() - .parseFrom(value.getObjectValue().getValue(), DEFAULT_EXTENSION_REGISTRY); + .parseFrom(value.getObjectValue().getValue(), extensionRegistry); } - return toNativeObject(value); + return toNativeObject(value, typeRegistry, extensionRegistry); } - private static Object toNativeObject(Value value) throws IOException { + private static Object toNativeObject( + Value value, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) + throws IOException { switch (value.getKindCase()) { case NULL_VALUE: return dev.cel.common.values.NullValue.NULL_VALUE; @@ -118,7 +132,9 @@ private static Object toNativeObject(Value value) throws IOException { ImmutableMap.Builder builder = ImmutableMap.builderWithExpectedSize(map.getEntriesCount()); for (MapValue.Entry entry : map.getEntriesList()) { - builder.put(fromValue(entry.getKey()), fromValue(entry.getValue())); + builder.put( + fromValue(entry.getKey(), typeRegistry, extensionRegistry), + fromValue(entry.getValue(), typeRegistry, extensionRegistry)); } return builder.buildOrThrow(); } @@ -128,7 +144,7 @@ private static Object toNativeObject(Value value) throws IOException { ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(list.getValuesCount()); for (Value element : list.getValuesList()) { - builder.add(fromValue(element)); + builder.add(fromValue(element, typeRegistry, extensionRegistry)); } return builder.build(); } @@ -181,7 +197,7 @@ public static Value toValue(Object object, CelType type) throws Exception { if (object instanceof dev.cel.expr.Value) { object = Value.parseFrom( - ((dev.cel.expr.Value) object).toByteArray(), DEFAULT_EXTENSION_REGISTRY); + ((dev.cel.expr.Value) object).toByteArray(), ExtensionRegistry.getEmptyRegistry()); } if (object instanceof Value) { return (Value) object; @@ -287,19 +303,6 @@ public static Value toValue(Object object, CelType type) throws Exception { String.format("Unexpected result type: %s", object.getClass())); } - private static Message parseAny(Any value, String fileDescriptorSetPath) throws IOException { - TypeRegistry typeRegistry = RegistryUtils.getTypeRegistry(fileDescriptorSetPath); - ExtensionRegistry extensionRegistry = RegistryUtils.getExtensionRegistry(fileDescriptorSetPath); - Descriptor descriptor = typeRegistry.getDescriptorForTypeUrl(value.getTypeUrl()); - return unpackAny(value, descriptor, extensionRegistry); - } - - private static Message unpackAny( - Any value, Descriptor descriptor, ExtensionRegistry extensionRegistry) throws IOException { - Message defaultInstance = getDefaultInstance(descriptor); - return defaultInstance.getParserForType().parseFrom(value.getValue(), extensionRegistry); - } - private static Message getDefaultInstance(Descriptor descriptor) { return DefaultInstanceMessageFactory.getInstance() .getPrototype(descriptor) @@ -309,20 +312,6 @@ private static Message getDefaultInstance(Descriptor descriptor) { "Could not find a default message for: " + descriptor.getFullName())); } - private static ExtensionRegistry newDefaultExtensionRegistry() { - ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); - dev.cel.expr.conformance.proto2.TestAllTypesExtensions.registerAllExtensions(extensionRegistry); - return extensionRegistry; - } - private static TypeRegistry newDefaultTypeRegistry() { - CelDescriptors allDescriptors = - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - ImmutableList.of( - dev.cel.expr.conformance.proto2.TestAllTypes.getDescriptor().getFile(), - dev.cel.expr.conformance.proto3.TestAllTypes.getDescriptor().getFile())); - - return TypeRegistry.newBuilder().add(allDescriptors.messageTypeDescriptors()).build(); - } } diff --git a/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java b/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java index d5a5248a8..b83375b35 100644 --- a/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java +++ b/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java @@ -26,6 +26,7 @@ import dev.cel.common.types.SimpleType; import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase; +import java.util.Map; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -281,4 +282,74 @@ public void triggerRunTest_evaluateRawExpr_withCoverage() throws Exception { .build(), celCoverageIndex); } + + @Test + public void runTest_withBindingTransformer() throws Exception { + CelTestCase testCase = + CelTestCase.newBuilder() + .setName("binding_transformer_test") + .setDescription("Test binding transformer") + .setInput( + CelTestCase.Input.ofBindings( + ImmutableMap.of("x", CelTestCase.Input.Binding.ofValue(1L)))) + .setOutput(CelTestCase.Output.ofResultValue(3L)) // 1 + 1 (transformed) + 1 (expr) = 3 + .build(); + + TestRunnerLibrary.evaluateTestCase( + testCase, + CelTestContext.newBuilder() + .setCelExpression(CelExpressionSource.fromRawExpr("x + 1")) + .setCel(CelFactory.standardCelBuilder().addVar("x", SimpleType.INT).build()) + .setBindingTransformer( + bindings -> { + ImmutableMap.Builder transformed = ImmutableMap.builder(); + for (Map.Entry entry : bindings.entrySet()) { + if (entry.getKey().equals("x")) { + transformed.put("x", (Long) entry.getValue() + 1L); + } else { + transformed.put(entry); + } + } + return transformed.buildOrThrow(); + }) + .build()); + } + + @Test + public void runTest_withMessageTypes() throws Exception { + CelTestCase testCase = + CelTestCase.newBuilder() + .setName("message_types_consolidation_test") + .setDescription("Test message types consolidation") + .setOutput(CelTestCase.Output.ofResultValue(true)) + .build(); + + TestRunnerLibrary.evaluateTestCase( + testCase, + CelTestContext.newBuilder() + .setCelExpression( + CelExpressionSource.fromRawExpr( + "cel.expr.conformance.proto3.TestAllTypes{single_int64: 1} ==" + + " cel.expr.conformance.proto3.TestAllTypes{single_int64: 1}")) + .addMessageTypes(TestAllTypes.getDescriptor()) + .build()); + } + + @Test + public void typeRegistry_withFileTypes() throws Exception { + CelTestContext celTestContext = + CelTestContext.newBuilder() + .setCelExpression(CelExpressionSource.fromRawExpr("true")) + .setCel(CelFactory.standardCelBuilder().build()) + .addMessageTypes(TestAllTypes.getDescriptor()) + .build(); + + assertThat( + celTestContext + .typeRegistry() + .get() + .find("cel.expr.conformance.proto3.TestAllTypes") + .getFullName()) + .isEqualTo("cel.expr.conformance.proto3.TestAllTypes"); + } } From 13568dfba2d0dc424b56183579bf0fe114024397 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 27 Mar 2026 14:10:01 -0700 Subject: [PATCH 12/66] Internal Changes PiperOrigin-RevId: 890628970 --- policy/src/main/java/dev/cel/policy/CelPolicy.java | 13 +++++++++---- .../dev/cel/testing/testrunner/CelTestSuite.java | 5 +++-- .../cel/testing/testrunner/TestRunnerLibrary.java | 7 +++++++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/policy/src/main/java/dev/cel/policy/CelPolicy.java b/policy/src/main/java/dev/cel/policy/CelPolicy.java index 9980d0cad..9e442a2e7 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicy.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicy.java @@ -27,6 +27,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -77,8 +78,7 @@ public abstract static class Builder { public abstract Builder setPolicySource(CelPolicySource policySource); - // This should stay package-private to encourage add/set methods to be used instead. - abstract ImmutableMap.Builder metadataBuilder(); + private final HashMap metadata = new HashMap<>(); public abstract Builder setMetadata(ImmutableMap value); @@ -90,6 +90,10 @@ public List imports() { return Collections.unmodifiableList(importList); } + public Map metadata() { + return Collections.unmodifiableMap(metadata); + } + @CanIgnoreReturnValue public Builder addImport(Import value) { importList.add(value); @@ -104,13 +108,13 @@ public Builder addImports(Collection values) { @CanIgnoreReturnValue public Builder putMetadata(String key, Object value) { - metadataBuilder().put(key, value); + metadata.put(key, value); return this; } @CanIgnoreReturnValue public Builder putMetadata(Map map) { - metadataBuilder().putAll(map); + metadata.putAll(map); return this; } @@ -118,6 +122,7 @@ public Builder putMetadata(Map map) { public CelPolicy build() { setImports(ImmutableList.copyOf(importList)); + setMetadata(ImmutableMap.copyOf(metadata)); return autoBuild(); } } diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuite.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuite.java index e6086f128..a8869a8fb 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuite.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuite.java @@ -93,7 +93,7 @@ public abstract static class Builder { public abstract Builder toBuilder(); public static Builder newBuilder() { - return new AutoValue_CelTestSuite_CelTestSection.Builder(); + return new AutoValue_CelTestSuite_CelTestSection.Builder().setDescription(""); } /** Class representing a CEL test case within a test section. */ @@ -237,7 +237,8 @@ public abstract static class Builder { public static Builder newBuilder() { return new AutoValue_CelTestSuite_CelTestSection_CelTestCase.Builder() - .setInput(Input.ofNoInput()); // Default input to no input. + .setInput(Input.ofNoInput()) // Default input to no input. + .setDescription(""); } } } diff --git a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java index 742b0cfb5..2465d330e 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java @@ -106,6 +106,13 @@ public static void runTest( } } + /** Runs the test with the provided AST. */ + public static void runTest( + CelAbstractSyntaxTree ast, CelTestCase testCase, CelTestContext celTestContext) + throws Exception { + evaluate(ast, testCase, celTestContext, /* celCoverageIndex= */ null); + } + @VisibleForTesting static void evaluateTestCase(CelTestCase testCase, CelTestContext celTestContext) throws Exception { From 75fafc91fee010ca1a7c719d07df408bad265b88 Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Fri, 27 Mar 2026 16:20:13 -0500 Subject: [PATCH 13/66] Add string extensions quote and reverse Add `strings.quote` and `reverse` extensions to match Go implementations. --- .../cel/extensions/CelStringExtensions.java | 76 +++++++++++++++++++ .../extensions/CelStringExtensionsTest.java | 54 +++++++++++++ 2 files changed, 130 insertions(+) diff --git a/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java index 10caa7db8..faf30c2b2 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java @@ -137,6 +137,17 @@ public enum Function { SimpleType.STRING, SimpleType.STRING)), CelFunctionBinding.from("string_lower_ascii", String.class, Ascii::toLowerCase)), + QUOTE( + CelFunctionDecl.newFunctionDeclaration( + "strings.quote", + CelOverloadDecl.newGlobalOverload( + "strings_quote", + "Takes the given string and makes it safe to print (without any formatting" + + " due to escape sequences). If any invalid UTF-8 characters are" + + " encountered, they are replaced with \\uFFFD.", + SimpleType.STRING, + ImmutableList.of(SimpleType.STRING))), + CelFunctionBinding.from("strings_quote", String.class, CelStringExtensions::quote)), REPLACE( CelFunctionDecl.newFunctionDeclaration( "replace", @@ -164,6 +175,16 @@ public enum Function { "string_replace_string_string_int", ImmutableList.of(String.class, String.class, String.class, Long.class), CelStringExtensions::replace)), + REVERSE( + CelFunctionDecl.newFunctionDeclaration( + "reverse", + CelOverloadDecl.newMemberOverload( + "string_reverse", + "Returns a new string whose characters are the same as the target string," + + " only formatted in reverse order.", + SimpleType.STRING, + SimpleType.STRING)), + CelFunctionBinding.from("string_reverse", String.class, CelStringExtensions::reverse)), SPLIT( CelFunctionDecl.newFunctionDeclaration( "split", @@ -449,6 +470,57 @@ private static Long lastIndexOf(CelCodePointArray str, CelCodePointArray substr, return -1L; } + private static String quote(String s) { + StringBuilder sb = new StringBuilder(s.length() + 2); + sb.append('"'); + for (int i = 0; i < s.length(); ) { + int codePoint = s.codePointAt(i); + if (!Character.isValidCodePoint(codePoint) + || Character.isLowSurrogate(s.charAt(i)) + || (Character.isHighSurrogate(s.charAt(i)) + && (i + 1 >= s.length() || !Character.isLowSurrogate(s.charAt(i + 1))))) { + sb.append('\uFFFD'); + i++; + continue; + } + switch (codePoint) { + case '\u0007': + sb.append("\\a"); + break; + case '\b': + sb.append("\\b"); + break; + case '\f': + sb.append("\\f"); + break; + case '\n': + sb.append("\\n"); + break; + case '\r': + sb.append("\\r"); + break; + case '\t': + sb.append("\\t"); + break; + case '\u000B': + sb.append("\\v"); + break; + case '\\': + sb.append("\\\\"); + break; + case '"': + sb.append("\\\""); + break; + default: + sb.appendCodePoint(codePoint); + break; + } + i += Character.charCount(codePoint); + } + sb.append('"'); + return sb.toString(); + } + private static String replaceAll(Object[] objects) { return replace((String) objects[0], (String) objects[1], (String) objects[2], -1); } @@ -504,6 +576,10 @@ private static String replace(String text, String searchString, String replaceme return sb.append(textCpa.slice(start, textCpa.length())).toString(); } + private static String reverse(String s) { + return new StringBuilder(s).reverse().toString(); + } + private static List split(String str, String separator) { return split(str, separator, Integer.MAX_VALUE); } diff --git a/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java index 6ea9b702c..ad0d6d679 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java @@ -70,7 +70,9 @@ public void library() { "lastIndexOf", "lowerAscii", "replace", + "reverse", "split", + "strings.quote", "substring", "trim", "upperAscii"); @@ -1467,6 +1469,58 @@ public void stringExtension_functionSubset_success() throws Exception { assertThat(evaluatedResult).isEqualTo(true); } + @Test + @TestParameters("{string: 'abcd', expectedResult: 'dcba'}") + @TestParameters("{string: '', expectedResult: ''}") + @TestParameters("{string: 'a', expectedResult: 'a'}") + @TestParameters("{string: 'hello world', expectedResult: 'dlrow olleh'}") + @TestParameters("{string: 'ab가cd', expectedResult: 'dc가ba'}") + public void reverse_success(String string, String expectedResult) throws Exception { + CelAbstractSyntaxTree ast = COMPILER.compile("s.reverse()").getAst(); + CelRuntime.Program program = RUNTIME.createProgram(ast); + + Object evaluatedResult = program.eval(ImmutableMap.of("s", string)); + + assertThat(evaluatedResult).isEqualTo(expectedResult); + } + + @Test + public void reverse_unicode() throws Exception { + CelAbstractSyntaxTree ast = COMPILER.compile("s.reverse()").getAst(); + CelRuntime.Program program = RUNTIME.createProgram(ast); + + Object evaluatedResult = program.eval(ImmutableMap.of("s", "😁😑😦")); + + assertThat(evaluatedResult).isEqualTo("😦😑😁"); + } + + @Test + @TestParameters("{string: 'hello', expectedResult: '\"hello\"'}") + @TestParameters("{string: '', expectedResult: '\"\"'}") + @TestParameters("{string: 'contains \\\"quotes\\\"', expectedResult: '\"contains \\\\\\\"quotes\\\\\\\"\"'}") + public void quote_success(String string, String expectedResult) throws Exception { + CelAbstractSyntaxTree ast = COMPILER.compile("strings.quote(s)").getAst(); + CelRuntime.Program program = RUNTIME.createProgram(ast); + + Object evaluatedResult = program.eval(ImmutableMap.of("s", string)); + + assertThat(evaluatedResult).isEqualTo(expectedResult); + } + + @Test + public void quote_escapesSpecialCharacters() throws Exception { + CelAbstractSyntaxTree ast = COMPILER.compile("strings.quote(s)").getAst(); + CelRuntime.Program program = RUNTIME.createProgram(ast); + + Object evaluatedResult = + program.eval( + ImmutableMap.of( + "s", "\u0007bell\u000Bvtab\bback\ffeed\rret\nline\ttab\\slash 가 😁")); + + assertThat(evaluatedResult) + .isEqualTo("\"\\abell\\vvtab\\bback\\ffeed\\rret\\nline\\ttab\\\\slash 가 😁\""); + } + @Test public void stringExtension_compileUnallowedFunction_throws() { CelCompiler celCompiler = From 6207811efff2d3b1d3dd5f0528e67906a90a955c Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Fri, 27 Mar 2026 17:48:43 -0500 Subject: [PATCH 14/66] Enable strings.quote conformance tests --- conformance/src/test/java/dev/cel/conformance/BUILD.bazel | 2 -- 1 file changed, 2 deletions(-) diff --git a/conformance/src/test/java/dev/cel/conformance/BUILD.bazel b/conformance/src/test/java/dev/cel/conformance/BUILD.bazel index fb2b1a159..ea9041433 100644 --- a/conformance/src/test/java/dev/cel/conformance/BUILD.bazel +++ b/conformance/src/test/java/dev/cel/conformance/BUILD.bazel @@ -120,7 +120,6 @@ _TESTS_TO_SKIP_LEGACY = [ # Skip until fixed. "fields/qualified_identifier_resolution/map_value_repeat_key_heterogeneous", # TODO: Add strings.format and strings.quote. - "string_ext/quote", "string_ext/format", "string_ext/format_errors", @@ -149,7 +148,6 @@ _TESTS_TO_SKIP_LEGACY = [ _TESTS_TO_SKIP_PLANNER = [ # TODO: Add strings.format and strings.quote. - "string_ext/quote", "string_ext/format", "string_ext/format_errors", From 0e1a30f5eb31d0f1a313a755d253ca0f1a3389d8 Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Fri, 27 Mar 2026 17:55:53 -0500 Subject: [PATCH 15/66] Fix comments in bazel TESTS_TO_SKIP --- conformance/src/test/java/dev/cel/conformance/BUILD.bazel | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/conformance/src/test/java/dev/cel/conformance/BUILD.bazel b/conformance/src/test/java/dev/cel/conformance/BUILD.bazel index ea9041433..c0e7ad2bc 100644 --- a/conformance/src/test/java/dev/cel/conformance/BUILD.bazel +++ b/conformance/src/test/java/dev/cel/conformance/BUILD.bazel @@ -119,7 +119,7 @@ _TESTS_TO_SKIP_LEGACY = [ # Skip until fixed. "fields/qualified_identifier_resolution/map_value_repeat_key_heterogeneous", - # TODO: Add strings.format and strings.quote. + # TODO: Add strings.format. "string_ext/format", "string_ext/format_errors", @@ -147,7 +147,7 @@ _TESTS_TO_SKIP_LEGACY = [ ] _TESTS_TO_SKIP_PLANNER = [ - # TODO: Add strings.format and strings.quote. + # TODO: Add strings.format. "string_ext/format", "string_ext/format_errors", From 57e795ad58d77dd462864a8e322fecd4a56514be Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Fri, 27 Mar 2026 18:13:28 -0500 Subject: [PATCH 16/66] Attempt to fix failing test --- .../src/test/java/dev/cel/extensions/CelExtensionsTest.java | 1 + 1 file changed, 1 insertion(+) diff --git a/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java index 61922f70f..192630ea3 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java @@ -168,6 +168,7 @@ public void getAllFunctionNames() { "join", "lastIndexOf", "lowerAscii", + "strings.quote", "replace", "split", "substring", From 943ae333723ed1ffe05975cdd4437d1b35a546fe Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Sun, 29 Mar 2026 10:43:53 -0500 Subject: [PATCH 17/66] Add additional tests and address review feedback --- .../cel/extensions/CelStringExtensions.java | 18 +++++++-- .../extensions/CelStringExtensionsTest.java | 37 +++++++++++++++++-- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java index faf30c2b2..e89b81071 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java @@ -475,10 +475,7 @@ private static String quote(String s) { sb.append('"'); for (int i = 0; i < s.length(); ) { int codePoint = s.codePointAt(i); - if (!Character.isValidCodePoint(codePoint) - || Character.isLowSurrogate(s.charAt(i)) - || (Character.isHighSurrogate(s.charAt(i)) - && (i + 1 >= s.length() || !Character.isLowSurrogate(s.charAt(i + 1))))) { + if (isMalformedUtf16(s, i, codePoint)) { sb.append('\uFFFD'); i++; continue; @@ -521,6 +518,19 @@ private static String quote(String s) { return sb.toString(); } + private static boolean isMalformedUtf16(String s, int index, int codePoint) { + char currentChar = s.charAt(index); + if (!Character.isValidCodePoint(codePoint)) { + return true; + } + if (Character.isLowSurrogate(currentChar)) { + return true; + } + // Check for unpaired high surrogate + return Character.isHighSurrogate(currentChar) + && (index + 1 >= s.length() || !Character.isLowSurrogate(s.charAt(index + 1))); + } + private static String replaceAll(Object[] objects) { return replace((String) objects[0], (String) objects[1], (String) objects[2], -1); } diff --git a/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java index ad0d6d679..58a1bff99 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java @@ -33,6 +33,8 @@ import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelRuntime; import dev.cel.runtime.CelRuntimeFactory; + +import java.nio.charset.StandardCharsets; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -1485,19 +1487,23 @@ public void reverse_success(String string, String expectedResult) throws Excepti } @Test - public void reverse_unicode() throws Exception { + @TestParameters("{string: '😁😑😦', expectedResult: '😦😑😁'}") + @TestParameters("{string: '\u180e\u200b\u200c\u200d\u2060\ufeff', expectedResult: '\ufeff\u2060\u200d\u200c\u200b\u180e'}") + public void reverse_unicode(String string, String expectedResult) throws Exception { CelAbstractSyntaxTree ast = COMPILER.compile("s.reverse()").getAst(); CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = program.eval(ImmutableMap.of("s", "😁😑😦")); + Object evaluatedResult = program.eval(ImmutableMap.of("s", string)); - assertThat(evaluatedResult).isEqualTo("😦😑😁"); + assertThat(evaluatedResult).isEqualTo(expectedResult); } @Test @TestParameters("{string: 'hello', expectedResult: '\"hello\"'}") @TestParameters("{string: '', expectedResult: '\"\"'}") @TestParameters("{string: 'contains \\\"quotes\\\"', expectedResult: '\"contains \\\\\\\"quotes\\\\\\\"\"'}") + @TestParameters("{string: 'ends with \\\\', expectedResult: '\"ends with \\\\\\\\\"'}") + @TestParameters("{string: '\\\\ starts with', expectedResult: '\"\\\\\\\\ starts with\"'}") public void quote_success(String string, String expectedResult) throws Exception { CelAbstractSyntaxTree ast = COMPILER.compile("strings.quote(s)").getAst(); CelRuntime.Program program = RUNTIME.createProgram(ast); @@ -1507,6 +1513,18 @@ public void quote_success(String string, String expectedResult) throws Exception assertThat(evaluatedResult).isEqualTo(expectedResult); } + @Test + public void quote_singleWithDoubleQuotes() throws Exception { + CelAbstractSyntaxTree ast = COMPILER.compile( + "strings.quote('single-quote with \"double quote\"') == \"\\\"single-quote with \\\\\\\"double quote\\\\\\\"\\\"\"" + ).getAst(); + CelRuntime.Program program = RUNTIME.createProgram(ast); + + Object evaluatedResult = program.eval(); + + assertThat(evaluatedResult).isEqualTo(true); + } + @Test public void quote_escapesSpecialCharacters() throws Exception { CelAbstractSyntaxTree ast = COMPILER.compile("strings.quote(s)").getAst(); @@ -1521,6 +1539,19 @@ public void quote_escapesSpecialCharacters() throws Exception { .isEqualTo("\"\\abell\\vvtab\\bback\\ffeed\\rret\\nline\\ttab\\\\slash 가 😁\""); } + @Test + public void quote_escapesMalformed() throws Exception { + CelAbstractSyntaxTree ast = COMPILER.compile("strings.quote(s)").getAst(); + CelRuntime.Program program = RUNTIME.createProgram(ast); + + Object evaluatedResult = + program.eval( + ImmutableMap.of( + "s", new String(new byte[]{'f','i','l','l','e','r',' ',(byte)0x9f}, StandardCharsets.UTF_8))); + + assertThat(evaluatedResult).isEqualTo("\"filler \uFFFD\""); + } + @Test public void stringExtension_compileUnallowedFunction_throws() { CelCompiler celCompiler = From 03d0a3621f721e05bbd5eecf6234a588246a397f Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Sun, 29 Mar 2026 11:57:43 -0500 Subject: [PATCH 18/66] Additional malformed unicode tests --- .../extensions/CelStringExtensionsTest.java | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java index 58a1bff99..0a3c595be 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java @@ -1540,16 +1540,28 @@ public void quote_escapesSpecialCharacters() throws Exception { } @Test - public void quote_escapesMalformed() throws Exception { + @TestParameters({"{rawString: !!binary 'ZmlsbGVyIJ8=', expectedResult: '\"filler \uFFFD\"'}"}) // "filler \x9f" + public void quote_escapesMalformed(byte[] rawString, String expectedResult) throws Exception { CelAbstractSyntaxTree ast = COMPILER.compile("strings.quote(s)").getAst(); CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = - program.eval( - ImmutableMap.of( - "s", new String(new byte[]{'f','i','l','l','e','r',' ',(byte)0x9f}, StandardCharsets.UTF_8))); + Object evaluatedResult = program.eval(ImmutableMap.of("s", new String(rawString, StandardCharsets.UTF_8))); + + assertThat(evaluatedResult).isEqualTo(expectedResult); + } - assertThat(evaluatedResult).isEqualTo("\"filler \uFFFD\""); + @Test + public void quote_escapesMalformed_endWithHighSurrogate() throws Exception { + CelRuntime.Program program = RUNTIME.createProgram(COMPILER.compile("strings.quote(s)").getAst()); + assertThat(program.eval(ImmutableMap.of("s", "end with high surrogate \uD83D"))) + .isEqualTo("\"end with high surrogate \uFFFD\""); + } + + @Test + public void quote_escapesMalformed_unpairedHighSurrogate() throws Exception { + CelRuntime.Program program = RUNTIME.createProgram(COMPILER.compile("strings.quote(s)").getAst()); + assertThat(program.eval(ImmutableMap.of("s", "bad pair \uD83DA"))) + .isEqualTo("\"bad pair \uFFFDA\""); } @Test From 79ff206b89f6b54767e4dd84f648b049d560ea89 Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Sun, 29 Mar 2026 13:53:13 -0500 Subject: [PATCH 19/66] Add quote and reverse functions to docs --- .../main/java/dev/cel/extensions/README.md | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/extensions/src/main/java/dev/cel/extensions/README.md b/extensions/src/main/java/dev/cel/extensions/README.md index 10c5217e8..c3fbf8c54 100644 --- a/extensions/src/main/java/dev/cel/extensions/README.md +++ b/extensions/src/main/java/dev/cel/extensions/README.md @@ -474,6 +474,19 @@ Examples: 'TacoCat'.lowerAscii() // returns 'tacocat' 'TacoCÆt Xii'.lowerAscii() // returns 'tacocÆt xii' +### Quote + +Takes the given string and makes it safe to print (without any formatting due +to escape sequences). +If any invalid UTF-8 characters are encountered, they are replaced with \uFFFD. + + strings.quote() + +Examples: + + strings.quote('single-quote with "double quote"') // returns '"single-quote with \"double quote\""' + strings.quote("two escape sequences \a\n") // returns '"two escape sequences \\a\\n"' + ### Replace Returns a new string based on the target, which replaces the occurrences of a @@ -493,6 +506,20 @@ Examples: 'hello hello'.replace('he', 'we', 1) // returns 'wello hello' 'hello hello'.replace('he', 'we', 0) // returns 'hello hello' +### Reverse + +Returns a new string whose characters are the same as the target string, only +formatted in reverse order. +This function relies on converting strings to Unicode code point arrays in +order to reverse. + + .reverse() -> + +Examples: + + 'gums'.reverse() // returns 'smug' + 'John Smith'.reverse() // returns 'htimS nhoJ' + ### Split Returns a mutable list of strings split from the input by the given separator. The From 576064d30bf241a96fdacda3660d5067d6da8f92 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 30 Mar 2026 10:53:04 -0700 Subject: [PATCH 20/66] Enable quoted identifiers by default PiperOrigin-RevId: 891799347 --- common/src/main/java/dev/cel/common/CelOptions.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/src/main/java/dev/cel/common/CelOptions.java b/common/src/main/java/dev/cel/common/CelOptions.java index 0c348c8a8..c39e0fea8 100644 --- a/common/src/main/java/dev/cel/common/CelOptions.java +++ b/common/src/main/java/dev/cel/common/CelOptions.java @@ -139,7 +139,7 @@ public static Builder newBuilder() { .retainRepeatedUnaryOperators(false) .retainUnbalancedLogicalExpressions(false) .enableHiddenAccumulatorVar(true) - .enableQuotedIdentifierSyntax(false) + .enableQuotedIdentifierSyntax(true) // Type-Checker options .enableCompileTimeOverloadResolution(false) .enableHomogeneousLiterals(false) From aaec509f18bcf65193188a0791d6c8980ef53316 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Mon, 30 Mar 2026 15:13:03 -0700 Subject: [PATCH 21/66] Switch from enhanced for-loop to indexed one to improve comprehension performance PiperOrigin-RevId: 891934916 --- .../dev/cel/runtime/planner/NamespacedAttribute.java | 10 ++++++---- .../dev/cel/runtime/planner/RelativeAttribute.java | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java b/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java index ed37eada1..d51336d80 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java @@ -77,8 +77,9 @@ public Object resolve(long exprId, GlobalResolver ctx, ExecutionFrame frame) { if (partialVars != null) { ImmutableList patterns = partialVars.unknowns(); - for (Qualifier qualifier : qualifiers) { - attr = attr.qualify(CelAttribute.Qualifier.fromGeneric(qualifier.value())); + // Avoid enhanced for loop to prevent UnmodifiableIterator from being allocated + for (int i = 0; i < qualifiers.size(); i++) { + attr = attr.qualify(CelAttribute.Qualifier.fromGeneric(qualifiers.get(i).value())); } CelAttributePattern partialMatch = findPartialMatchingPattern(attr, patterns).orElse(null); @@ -178,8 +179,9 @@ private static Object applyQualifiers( Object value, CelValueConverter celValueConverter, ImmutableList qualifiers) { Object obj = celValueConverter.toRuntimeValue(value); - for (Qualifier qualifier : qualifiers) { - obj = qualifier.qualify(obj); + // Avoid enhanced for loop to prevent UnmodifiableIterator from being allocated + for (int i = 0; i < qualifiers.size(); i++) { + obj = qualifiers.get(i).qualify(obj); } return celValueConverter.maybeUnwrap(obj); diff --git a/runtime/src/main/java/dev/cel/runtime/planner/RelativeAttribute.java b/runtime/src/main/java/dev/cel/runtime/planner/RelativeAttribute.java index 1ab2fa3e7..addbeb4d0 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/RelativeAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/RelativeAttribute.java @@ -40,8 +40,9 @@ public Object resolve(long exprId, GlobalResolver ctx, ExecutionFrame frame) { obj = celValueConverter.toRuntimeValue(obj); - for (Qualifier qualifier : qualifiers) { - obj = qualifier.qualify(obj); + // Avoid enhanced for loop to prevent UnmodifiableIterator from being allocated + for (int i = 0; i < qualifiers.size(); i++) { + obj = qualifiers.get(i).qualify(obj); } return celValueConverter.maybeUnwrap(obj); From 6ca2c4d48013bbe7d96fe5e72657f0e0330b7179 Mon Sep 17 00:00:00 2001 From: "Philip K. Warren" Date: Mon, 30 Mar 2026 21:30:22 -0500 Subject: [PATCH 22/66] Remove test and fix line length --- .../cel/extensions/CelStringExtensionsTest.java | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java index 0a3c595be..27152191b 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java @@ -1515,9 +1515,9 @@ public void quote_success(String string, String expectedResult) throws Exception @Test public void quote_singleWithDoubleQuotes() throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile( - "strings.quote('single-quote with \"double quote\"') == \"\\\"single-quote with \\\\\\\"double quote\\\\\\\"\\\"\"" - ).getAst(); + String expr = "strings.quote('single-quote with \"double quote\"')"; + String expected = "\"\\\"single-quote with \\\\\\\"double quote\\\\\\\"\\\"\""; + CelAbstractSyntaxTree ast = COMPILER.compile(expr + " == " + expected).getAst(); CelRuntime.Program program = RUNTIME.createProgram(ast); Object evaluatedResult = program.eval(); @@ -1539,17 +1539,6 @@ public void quote_escapesSpecialCharacters() throws Exception { .isEqualTo("\"\\abell\\vvtab\\bback\\ffeed\\rret\\nline\\ttab\\\\slash 가 😁\""); } - @Test - @TestParameters({"{rawString: !!binary 'ZmlsbGVyIJ8=', expectedResult: '\"filler \uFFFD\"'}"}) // "filler \x9f" - public void quote_escapesMalformed(byte[] rawString, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("strings.quote(s)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", new String(rawString, StandardCharsets.UTF_8))); - - assertThat(evaluatedResult).isEqualTo(expectedResult); - } - @Test public void quote_escapesMalformed_endWithHighSurrogate() throws Exception { CelRuntime.Program program = RUNTIME.createProgram(COMPILER.compile("strings.quote(s)").getAst()); From cbe0104bac6fd115bdc46f06faca6fdf5462d225 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 31 Mar 2026 13:26:56 -0700 Subject: [PATCH 23/66] Optimize unary and binary function calls to avoid array allocation PiperOrigin-RevId: 892511211 --- .../dev/cel/runtime/CelFunctionBinding.java | 29 ++++++- .../dev/cel/runtime/CelFunctionOverload.java | 61 +++++++++++---- .../dev/cel/runtime/DefaultDispatcher.java | 66 +++++++++------- .../dev/cel/runtime/FunctionBindingImpl.java | 29 +++++++ .../java/dev/cel/runtime/planner/BUILD.bazel | 16 ++++ .../dev/cel/runtime/planner/EvalBinary.java | 75 +++++++++++++++++++ .../dev/cel/runtime/planner/EvalHelpers.java | 45 ++++++++--- .../dev/cel/runtime/planner/EvalUnary.java | 4 +- .../cel/runtime/planner/ProgramPlanner.java | 3 + 9 files changed, 271 insertions(+), 57 deletions(-) create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/EvalBinary.java diff --git a/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java b/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java index c7b63926b..06e5facdf 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java +++ b/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java @@ -54,7 +54,20 @@ public interface CelFunctionBinding { @SuppressWarnings("unchecked") static CelFunctionBinding from( String overloadId, Class arg, CelFunctionOverload.Unary impl) { - return from(overloadId, ImmutableList.of(arg), (args) -> impl.apply((T) args[0])); + return from( + overloadId, + ImmutableList.of(arg), + new CelFunctionOverload() { + @Override + public Object apply(Object[] args) throws CelEvaluationException { + return impl.apply((T) args[0]); + } + + @Override + public Object apply(Object arg1) throws CelEvaluationException { + return impl.apply((T) arg1); + } + }); } /** @@ -65,7 +78,19 @@ static CelFunctionBinding from( static CelFunctionBinding from( String overloadId, Class arg1, Class arg2, CelFunctionOverload.Binary impl) { return from( - overloadId, ImmutableList.of(arg1, arg2), (args) -> impl.apply((T1) args[0], (T2) args[1])); + overloadId, + ImmutableList.of(arg1, arg2), + new CelFunctionOverload() { + @Override + public Object apply(Object[] args) throws CelEvaluationException { + return impl.apply((T1) args[0], (T2) args[1]); + } + + @Override + public Object apply(Object arg1, Object arg2) throws CelEvaluationException { + return impl.apply((T1) arg1, (T2) arg2); + } + }); } /** Create a function binding from the {@code overloadId}, {@code argTypes}, and {@code impl}. */ diff --git a/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java b/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java index 3e30a2146..e1bdbf886 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java @@ -26,6 +26,16 @@ public interface CelFunctionOverload { /** Evaluate a set of arguments throwing a {@code CelException} on error. */ Object apply(Object[] args) throws CelEvaluationException; + /** Fast-path for unary function execution to avoid Object[] allocation. */ + default Object apply(Object arg) throws CelEvaluationException { + return apply(new Object[] {arg}); + } + + /** Fast-path for binary function execution to avoid Object[] allocation. */ + default Object apply(Object arg1, Object arg2) throws CelEvaluationException { + return apply(new Object[] {arg1, arg2}); + } + /** * Helper interface for describing unary functions where the type-parameter is used to improve * compile-time correctness of function bindings. @@ -57,27 +67,46 @@ static boolean canHandle( for (int i = 0; i < parameterTypes.size(); i++) { Class paramType = parameterTypes.get(i); Object arg = arguments[i]; - if (arg == null) { - // null can be assigned to messages, maps, and to objects. - // TODO: Remove null special casing - if (paramType != Object.class && !Map.class.isAssignableFrom(paramType)) { - return false; - } - continue; + boolean result = canHandleArg(arg, paramType, isStrict); + if (!result) { + return false; } + } + return true; + } - if (arg instanceof Exception || arg instanceof CelUnknownSet) { - // Only non-strict functions can accept errors/unknowns as arguments to a function - if (!isStrict) { - // Skip assignability check below, but continue to validate remaining args - continue; - } - } + static boolean canHandle(Object arg, ImmutableList> parameterTypes, boolean isStrict) { + if (parameterTypes.size() != 1) { + return false; + } + return canHandleArg(arg, parameterTypes.get(0), isStrict); + } + + static boolean canHandle( + Object arg1, Object arg2, ImmutableList> parameterTypes, boolean isStrict) { + if (parameterTypes.size() != 2) { + return false; + } + return canHandleArg(arg1, parameterTypes.get(0), isStrict) + && canHandleArg(arg2, parameterTypes.get(1), isStrict); + } - if (!paramType.isAssignableFrom(arg.getClass())) { + static boolean canHandleArg(Object arg, Class paramType, boolean isStrict) { + // null can be assigned to messages, maps, and to objects. + // TODO: Remove null special casing + if (arg == null) { + if (paramType != Object.class && !Map.class.isAssignableFrom(paramType)) { return false; } + return true; } - return true; + + if (arg instanceof Exception || arg instanceof CelUnknownSet) { + if (!isStrict) { + return true; + } + } + + return paramType.isAssignableFrom(arg.getClass()); } } diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java index 35e3b76a3..87cb07945 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java @@ -200,19 +200,48 @@ public DefaultDispatcher build() { for (Map.Entry entry : overloads.entrySet()) { String overloadId = entry.getKey(); OverloadEntry overloadEntry = entry.getValue(); + CelFunctionOverload overloadImpl = overloadEntry.overload(); + + CelFunctionOverload guardedApply; + if (overloadImpl instanceof DynamicDispatchOverload) { + // Dynamic dispatcher already does its own internal canHandle checks + guardedApply = overloadImpl; + } else { + boolean isStrict = overloadEntry.isStrict(); + ImmutableList> argTypes = overloadEntry.argTypes(); + + guardedApply = + new CelFunctionOverload() { + @Override + public Object apply(Object[] args) throws CelEvaluationException { + if (CelFunctionOverload.canHandle(args, argTypes, isStrict)) { + return overloadImpl.apply(args); + } + throw new CelOverloadNotFoundException(overloadId); + } + + @Override + public Object apply(Object arg) throws CelEvaluationException { + if (CelFunctionOverload.canHandle(arg, argTypes, isStrict)) { + return overloadImpl.apply(arg); + } + throw new CelOverloadNotFoundException(overloadId); + } + + @Override + public Object apply(Object arg1, Object arg2) throws CelEvaluationException { + if (CelFunctionOverload.canHandle(arg1, arg2, argTypes, isStrict)) { + return overloadImpl.apply(arg1, arg2); + } + throw new CelOverloadNotFoundException(overloadId); + } + }; + } + resolvedOverloads.put( overloadId, CelResolvedOverload.of( - overloadId, - args -> - guardedOp( - overloadId, - args, - overloadEntry.argTypes(), - overloadEntry.isStrict(), - overloadEntry.overload()), - overloadEntry.isStrict(), - overloadEntry.argTypes())); + overloadId, guardedApply, overloadEntry.isStrict(), overloadEntry.argTypes())); } return new DefaultDispatcher(resolvedOverloads.buildOrThrow()); @@ -223,23 +252,6 @@ private Builder() { } } - /** Creates an invocation guard around the overload definition. */ - private static Object guardedOp( - String functionName, - Object[] args, - ImmutableList> argTypes, - boolean isStrict, - CelFunctionOverload overload) - throws CelEvaluationException { - // Argument checking for DynamicDispatch is handled inside the overload's apply method itself. - if (overload instanceof DynamicDispatchOverload - || CelFunctionOverload.canHandle(args, argTypes, isStrict)) { - return overload.apply(args); - } - - throw new CelOverloadNotFoundException(functionName); - } - DefaultDispatcher(ImmutableMap overloads) { this.overloads = overloads; } diff --git a/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java b/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java index faea853f8..1f47f1dfd 100644 --- a/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java @@ -145,6 +145,35 @@ public Object apply(Object[] args) throws CelEvaluationException { .collect(toImmutableList())); } + @Override + public Object apply(Object arg) throws CelEvaluationException { + for (CelFunctionBinding overload : overloadBindings) { + if (CelFunctionOverload.canHandle(arg, overload.getArgTypes(), overload.isStrict())) { + return overload.getDefinition().apply(arg); + } + } + throw new CelOverloadNotFoundException( + functionName, + overloadBindings.stream() + .map(CelFunctionBinding::getOverloadId) + .collect(toImmutableList())); + } + + @Override + public Object apply(Object arg1, Object arg2) throws CelEvaluationException { + for (CelFunctionBinding overload : overloadBindings) { + if (CelFunctionOverload.canHandle( + arg1, arg2, overload.getArgTypes(), overload.isStrict())) { + return overload.getDefinition().apply(arg1, arg2); + } + } + throw new CelOverloadNotFoundException( + functionName, + overloadBindings.stream() + .map(CelFunctionBinding::getOverloadId) + .collect(toImmutableList())); + } + ImmutableSet getOverloadBindings() { return overloadBindings; } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel index 3c18b192f..fc70118e4 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel @@ -17,6 +17,7 @@ java_library( ":error_metadata", ":eval_and", ":eval_attribute", + ":eval_binary", ":eval_conditional", ":eval_const", ":eval_create_list", @@ -232,6 +233,21 @@ java_library( ], ) +java_library( + name = "eval_binary", + srcs = ["EvalBinary.java"], + deps = [ + ":eval_helpers", + ":execution_frame", + ":planned_interpretable", + "//common/values", + "//runtime:accumulated_unknowns", + "//runtime:evaluation_exception", + "//runtime:interpretable", + "//runtime:resolved_overload", + ], +) + java_library( name = "eval_var_args_call", srcs = ["EvalVarArgsCall.java"], diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalBinary.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalBinary.java new file mode 100644 index 000000000..7771da3e6 --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalBinary.java @@ -0,0 +1,75 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import static dev.cel.runtime.planner.EvalHelpers.evalNonstrictly; +import static dev.cel.runtime.planner.EvalHelpers.evalStrictly; + +import dev.cel.common.values.CelValueConverter; +import dev.cel.runtime.AccumulatedUnknowns; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelResolvedOverload; +import dev.cel.runtime.GlobalResolver; + +final class EvalBinary extends PlannedInterpretable { + + private final CelResolvedOverload resolvedOverload; + private final PlannedInterpretable arg1; + private final PlannedInterpretable arg2; + private final CelValueConverter celValueConverter; + + @Override + public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object argVal1 = + resolvedOverload.isStrict() + ? evalStrictly(arg1, resolver, frame) + : evalNonstrictly(arg1, resolver, frame); + Object argVal2 = + resolvedOverload.isStrict() + ? evalStrictly(arg2, resolver, frame) + : evalNonstrictly(arg2, resolver, frame); + + AccumulatedUnknowns unknowns = AccumulatedUnknowns.maybeMerge(null, argVal1); + unknowns = AccumulatedUnknowns.maybeMerge(unknowns, argVal2); + + if (unknowns != null) { + return unknowns; + } + + return EvalHelpers.dispatch(resolvedOverload, celValueConverter, argVal1, argVal2); + } + + static EvalBinary create( + long exprId, + CelResolvedOverload resolvedOverload, + PlannedInterpretable arg1, + PlannedInterpretable arg2, + CelValueConverter celValueConverter) { + return new EvalBinary(exprId, resolvedOverload, arg1, arg2, celValueConverter); + } + + private EvalBinary( + long exprId, + CelResolvedOverload resolvedOverload, + PlannedInterpretable arg1, + PlannedInterpretable arg2, + CelValueConverter celValueConverter) { + super(exprId); + this.resolvedOverload = resolvedOverload; + this.arg1 = arg1; + this.arg2 = arg2; + this.celValueConverter = celValueConverter; + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java index 38b060b92..5c1dd80b3 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java @@ -61,17 +61,44 @@ static Object dispatch( try { Object result = overload.getDefinition().apply(args); return valueConverter.maybeUnwrap(valueConverter.toRuntimeValue(result)); - } catch (CelRuntimeException e) { - // Function dispatch failure that's already been handled -- just propagate. - throw e; } catch (RuntimeException e) { - // Unexpected function dispatch failure. - throw new IllegalArgumentException( - String.format( - "Function '%s' failed with arg(s) '%s'", - overload.getOverloadId(), Joiner.on(", ").join(args)), - e); + throw handleDispatchException(e, overload, args); + } + } + + static Object dispatch(CelResolvedOverload overload, CelValueConverter valueConverter, Object arg) + throws CelEvaluationException { + try { + Object result = overload.getDefinition().apply(arg); + return valueConverter.maybeUnwrap(valueConverter.toRuntimeValue(result)); + } catch (RuntimeException e) { + throw handleDispatchException(e, overload, arg); + } + } + + static Object dispatch( + CelResolvedOverload overload, CelValueConverter valueConverter, Object arg1, Object arg2) + throws CelEvaluationException { + try { + Object result = overload.getDefinition().apply(arg1, arg2); + return valueConverter.maybeUnwrap(valueConverter.toRuntimeValue(result)); + } catch (RuntimeException e) { + throw handleDispatchException(e, overload, arg1, arg2); + } + } + + private static RuntimeException handleDispatchException( + RuntimeException e, CelResolvedOverload overload, Object... args) { + if (e instanceof CelRuntimeException) { + // Function dispatch failure that's already been handled -- just propagate. + return e; } + // Unexpected function dispatch failure. + return new IllegalArgumentException( + String.format( + "Function '%s' failed with arg(s) '%s'", + overload.getOverloadId(), Joiner.on(", ").join(args)), + e); } private EvalHelpers() {} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java index c715ff032..322648ee3 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java @@ -34,9 +34,7 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval resolvedOverload.isStrict() ? evalStrictly(arg, resolver, frame) : evalNonstrictly(arg, resolver, frame); - Object[] arguments = new Object[] {argVal}; - - return EvalHelpers.dispatch(resolvedOverload, celValueConverter, arguments); + return EvalHelpers.dispatch(resolvedOverload, celValueConverter, argVal); } static EvalUnary create( diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java index 7935e4838..b144c4ec9 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java @@ -276,6 +276,9 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { return EvalZeroArity.create(expr.id(), resolvedOverload, celValueConverter); case 1: return EvalUnary.create(expr.id(), resolvedOverload, evaluatedArgs[0], celValueConverter); + case 2: + return EvalBinary.create( + expr.id(), resolvedOverload, evaluatedArgs[0], evaluatedArgs[1], celValueConverter); default: return EvalVarArgsCall.create( expr.id(), resolvedOverload, evaluatedArgs, celValueConverter); From 05cb0a114333b3317378a58a98be5480995fd0d3 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 31 Mar 2026 15:33:35 -0700 Subject: [PATCH 24/66] Remove createStruct planinng overhead for type-checked ASTs PiperOrigin-RevId: 892572698 --- .../cel/runtime/planner/ProgramPlanner.java | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java index b144c4ec9..add918f64 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java @@ -343,7 +343,7 @@ private Optional maybeInterceptOptionalCalls( private PlannedInterpretable planCreateStruct(CelExpr celExpr, PlannerContext ctx) { CelStruct struct = celExpr.struct(); - CelType structType = resolveStructType(struct); + CelType structType = resolveStructType(celExpr, ctx); ImmutableList entries = struct.entries(); String[] keys = new String[entries.size()]; @@ -489,7 +489,17 @@ private ResolvedFunction resolveFunction( return ResolvedFunction.newBuilder().setFunctionName(functionName).setTarget(target).build(); } - private CelType resolveStructType(CelStruct struct) { + private CelType resolveStructType(CelExpr expr, PlannerContext ctx) { + CelType checkedType = ctx.typeMap().get(expr.id()); + if (checkedType != null) { + CelKind kind = checkedType.kind(); + // Type-checked ASTs do not need a type-provider lookup as long as it's of expected kind. + if (isValidStructKind(kind)) { + return checkedType; + } + } + + CelStruct struct = expr.struct(); String messageName = struct.messageName(); for (String typeName : container.resolveCandidateNames(messageName)) { CelType structType = typeProvider.findType(typeName).orElse(null); @@ -499,9 +509,7 @@ private CelType resolveStructType(CelStruct struct) { CelKind kind = structType.kind(); - if (!kind.equals(CelKind.STRUCT) - && !kind.equals(CelKind.TIMESTAMP) - && !kind.equals(CelKind.DURATION)) { + if (!isValidStructKind(kind)) { throw new IllegalArgumentException( String.format( "Expected struct type for %s, got %s", structType.name(), structType.kind())); @@ -513,6 +521,12 @@ private CelType resolveStructType(CelStruct struct) { throw new IllegalArgumentException("Undefined type name: " + messageName); } + private static boolean isValidStructKind(CelKind kind) { + return kind.equals(CelKind.STRUCT) + || kind.equals(CelKind.TIMESTAMP) + || kind.equals(CelKind.DURATION); + } + /** Converts a given expression into a qualified name, if possible. */ private Optional toQualifiedName(CelExpr operand) { switch (operand.getKind()) { From 8e9f1acaa882a858167285093f43722c8f021a7b Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 31 Mar 2026 17:25:32 -0700 Subject: [PATCH 25/66] Optimize attribute qualification process and empty message creation PiperOrigin-RevId: 892624189 --- .../cel/common/internal/DefaultMessageFactory.java | 2 +- .../src/main/java/dev/cel/runtime/CelAttribute.java | 13 ++++++++----- .../java/dev/cel/runtime/CelAttributePattern.java | 13 ++++++++----- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/common/src/main/java/dev/cel/common/internal/DefaultMessageFactory.java b/common/src/main/java/dev/cel/common/internal/DefaultMessageFactory.java index 4a021cd90..68d05e127 100644 --- a/common/src/main/java/dev/cel/common/internal/DefaultMessageFactory.java +++ b/common/src/main/java/dev/cel/common/internal/DefaultMessageFactory.java @@ -52,7 +52,7 @@ public Optional newBuilder(String messageName) { DefaultInstanceMessageFactory.getInstance().getPrototype(descriptor.get()); if (message.isPresent()) { - return message.map(Message::toBuilder); + return message.map(Message::newBuilderForType); } return Optional.of(DynamicMessage.newBuilder(descriptor.get())); diff --git a/runtime/src/main/java/dev/cel/runtime/CelAttribute.java b/runtime/src/main/java/dev/cel/runtime/CelAttribute.java index 6080dbaa1..f04418e0c 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/CelAttribute.java @@ -17,7 +17,6 @@ import com.google.auto.value.AutoOneOf; import com.google.auto.value.AutoValue; import com.google.common.base.Preconditions; -import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import com.google.common.primitives.UnsignedLong; import com.google.errorprone.annotations.Immutable; @@ -184,9 +183,13 @@ public static CelAttribute create(String rootIdentifier) { */ public static CelAttribute fromQualifiedIdentifier(String qualifiedIdentifier) { ImmutableList.Builder qualifiers = ImmutableList.builder(); - Splitter.on(".") - .split(qualifiedIdentifier) - .forEach((element) -> qualifiers.add(Qualifier.ofString(element))); + int start = 0; + int next; + while ((next = qualifiedIdentifier.indexOf('.', start)) != -1) { + qualifiers.add(Qualifier.ofString(qualifiedIdentifier.substring(start, next))); + start = next + 1; + } + qualifiers.add(Qualifier.ofString(qualifiedIdentifier.substring(start))); return new AutoValue_CelAttribute(qualifiers.build()); } @@ -206,7 +209,7 @@ public CelAttribute qualify(Qualifier qualifier) { return EMPTY; } return new AutoValue_CelAttribute( - ImmutableList.builder().addAll(qualifiers()).add(qualifier).build()); + ImmutableList.builderWithExpectedSize(qualifiers().size() + 1).addAll(qualifiers()).add(qualifier).build()); } @Override diff --git a/runtime/src/main/java/dev/cel/runtime/CelAttributePattern.java b/runtime/src/main/java/dev/cel/runtime/CelAttributePattern.java index 9075cd7a8..ff5f3f5bf 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelAttributePattern.java +++ b/runtime/src/main/java/dev/cel/runtime/CelAttributePattern.java @@ -18,7 +18,6 @@ import com.google.auto.value.AutoValue; import com.google.common.base.Preconditions; -import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; @@ -62,9 +61,13 @@ public static CelAttributePattern create(String rootIdentifier) { */ public static CelAttributePattern fromQualifiedIdentifier(String qualifiedIdentifier) { ImmutableList.Builder qualifiers = ImmutableList.builder(); - Splitter.on(".") - .split(qualifiedIdentifier) - .forEach((String element) -> qualifiers.add(CelAttribute.Qualifier.ofString(element))); + int start = 0; + int next; + while ((next = qualifiedIdentifier.indexOf('.', start)) != -1) { + qualifiers.add(CelAttribute.Qualifier.ofString(qualifiedIdentifier.substring(start, next))); + start = next + 1; + } + qualifiers.add(CelAttribute.Qualifier.ofString(qualifiedIdentifier.substring(start))); return new AutoValue_CelAttributePattern(qualifiers.build()); } @@ -74,7 +77,7 @@ public static CelAttributePattern fromQualifiedIdentifier(String qualifiedIdenti /** Create a new attribute pattern that specifies a subfield of this pattern. */ public CelAttributePattern qualify(CelAttribute.Qualifier qualifier) { return new AutoValue_CelAttributePattern( - ImmutableList.builder() + ImmutableList.builderWithExpectedSize(qualifiers().size() + 1) .addAll(qualifiers()) .add(qualifier) .build()); From ce4d2179fc69458234367d1f61ea3fe1f80a178d Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 1 Apr 2026 15:03:56 -0700 Subject: [PATCH 26/66] Move fast-path unary/binary apply methods into an internal interface PiperOrigin-RevId: 893131554 --- .../src/main/java/dev/cel/runtime/BUILD.bazel | 10 ++++- .../dev/cel/runtime/CelFunctionBinding.java | 8 ++-- .../dev/cel/runtime/CelFunctionOverload.java | 9 ----- .../cel/runtime/CelLateFunctionBindings.java | 2 +- .../dev/cel/runtime/CelResolvedOverload.java | 38 +++++++++++++++++- .../dev/cel/runtime/DefaultDispatcher.java | 39 +------------------ .../dev/cel/runtime/FunctionBindingImpl.java | 8 ++-- .../runtime/OptimizedFunctionOverload.java | 35 +++++++++++++++++ .../dev/cel/runtime/planner/EvalHelpers.java | 6 +-- 9 files changed, 94 insertions(+), 61 deletions(-) create mode 100644 runtime/src/main/java/dev/cel/runtime/OptimizedFunctionOverload.java diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index 2681c17de..6f0607de4 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -129,7 +129,6 @@ java_library( "//:auto_value", "//common:error_codes", "//common/annotations", - "//common/exceptions:overload_not_found", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -151,7 +150,6 @@ cel_android_library( "//:auto_value", "//common:error_codes", "//common/annotations", - "//common/exceptions:overload_not_found", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven_android//:com_google_guava_guava", @@ -790,6 +788,7 @@ java_library( name = "function_overload", srcs = [ "CelFunctionOverload.java", + "OptimizedFunctionOverload.java", ], tags = [ ], @@ -805,6 +804,7 @@ cel_android_library( name = "function_overload_android", srcs = [ "CelFunctionOverload.java", + "OptimizedFunctionOverload.java", ], deps = [ ":evaluation_exception", @@ -1306,9 +1306,12 @@ java_library( tags = [ ], deps = [ + ":evaluation_exception", + ":function_binding", ":function_overload", "//:auto_value", "//common/annotations", + "//common/exceptions:overload_not_found", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", ], @@ -1320,9 +1323,12 @@ cel_android_library( tags = [ ], deps = [ + ":evaluation_exception", + ":function_binding_android", ":function_overload_android", "//:auto_value", "//common/annotations", + "//common/exceptions:overload_not_found", "@maven//:com_google_errorprone_error_prone_annotations", "@maven_android//:com_google_guava_guava", ], diff --git a/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java b/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java index 06e5facdf..88be0d3c3 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java +++ b/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java @@ -51,13 +51,13 @@ public interface CelFunctionBinding { boolean isStrict(); /** Create a unary function binding from the {@code overloadId}, {@code arg}, and {@code impl}. */ - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Safe from CelFunctionOverload.canHandle check before invocation static CelFunctionBinding from( String overloadId, Class arg, CelFunctionOverload.Unary impl) { return from( overloadId, ImmutableList.of(arg), - new CelFunctionOverload() { + new OptimizedFunctionOverload() { @Override public Object apply(Object[] args) throws CelEvaluationException { return impl.apply((T) args[0]); @@ -74,13 +74,13 @@ public Object apply(Object arg1) throws CelEvaluationException { * Create a binary function binding from the {@code overloadId}, {@code arg1}, {@code arg2}, and * {@code impl}. */ - @SuppressWarnings("unchecked") + @SuppressWarnings("unchecked") // Safe from CelFunctionOverload.canHandle check before invocation static CelFunctionBinding from( String overloadId, Class arg1, Class arg2, CelFunctionOverload.Binary impl) { return from( overloadId, ImmutableList.of(arg1, arg2), - new CelFunctionOverload() { + new OptimizedFunctionOverload() { @Override public Object apply(Object[] args) throws CelEvaluationException { return impl.apply((T1) args[0], (T2) args[1]); diff --git a/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java b/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java index e1bdbf886..c5f75096d 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/CelFunctionOverload.java @@ -26,15 +26,6 @@ public interface CelFunctionOverload { /** Evaluate a set of arguments throwing a {@code CelException} on error. */ Object apply(Object[] args) throws CelEvaluationException; - /** Fast-path for unary function execution to avoid Object[] allocation. */ - default Object apply(Object arg) throws CelEvaluationException { - return apply(new Object[] {arg}); - } - - /** Fast-path for binary function execution to avoid Object[] allocation. */ - default Object apply(Object arg1, Object arg2) throws CelEvaluationException { - return apply(new Object[] {arg1, arg2}); - } /** * Helper interface for describing unary functions where the type-parameter is used to improve diff --git a/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java b/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java index c1f4b236f..3d75845cf 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java +++ b/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java @@ -65,7 +65,7 @@ public static CelLateFunctionBindings from(Collection functi private static CelResolvedOverload createResolvedOverload(CelFunctionBinding binding) { return CelResolvedOverload.of( binding.getOverloadId(), - (args) -> binding.getDefinition().apply(args), + binding.getDefinition(), binding.isStrict(), binding.getArgTypes()); } diff --git a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java index 2bcdf3a2d..7063720a1 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; import dev.cel.common.annotations.Internal; +import dev.cel.common.exceptions.CelOverloadNotFoundException; import java.util.List; /** @@ -52,6 +53,33 @@ public abstract class CelResolvedOverload { /** The function definition. */ public abstract CelFunctionOverload getDefinition(); + abstract OptimizedFunctionOverload getOptimizedDefinition(); + + public Object invoke(Object[] args) throws CelEvaluationException { + // Note: canHandle check is handled separately in DynamicDispatchOverload + if (isDynamicDispatch() + || CelFunctionOverload.canHandle(args, getParameterTypes(), isStrict())) { + return getDefinition().apply(args); + } + throw new CelOverloadNotFoundException(getOverloadId()); + } + + public Object invoke(Object arg) throws CelEvaluationException { + if (isDynamicDispatch() + || CelFunctionOverload.canHandle(arg, getParameterTypes(), isStrict())) { + return getOptimizedDefinition().apply(arg); + } + throw new CelOverloadNotFoundException(getOverloadId()); + } + + public Object invoke(Object arg1, Object arg2) throws CelEvaluationException { + if (isDynamicDispatch() + || CelFunctionOverload.canHandle(arg1, arg2, getParameterTypes(), isStrict())) { + return getOptimizedDefinition().apply(arg1, arg2); + } + throw new CelOverloadNotFoundException(getOverloadId()); + } + /** * Creates a new resolved overload from the given overload id, parameter types, and definition. */ @@ -71,8 +99,12 @@ public static CelResolvedOverload of( CelFunctionOverload definition, boolean isStrict, List> parameterTypes) { + OptimizedFunctionOverload optimizedDef = + (definition instanceof OptimizedFunctionOverload) + ? (OptimizedFunctionOverload) definition + : definition::apply; return new AutoValue_CelResolvedOverload( - overloadId, ImmutableList.copyOf(parameterTypes), isStrict, definition); + overloadId, ImmutableList.copyOf(parameterTypes), isStrict, definition, optimizedDef); } /** @@ -81,4 +113,8 @@ public static CelResolvedOverload of( boolean canHandle(Object[] arguments) { return CelFunctionOverload.canHandle(arguments, getParameterTypes(), isStrict()); } + + private boolean isDynamicDispatch() { + return getDefinition() instanceof FunctionBindingImpl.DynamicDispatchOverload; + } } diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java index 87cb07945..d6ddf3965 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java @@ -26,7 +26,6 @@ import com.google.errorprone.annotations.Immutable; import dev.cel.common.CelErrorCode; import dev.cel.common.annotations.Internal; -import dev.cel.common.exceptions.CelOverloadNotFoundException; import dev.cel.runtime.FunctionBindingImpl.DynamicDispatchOverload; import java.util.ArrayList; import java.util.Collection; @@ -202,46 +201,10 @@ public DefaultDispatcher build() { OverloadEntry overloadEntry = entry.getValue(); CelFunctionOverload overloadImpl = overloadEntry.overload(); - CelFunctionOverload guardedApply; - if (overloadImpl instanceof DynamicDispatchOverload) { - // Dynamic dispatcher already does its own internal canHandle checks - guardedApply = overloadImpl; - } else { - boolean isStrict = overloadEntry.isStrict(); - ImmutableList> argTypes = overloadEntry.argTypes(); - - guardedApply = - new CelFunctionOverload() { - @Override - public Object apply(Object[] args) throws CelEvaluationException { - if (CelFunctionOverload.canHandle(args, argTypes, isStrict)) { - return overloadImpl.apply(args); - } - throw new CelOverloadNotFoundException(overloadId); - } - - @Override - public Object apply(Object arg) throws CelEvaluationException { - if (CelFunctionOverload.canHandle(arg, argTypes, isStrict)) { - return overloadImpl.apply(arg); - } - throw new CelOverloadNotFoundException(overloadId); - } - - @Override - public Object apply(Object arg1, Object arg2) throws CelEvaluationException { - if (CelFunctionOverload.canHandle(arg1, arg2, argTypes, isStrict)) { - return overloadImpl.apply(arg1, arg2); - } - throw new CelOverloadNotFoundException(overloadId); - } - }; - } - resolvedOverloads.put( overloadId, CelResolvedOverload.of( - overloadId, guardedApply, overloadEntry.isStrict(), overloadEntry.argTypes())); + overloadId, overloadImpl, overloadEntry.isStrict(), overloadEntry.argTypes())); } return new DefaultDispatcher(resolvedOverloads.buildOrThrow()); diff --git a/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java b/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java index 1f47f1dfd..c1306ce19 100644 --- a/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java @@ -126,7 +126,7 @@ private DynamicDispatchBinding( } @Immutable - static final class DynamicDispatchOverload implements CelFunctionOverload { + static final class DynamicDispatchOverload implements OptimizedFunctionOverload { private final String functionName; private final ImmutableSet overloadBindings; @@ -149,7 +149,8 @@ public Object apply(Object[] args) throws CelEvaluationException { public Object apply(Object arg) throws CelEvaluationException { for (CelFunctionBinding overload : overloadBindings) { if (CelFunctionOverload.canHandle(arg, overload.getArgTypes(), overload.isStrict())) { - return overload.getDefinition().apply(arg); + OptimizedFunctionOverload def = (OptimizedFunctionOverload) overload.getDefinition(); + return def.apply(arg); } } throw new CelOverloadNotFoundException( @@ -164,7 +165,8 @@ public Object apply(Object arg1, Object arg2) throws CelEvaluationException { for (CelFunctionBinding overload : overloadBindings) { if (CelFunctionOverload.canHandle( arg1, arg2, overload.getArgTypes(), overload.isStrict())) { - return overload.getDefinition().apply(arg1, arg2); + OptimizedFunctionOverload def = (OptimizedFunctionOverload) overload.getDefinition(); + return def.apply(arg1, arg2); } } throw new CelOverloadNotFoundException( diff --git a/runtime/src/main/java/dev/cel/runtime/OptimizedFunctionOverload.java b/runtime/src/main/java/dev/cel/runtime/OptimizedFunctionOverload.java new file mode 100644 index 000000000..fde8bcc15 --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/OptimizedFunctionOverload.java @@ -0,0 +1,35 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime; + +import com.google.errorprone.annotations.Immutable; + +/** + * Internal interface to support fast-path Unary and Binary evaluations, avoiding Object[] + * allocation. + */ +@Immutable +interface OptimizedFunctionOverload extends CelFunctionOverload { + + /** Fast-path for unary function execution to avoid Object[] allocation. */ + default Object apply(Object arg) throws CelEvaluationException { + return apply(new Object[] {arg}); + } + + /** Fast-path for binary function execution to avoid Object[] allocation. */ + default Object apply(Object arg1, Object arg2) throws CelEvaluationException { + return apply(new Object[] {arg1, arg2}); + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java index 5c1dd80b3..a30f91880 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java @@ -59,7 +59,7 @@ static Object dispatch( CelResolvedOverload overload, CelValueConverter valueConverter, Object[] args) throws CelEvaluationException { try { - Object result = overload.getDefinition().apply(args); + Object result = overload.invoke(args); return valueConverter.maybeUnwrap(valueConverter.toRuntimeValue(result)); } catch (RuntimeException e) { throw handleDispatchException(e, overload, args); @@ -69,7 +69,7 @@ static Object dispatch( static Object dispatch(CelResolvedOverload overload, CelValueConverter valueConverter, Object arg) throws CelEvaluationException { try { - Object result = overload.getDefinition().apply(arg); + Object result = overload.invoke(arg); return valueConverter.maybeUnwrap(valueConverter.toRuntimeValue(result)); } catch (RuntimeException e) { throw handleDispatchException(e, overload, arg); @@ -80,7 +80,7 @@ static Object dispatch( CelResolvedOverload overload, CelValueConverter valueConverter, Object arg1, Object arg2) throws CelEvaluationException { try { - Object result = overload.getDefinition().apply(arg1, arg2); + Object result = overload.invoke(arg1, arg2); return valueConverter.maybeUnwrap(valueConverter.toRuntimeValue(result)); } catch (RuntimeException e) { throw handleDispatchException(e, overload, arg1, arg2); From 4d00593223f7373eeda605353180fed78d2067f9 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 1 Apr 2026 15:28:25 -0700 Subject: [PATCH 27/66] Fix partial evaluation to properly check for comprehension bound variables for planner PiperOrigin-RevId: 893142702 --- .../dev/cel/runtime/planner/ActivationWrapper.java | 3 +++ .../java/dev/cel/runtime/planner/EvalFold.java | 5 +++++ .../cel/runtime/planner/NamespacedAttribute.java | 13 ++++++++++++- .../dev/cel/runtime/PlannerInterpreterTest.java | 5 +++++ .../planner_unknownResultSet_success.baseline | 14 +++++++++++++- testing/src/main/java/dev/cel/testing/BUILD.bazel | 1 + .../java/dev/cel/testing/BaseInterpreterTest.java | 3 ++- 7 files changed, 41 insertions(+), 3 deletions(-) diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ActivationWrapper.java b/runtime/src/main/java/dev/cel/runtime/planner/ActivationWrapper.java index f844ab232..8883ac12c 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ActivationWrapper.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ActivationWrapper.java @@ -19,4 +19,7 @@ /** Identifies a resolver that can be unwrapped to bypass local variable state. */ public interface ActivationWrapper extends GlobalResolver { GlobalResolver unwrap(); + + /** Returns true if the given name is bound by this local activation wrapper. */ + boolean isLocallyBound(String name); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java index 197db42ad..2631bf0b9 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java @@ -175,6 +175,11 @@ public GlobalResolver unwrap() { return resolver; } + @Override + public boolean isLocallyBound(String name) { + return name.equals(accuVar) || name.equals(iterVar) || name.equals(iterVar2); + } + @Override public @Nullable Object resolve(String name) { if (name.equals(accuVar)) { diff --git a/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java b/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java index d51336d80..0000ad764 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java @@ -75,7 +75,7 @@ public Object resolve(long exprId, GlobalResolver ctx, ExecutionFrame frame) { PartialVars partialVars = frame.partialVars().orElse(null); - if (partialVars != null) { + if (partialVars != null && !isLocallyBound(resolver, name)) { ImmutableList patterns = partialVars.unknowns(); // Avoid enhanced for loop to prevent UnmodifiableIterator from being allocated for (int i = 0; i < qualifiers.size(); i++) { @@ -151,6 +151,17 @@ private static Long getEnumValue(EnumType enumType, String field) { String.format("Field %s was not found on enum %s", enumType.name(), field))); } + private boolean isLocallyBound(GlobalResolver resolver, String name) { + while (resolver instanceof ActivationWrapper) { + ActivationWrapper wrapper = (ActivationWrapper) resolver; + if (wrapper.isLocallyBound(name)) { + return true; + } + resolver = wrapper.unwrap(); + } + return false; + } + private GlobalResolver unwrapToNonLocal(GlobalResolver resolver) { while (resolver instanceof ActivationWrapper) { resolver = ((ActivationWrapper) resolver).unwrap(); diff --git a/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java index 181842ab4..2b0e53298 100644 --- a/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java +++ b/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java @@ -283,6 +283,11 @@ public void planner_unknownResultSet_success() { declareVariable("unknown_list", ListType.create(SimpleType.INT)); source = "unknown_list.map(x, x)"; runTest(variables, CelAttributePattern.fromQualifiedIdentifier("unknown_list")); + + clearAllDeclarations(); + declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); + source = "cel.bind(x, [1, 2, 3], 1 in x)"; + runTest(variables, CelAttributePattern.fromQualifiedIdentifier("x")); } @Test diff --git a/runtime/src/test/resources/planner_unknownResultSet_success.baseline b/runtime/src/test/resources/planner_unknownResultSet_success.baseline index 2f2c218d0..c5e8867db 100644 --- a/runtime/src/test/resources/planner_unknownResultSet_success.baseline +++ b/runtime/src/test/resources/planner_unknownResultSet_success.baseline @@ -458,4 +458,16 @@ single_timestamp { seconds: 15 } , unknown_attributes=[unknown_list]} -result: CelUnknownSet{attributes=[unknown_list], unknownExprIds=[1]} \ No newline at end of file +result: CelUnknownSet{attributes=[unknown_list], unknownExprIds=[1]} + +Source: cel.bind(x, [1, 2, 3], 1 in x) +declare x { + value cel.expr.conformance.proto3.TestAllTypes +} +=====> +bindings: {x=single_string: "test" +single_timestamp { + seconds: 15 +} +, unknown_attributes=[x]} +result: true \ No newline at end of file diff --git a/testing/src/main/java/dev/cel/testing/BUILD.bazel b/testing/src/main/java/dev/cel/testing/BUILD.bazel index 2ecabdf05..5ee142200 100644 --- a/testing/src/main/java/dev/cel/testing/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/BUILD.bazel @@ -87,6 +87,7 @@ java_library( "//common/types:message_type_provider", "//common/types:type_providers", "//common/values:cel_byte_string", + "//extensions", "//extensions:optional_library", "//runtime", "//runtime:function_binding", diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index 42cb5e41b..bda56a19e 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -74,6 +74,7 @@ import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.expr.conformance.proto3.TestAllTypes.NestedEnum; import dev.cel.expr.conformance.proto3.TestAllTypes.NestedMessage; +import dev.cel.extensions.CelExtensions; import dev.cel.extensions.CelOptionalLibrary; import dev.cel.runtime.CelAttributePattern; import dev.cel.runtime.CelEvaluationException; @@ -153,7 +154,7 @@ protected void prepareCompiler(CelTypeProvider typeProvider) { this.celCompiler = celCompiler .toCompilerBuilder() - .addLibraries(CelOptionalLibrary.INSTANCE) + .addLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) .setOptions(celOptions) .build(); } From 46bae721769da226f2d1fc560d4e8f8dce47e2f0 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 2 Apr 2026 13:30:20 -0700 Subject: [PATCH 28/66] Fix constant folding to not error when sub-asts contain unbound variables CEL-Java fix for xref: https://github.com/google/cel-go/issues/1296 PiperOrigin-RevId: 893670998 --- .../dev/cel/optimizer/optimizers/BUILD.bazel | 2 + .../optimizers/ConstantFoldingOptimizer.java | 22 ++++- .../dev/cel/optimizer/optimizers/BUILD.bazel | 1 + .../ConstantFoldingOptimizerTest.java | 97 ++++++++++++++----- .../java/dev/cel/runtime/PartialVars.java | 7 +- 5 files changed, 97 insertions(+), 32 deletions(-) diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel index 7984cf3ba..c887f3d15 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -35,6 +35,8 @@ java_library( "//optimizer:mutable_ast", "//optimizer:optimization_exception", "//runtime", + "//runtime:partial_vars", + "//runtime:unknown_attributes", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", ], diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index ada73ce56..c017911f9 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -30,7 +30,6 @@ import dev.cel.common.CelValidationException; import dev.cel.common.Operator; import dev.cel.common.ast.CelConstant; -import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.ExprKind.Kind; import dev.cel.common.ast.CelMutableExpr; import dev.cel.common.ast.CelMutableExpr.CelMutableCall; @@ -47,7 +46,10 @@ import dev.cel.optimizer.AstMutator; import dev.cel.optimizer.CelAstOptimizer; import dev.cel.optimizer.CelOptimizationException; +import dev.cel.runtime.CelAttribute.Qualifier; +import dev.cel.runtime.CelAttributePattern; import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.PartialVars; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; @@ -282,7 +284,7 @@ private Optional maybeFold( throws CelOptimizationException { Object result; try { - result = evaluateExpr(cel, CelMutableExprConverter.fromMutableExpr(node.expr())); + result = evaluateExpr(cel, node); } catch (CelValidationException | CelEvaluationException e) { throw new CelOptimizationException( "Constant folding failure. Failed to evaluate subtree due to: " + e.getMessage(), e); @@ -674,13 +676,23 @@ private CelMutableAst pruneOptionalStructElements(CelMutableAst ast, CelMutableE } @CanIgnoreReturnValue - private static Object evaluateExpr(Cel cel, CelExpr expr) + private static Object evaluateExpr(Cel cel, CelNavigableMutableExpr navigableMutableExpr) throws CelValidationException, CelEvaluationException { + ImmutableList attributePatterns = + navigableMutableExpr + .allNodes() + .filter(node -> node.getKind().equals(Kind.IDENT)) + .map(node -> node.expr().ident().name()) + .filter(Qualifier::isLegalIdentifier) + .map(CelAttributePattern::create) + .collect(toImmutableList()); CelAbstractSyntaxTree ast = - CelAbstractSyntaxTree.newParsedAst(expr, CelSource.newBuilder().build()); + CelAbstractSyntaxTree.newParsedAst( + CelMutableExprConverter.fromMutableExpr(navigableMutableExpr.expr()), + CelSource.newBuilder().build()); ast = cel.check(ast).getAst(); - return cel.createProgram(ast).eval(); + return cel.createProgram(ast).eval(PartialVars.of(attributePatterns)); } /** Options to configure how Constant Folding behave. */ diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel index d91e48f54..b0c48682a 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -11,6 +11,7 @@ java_library( deps = [ # "//java/com/google/testing/testsize:annotations", "//bundle:cel", + "//bundle:cel_experimental_factory", "//common:cel_ast", "//common:cel_source", "//common:compiler_common", diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java index a8cadf83a..bbb5c6e7e 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -18,9 +18,12 @@ import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableList; +import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; +import dev.cel.bundle.CelBuilder; +import dev.cel.bundle.CelExperimentalFactory; import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; @@ -47,9 +50,23 @@ @RunWith(TestParameterInjector.class) public class ConstantFoldingOptimizerTest { private static final CelOptions CEL_OPTIONS = - CelOptions.current().populateMacroCalls(true).build(); - private static final Cel CEL = - CelFactory.standardCelBuilder() + CelOptions.current() + .populateMacroCalls(true) + .enableHeterogeneousNumericComparisons(true) + .build(); + + private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); + + @SuppressWarnings("ImmutableEnumChecker") // test only + private enum RuntimeEnv { + LEGACY(setupEnv(CelFactory.standardCelBuilder())), + PLANNER(setupEnv(CelExperimentalFactory.plannerCelBuilder())); + + private final Cel cel; + private final CelOptimizer celOptimizer; + + private static Cel setupEnv(CelBuilder celBuilder) { + return celBuilder .addVar("x", SimpleType.DYN) .addVar("y", SimpleType.DYN) .addVar("list_var", ListType.create(SimpleType.STRING)) @@ -84,13 +101,28 @@ public class ConstantFoldingOptimizerTest { CelExtensions.sets(CEL_OPTIONS), CelExtensions.encoders(CEL_OPTIONS)) .build(); + } + + RuntimeEnv(Cel cel) { + this.cel = cel; + this.celOptimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(cel) + .addAstOptimizers(ConstantFoldingOptimizer.getInstance()) + .build(); + } + + private CelBuilder newCelBuilder() { + switch (this) { + case LEGACY: + return CelFactory.standardCelBuilder(); + case PLANNER: + return CelExperimentalFactory.plannerCelBuilder(); + } + throw new AssertionError("Unknown RuntimeEnv: " + this); + } + } - private static final CelOptimizer CEL_OPTIMIZER = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) - .addAstOptimizers(ConstantFoldingOptimizer.getInstance()) - .build(); - - private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); + @TestParameter RuntimeEnv runtimeEnv; @Test @TestParameters("{source: 'null', expected: 'null'}") @@ -238,9 +270,9 @@ public class ConstantFoldingOptimizerTest { // TODO: Support folding lists with mixed types. This requires mutable lists. // @TestParameters("{source: 'dyn([1]) + [1.0]'}") public void constantFold_success(String source, String expected) throws Exception { - CelAbstractSyntaxTree ast = CEL.compile(source).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(source).getAst(); - CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + CelAbstractSyntaxTree optimizedAst = runtimeEnv.celOptimizer.optimize(ast); assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(expected); } @@ -285,12 +317,13 @@ public void constantFold_success(String source, String expected) throws Exceptio public void constantFold_macros_macroCallMetadataPopulated(String source, String expected) throws Exception { Cel cel = - CelFactory.standardCelBuilder() + runtimeEnv + .newCelBuilder() .addVar("x", SimpleType.DYN) .addVar("y", SimpleType.DYN) .addMessageTypes(TestAllTypes.getDescriptor()) .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .setOptions(CelOptions.current().populateMacroCalls(true).build()) + .setOptions(CEL_OPTIONS) .addCompilerLibraries( CelExtensions.bindings(), CelExtensions.optional(), CelExtensions.comprehensions()) .addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions()) @@ -330,12 +363,17 @@ public void constantFold_macros_macroCallMetadataPopulated(String source, String @TestParameters("{source: 'false ? false : cel.bind(a, true, a)'}") public void constantFold_macros_withoutMacroCallMetadata(String source) throws Exception { Cel cel = - CelFactory.standardCelBuilder() + runtimeEnv + .newCelBuilder() .addVar("x", SimpleType.DYN) .addVar("y", SimpleType.DYN) .addMessageTypes(TestAllTypes.getDescriptor()) .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .setOptions(CelOptions.current().populateMacroCalls(false).build()) + .setOptions( + CelOptions.current() + .enableHeterogeneousNumericComparisons(true) + .populateMacroCalls(false) + .build()) .addCompilerLibraries(CelExtensions.bindings(), CelOptionalLibrary.INSTANCE) .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) .build(); @@ -378,21 +416,22 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E @TestParameters("{source: 'duration(\"1h\")'}") @TestParameters("{source: '[true].exists(x, x == get_true())'}") @TestParameters("{source: 'get_list([1, 2]).map(x, x * 2)'}") + @TestParameters("{source: '[(x - 1 > 3) ? (x - 1) : 5].exists(x, x - 1 > 3)'}") public void constantFold_noOp(String source) throws Exception { - CelAbstractSyntaxTree ast = CEL.compile(source).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(source).getAst(); - CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + CelAbstractSyntaxTree optimizedAst = runtimeEnv.celOptimizer.optimize(ast); assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(source); } @Test public void constantFold_addFoldableFunction_success() throws Exception { - CelAbstractSyntaxTree ast = CEL.compile("get_true() == get_true()").getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("get_true() == get_true()").getAst(); ConstantFoldingOptions options = ConstantFoldingOptions.newBuilder().addFoldableFunctions("get_true").build(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) .addAstOptimizers(ConstantFoldingOptimizer.newInstance(options)) .build(); @@ -403,7 +442,7 @@ public void constantFold_addFoldableFunction_success() throws Exception { @Test public void constantFold_withExpectedResultTypeSet_success() throws Exception { - Cel cel = CelFactory.standardCelBuilder().setResultType(SimpleType.STRING).build(); + Cel cel = runtimeEnv.newCelBuilder().setResultType(SimpleType.STRING).build(); CelOptimizer optimizer = CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers(ConstantFoldingOptimizer.getInstance()) @@ -419,10 +458,11 @@ public void constantFold_withExpectedResultTypeSet_success() throws Exception { public void constantFold_withMacroCallPopulated_comprehensionsAreReplacedWithNotSet() throws Exception { Cel cel = - CelFactory.standardCelBuilder() + runtimeEnv + .newCelBuilder() .addVar("x", SimpleType.DYN) .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .setOptions(CelOptions.current().populateMacroCalls(true).build()) + .setOptions(CEL_OPTIONS) .build(); CelOptimizer celOptimizer = CelOptimizerFactory.standardCelOptimizerBuilder(cel) @@ -492,9 +532,9 @@ public void constantFold_withMacroCallPopulated_comprehensionsAreReplacedWithNot @Test public void constantFold_astProducesConsistentlyNumberedIds() throws Exception { - CelAbstractSyntaxTree ast = CEL.compile("[1] + [2] + [3]").getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("[1] + [2] + [3]").getAst(); - CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + CelAbstractSyntaxTree optimizedAst = runtimeEnv.celOptimizer.optimize(ast); assertThat(optimizedAst.getExpr().toString()) .isEqualTo( @@ -515,8 +555,13 @@ public void iterationLimitReached_throws() throws Exception { sb.append(" + ").append(i); } // 0 + 1 + 2 + 3 + ... 200 Cel cel = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().maxParseRecursionDepth(200).build()) + runtimeEnv + .newCelBuilder() + .setOptions( + CelOptions.current() + .enableHeterogeneousNumericComparisons(true) + .maxParseRecursionDepth(200) + .build()) .build(); CelAbstractSyntaxTree ast = cel.compile(sb.toString()).getAst(); CelOptimizer optimizer = diff --git a/runtime/src/main/java/dev/cel/runtime/PartialVars.java b/runtime/src/main/java/dev/cel/runtime/PartialVars.java index 1cd081040..f195880d0 100644 --- a/runtime/src/main/java/dev/cel/runtime/PartialVars.java +++ b/runtime/src/main/java/dev/cel/runtime/PartialVars.java @@ -37,7 +37,12 @@ public abstract class PartialVars { /** Constructs a new {@code PartialVars} from one or more {@link CelAttributePattern}s. */ public static PartialVars of(CelAttributePattern... unknownAttributes) { - return of((unused) -> Optional.empty(), ImmutableList.copyOf(unknownAttributes)); + return of(ImmutableList.copyOf(unknownAttributes)); + } + + /** Constructs a new {@code PartialVars} from a list of {@link CelAttributePattern}s. */ + public static PartialVars of(Iterable unknownAttributes) { + return of((unused) -> Optional.empty(), unknownAttributes); } /** From 7d73658c7529d7308294e30aee51e23a36d5028a Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 3 Apr 2026 10:17:16 -0700 Subject: [PATCH 29/66] Reject invalid unicode literals in the parser PiperOrigin-RevId: 894137619 --- .../dev/cel/common/internal/Constants.java | 6 ++++++ .../cel/parser/CelParserParameterizedTest.java | 4 ++++ .../src/test/resources/parser_errors.baseline | 18 ++++++++++++++++++ 3 files changed, 28 insertions(+) diff --git a/common/src/main/java/dev/cel/common/internal/Constants.java b/common/src/main/java/dev/cel/common/internal/Constants.java index d2c0719ec..49bca7489 100644 --- a/common/src/main/java/dev/cel/common/internal/Constants.java +++ b/common/src/main/java/dev/cel/common/internal/Constants.java @@ -207,6 +207,9 @@ private static void decodeString( continue; } skipNewline = false; + if (codePoint >= MIN_SURROGATE && codePoint <= MAX_SURROGATE) { + throw new ParseException("Invalid unicode code point", seqOffset); + } buffer.appendCodePoint(codePoint); } else { // Normalize '\r' and '\r\n' to '\n'. @@ -231,6 +234,9 @@ private static void decodeString( // For raw literals, all escapes are valid and those characters come through literally in // the string. buffer.appendCodePoint('\\'); + if (codePoint >= MIN_SURROGATE && codePoint <= MAX_SURROGATE) { + throw new ParseException("Invalid unicode code point", seqOffset); + } buffer.appendCodePoint(codePoint); continue; } diff --git a/parser/src/test/java/dev/cel/parser/CelParserParameterizedTest.java b/parser/src/test/java/dev/cel/parser/CelParserParameterizedTest.java index 58b45ddab..b7474041d 100644 --- a/parser/src/test/java/dev/cel/parser/CelParserParameterizedTest.java +++ b/parser/src/test/java/dev/cel/parser/CelParserParameterizedTest.java @@ -248,6 +248,10 @@ public void parser_errors() { runTest(PARSER, "1 + +"); runTest(PARSER, "\"\\xFh\""); runTest(PARSER, "\"\\a\\b\\f\\n\\r\\t\\v\\'\\\"\\\\\\? Illegal escape \\>\""); + runTest(PARSER, "'\uD800'"); + runTest(PARSER, "'\uDFFF'"); + runTest(PARSER, "r\"\\\uD800\""); + runTest(PARSER, "as"); runTest(PARSER, "break"); runTest(PARSER, "const"); diff --git a/parser/src/test/resources/parser_errors.baseline b/parser/src/test/resources/parser_errors.baseline index 9f4b96825..998bbd487 100644 --- a/parser/src/test/resources/parser_errors.baseline +++ b/parser/src/test/resources/parser_errors.baseline @@ -85,6 +85,24 @@ ERROR: :1:43: mismatched input '' expecting {'[', '{', '(', '.', '-' | "\a\b\f\n\r\t\v\'\"\\\? Illegal escape \>" | ..........................................^ +I: '?' +=====> +E: ERROR: :1:1: Invalid unicode code point + | '?' + | ^ + +I: '?' +=====> +E: ERROR: :1:1: Invalid unicode code point + | '?' + | ^ + +I: r"\?" +=====> +E: ERROR: :1:1: Invalid unicode code point + | r"\?" + | ^ + I: as =====> E: ERROR: :1:1: reserved identifier: as From 664c31b9b58b7086df87ec24f066878431c3fc2d Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 3 Apr 2026 12:53:48 -0700 Subject: [PATCH 30/66] Implement cel.@block for planner PiperOrigin-RevId: 894210773 --- .../main/java/dev/cel/extensions/BUILD.bazel | 1 + .../cel/extensions/CelBindingsExtensions.java | 14 +- .../extensions/CelBindingsExtensionsTest.java | 3 +- .../dev/cel/optimizer/optimizers/BUILD.bazel | 3 + .../SubexpressionOptimizerBaselineTest.java | 159 +++++++++++------- .../SubexpressionOptimizerTest.java | 140 ++++++++++++--- ...old_before_subexpression_unparsed.baseline | 2 +- .../resources/subexpression_unparsed.baseline | 2 +- .../java/dev/cel/runtime/planner/BUILD.bazel | 62 +++---- .../cel/runtime/planner/BlockMemoizer.java | 72 ++++++++ .../dev/cel/runtime/planner/EvalBlock.java | 67 ++++++++ .../cel/runtime/planner/ExecutionFrame.java | 12 ++ .../cel/runtime/planner/ProgramPlanner.java | 59 ++++++- 13 files changed, 466 insertions(+), 130 deletions(-) create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java diff --git a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel index ed2d19d6f..77663f2fa 100644 --- a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel @@ -142,6 +142,7 @@ java_library( deps = [ "//common:compiler_common", "//common/ast", + "//common/types", "//compiler:compiler_builder", "//extensions:extension_library", "//parser:macro", diff --git a/extensions/src/main/java/dev/cel/extensions/CelBindingsExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelBindingsExtensions.java index 5eb2c2e8c..0e6537334 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelBindingsExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelBindingsExtensions.java @@ -22,7 +22,11 @@ import com.google.errorprone.annotations.Immutable; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelIssue; +import dev.cel.common.CelOverloadDecl; import dev.cel.common.ast.CelExpr; +import dev.cel.common.types.ListType; +import dev.cel.common.types.SimpleType; +import dev.cel.common.types.TypeParamType; import dev.cel.compiler.CelCompilerLibrary; import dev.cel.parser.CelMacro; import dev.cel.parser.CelMacroExprFactory; @@ -62,7 +66,15 @@ public int version() { @Override public ImmutableSet functions() { - return ImmutableSet.of(); + // TODO: Add bindings for block once decorator support is available. + return ImmutableSet.of( + CelFunctionDecl.newFunctionDeclaration( + "cel.@block", + CelOverloadDecl.newGlobalOverload( + "cel_block_list", + TypeParamType.create("T"), + ListType.create(SimpleType.DYN), + TypeParamType.create("T")))); } @Override diff --git a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java index bc98c9816..ff9e31432 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java @@ -63,7 +63,8 @@ public void library() { CelExtensions.getExtensionLibrary("bindings", CelOptions.DEFAULT); assertThat(library.name()).isEqualTo("bindings"); assertThat(library.latest().version()).isEqualTo(0); - assertThat(library.version(0).functions()).isEmpty(); + assertThat(library.version(0).functions().stream().map(CelFunctionDecl::name)) + .containsExactly("cel.@block"); assertThat(library.version(0).macros().stream().map(CelMacro::getFunction)) .containsExactly("bind"); } diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel index b0c48682a..734aa6879 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -33,6 +33,9 @@ java_library( "//parser:unparser", "//runtime", "//runtime:function_binding", + "//runtime:partial_vars", + "//runtime:program", + "//runtime:unknown_attributes", "//testing:baseline_test_case", "@maven//:junit_junit", "@maven//:com_google_testparameterinjector_test_parameter_injector", diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java index 802ef3037..74e3b5b32 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java @@ -24,6 +24,7 @@ // import com.google.testing.testsize.MediumTest; import dev.cel.bundle.Cel; import dev.cel.bundle.CelBuilder; +import dev.cel.bundle.CelExperimentalFactory; import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; @@ -43,6 +44,7 @@ import dev.cel.parser.CelUnparserFactory; import dev.cel.runtime.CelFunctionBinding; import dev.cel.testing.BaselineTestCase; +import java.util.EnumSet; import java.util.Optional; import org.junit.Before; import org.junit.Test; @@ -51,6 +53,50 @@ // @MediumTest @RunWith(TestParameterInjector.class) public class SubexpressionOptimizerBaselineTest extends BaselineTestCase { + private enum RuntimeEnv { + LEGACY(setupCelEnv(CelFactory.standardCelBuilder())), + PLANNER(setupCelEnv(CelExperimentalFactory.plannerCelBuilder())); + + private final Cel cel; + + private static Cel setupCelEnv(CelBuilder celBuilder) { + return celBuilder + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setOptions( + CelOptions.current() + .populateMacroCalls(true) + .enableHeterogeneousNumericComparisons(true) + .build()) + .addCompilerLibraries( + CelExtensions.optional(), CelExtensions.bindings(), CelExtensions.comprehensions()) + .addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions()) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "pure_custom_func", + newGlobalOverload("pure_custom_func_overload", SimpleType.INT, SimpleType.INT)), + CelFunctionDecl.newFunctionDeclaration( + "non_pure_custom_func", + newGlobalOverload( + "non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT))) + .addFunctionBindings( + // This is pure, but for the purposes of excluding it as a CSE candidate, pretend that + // it isn't. + CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val), + CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val)) + .addVar("x", SimpleType.DYN) + .addVar("y", SimpleType.DYN) + .addVar("opt_x", OptionalType.create(SimpleType.DYN)) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .build(); + } + + RuntimeEnv(Cel cel) { + this.cel = cel; + } + } + private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); private static final TestAllTypes TEST_ALL_TYPES_INPUT = TestAllTypes.newBuilder() @@ -67,7 +113,6 @@ public class SubexpressionOptimizerBaselineTest extends BaselineTestCase { .putMapInt32Int64(2, 2) .putMapStringString("key", "A"))) .build(); - private static final Cel CEL = newCelBuilder().build(); private static final SubexpressionOptimizerOptions OPTIMIZER_COMMON_OPTIONS = SubexpressionOptimizerOptions.newBuilder() @@ -90,45 +135,49 @@ protected String baselineFileName() { return overriddenBaseFilePath; } + @TestParameter RuntimeEnv runtimeEnv; + @Test public void allOptimizers_producesSameEvaluationResult( @TestParameter CseTestOptimizer cseTestOptimizer, @TestParameter CseTestCase cseTestCase) throws Exception { skipBaselineVerification(); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); ImmutableMap inputMap = ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L)); - Object expectedEvalResult = CEL.createProgram(ast).eval(inputMap); + Object expectedEvalResult = runtimeEnv.cel.createProgram(ast).eval(inputMap); - CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast); + CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.newCseOptimizer(runtimeEnv).optimize(ast); - Object optimizedEvalResult = CEL.createProgram(optimizedAst).eval(inputMap); + Object optimizedEvalResult = runtimeEnv.cel.createProgram(optimizedAst).eval(inputMap); assertThat(optimizedEvalResult).isEqualTo(expectedEvalResult); } @Test public void subexpression_unparsed() throws Exception { - for (CseTestCase cseTestCase : CseTestCase.values()) { + for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); boolean resultPrinted = false; for (CseTestOptimizer cseTestOptimizer : CseTestOptimizer.values()) { String optimizerName = cseTestOptimizer.name(); CelAbstractSyntaxTree optimizedAst; try { - optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast); + optimizedAst = cseTestOptimizer.newCseOptimizer(runtimeEnv).optimize(ast); } catch (Exception e) { testOutput().printf("[%s]: Optimization Error: %s", optimizerName, e); continue; } if (!resultPrinted) { Object optimizedEvalResult = - CEL.createProgram(optimizedAst) + runtimeEnv + .cel + .createProgram(optimizedAst) .eval( ImmutableMap.of( - "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))); + "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L))); testOutput().println("Result: " + optimizedEvalResult); resultPrinted = true; } @@ -145,22 +194,24 @@ public void subexpression_unparsed() throws Exception { @Test public void constfold_before_subexpression_unparsed() throws Exception { - for (CseTestCase cseTestCase : CseTestCase.values()) { + for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); boolean resultPrinted = false; - for (CseTestOptimizer cseTestOptimizer : CseTestOptimizer.values()) { + for (CseTestOptimizer cseTestOptimizer : EnumSet.allOf(CseTestOptimizer.class)) { String optimizerName = cseTestOptimizer.name(); CelAbstractSyntaxTree optimizedAst = - cseTestOptimizer.cseWithConstFoldingOptimizer.optimize(ast); + cseTestOptimizer.newCseWithConstFoldingOptimizer(runtimeEnv).optimize(ast); if (!resultPrinted) { Object optimizedEvalResult = - CEL.createProgram(optimizedAst) + runtimeEnv + .cel + .createProgram(optimizedAst) .eval( ImmutableMap.of( - "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))); + "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L))); testOutput().println("Result: " + optimizedEvalResult); resultPrinted = true; } @@ -179,12 +230,13 @@ public void constfold_before_subexpression_unparsed() throws Exception { public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer) throws Exception { String testBasefileName = "subexpression_ast_" + Ascii.toLowerCase(cseTestOptimizer.name()); overriddenBaseFilePath = String.format("%s%s.baseline", testdataDir(), testBasefileName); - for (CseTestCase cseTestCase : CseTestCase.values()) { + for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); - CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.cseOptimizer.optimize(ast); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree optimizedAst = + newCseOptimizer(runtimeEnv.cel, cseTestOptimizer.option).optimize(ast); testOutput().println(optimizedAst.getExpr()); } } @@ -193,7 +245,8 @@ public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer) public void large_expressions_block_common_subexpr() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - CEL, SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()); + runtimeEnv.cel, + SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()); runLargeTestCases(celOptimizer); } @@ -202,7 +255,7 @@ public void large_expressions_block_common_subexpr() throws Exception { public void large_expressions_block_recursion_depth_1() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - CEL, + runtimeEnv.cel, SubexpressionOptimizerOptions.newBuilder() .populateMacroCalls(true) .subexpressionMaxRecursionDepth(1) @@ -215,7 +268,7 @@ public void large_expressions_block_recursion_depth_1() throws Exception { public void large_expressions_block_recursion_depth_2() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - CEL, + runtimeEnv.cel, SubexpressionOptimizerOptions.newBuilder() .populateMacroCalls(true) .subexpressionMaxRecursionDepth(2) @@ -228,7 +281,7 @@ public void large_expressions_block_recursion_depth_2() throws Exception { public void large_expressions_block_recursion_depth_3() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - CEL, + runtimeEnv.cel, SubexpressionOptimizerOptions.newBuilder() .populateMacroCalls(true) .subexpressionMaxRecursionDepth(3) @@ -238,15 +291,16 @@ public void large_expressions_block_recursion_depth_3() throws Exception { } private void runLargeTestCases(CelOptimizer celOptimizer) throws Exception { - for (CseLargeTestCase cseTestCase : CseLargeTestCase.values()) { + for (CseLargeTestCase cseTestCase : EnumSet.allOf(CseLargeTestCase.class)) { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = CEL.compile(cseTestCase.source).getAst(); - + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); Object optimizedEvalResult = - CEL.createProgram(optimizedAst) + runtimeEnv + .cel + .createProgram(optimizedAst) .eval( ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))); testOutput().println("Result: " + optimizedEvalResult); @@ -260,33 +314,6 @@ private void runLargeTestCases(CelOptimizer celOptimizer) throws Exception { } } - private static CelBuilder newCelBuilder() { - return CelFactory.standardCelBuilder() - .addMessageTypes(TestAllTypes.getDescriptor()) - .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .setOptions(CelOptions.current().populateMacroCalls(true).build()) - .addCompilerLibraries( - CelExtensions.optional(), CelExtensions.bindings(), CelExtensions.comprehensions()) - .addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions()) - .addFunctionDeclarations( - CelFunctionDecl.newFunctionDeclaration( - "pure_custom_func", - newGlobalOverload("pure_custom_func_overload", SimpleType.INT, SimpleType.INT)), - CelFunctionDecl.newFunctionDeclaration( - "non_pure_custom_func", - newGlobalOverload("non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT))) - .addFunctionBindings( - // This is pure, but for the purposes of excluding it as a CSE candidate, pretend that - // it isn't. - CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val), - CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val)) - .addVar("x", SimpleType.DYN) - .addVar("y", SimpleType.DYN) - .addVar("opt_x", OptionalType.create(SimpleType.DYN)) - .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); - } - private static CelOptimizer newCseOptimizer(Cel cel, SubexpressionOptimizerOptions options) { return CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers(SubexpressionOptimizer.newInstance(options)) @@ -315,17 +342,23 @@ private enum CseTestOptimizer { BLOCK_RECURSION_DEPTH_9( OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(9).build()); - private final CelOptimizer cseOptimizer; - private final CelOptimizer cseWithConstFoldingOptimizer; + private final SubexpressionOptimizerOptions option; CseTestOptimizer(SubexpressionOptimizerOptions option) { - this.cseOptimizer = newCseOptimizer(CEL, option); - this.cseWithConstFoldingOptimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) - .addAstOptimizers( - ConstantFoldingOptimizer.getInstance(), - SubexpressionOptimizer.newInstance(option)) - .build(); + this.option = option; + } + + // Defers building the optimizer until the test runs + private CelOptimizer newCseOptimizer(RuntimeEnv env) { + return SubexpressionOptimizerBaselineTest.newCseOptimizer(env.cel, option); + } + + // Defers building the optimizer until the test runs + private CelOptimizer newCseWithConstFoldingOptimizer(RuntimeEnv env) { + return CelOptimizerFactory.standardCelOptimizerBuilder(env.cel) + .addAstOptimizers( + ConstantFoldingOptimizer.getInstance(), SubexpressionOptimizer.newInstance(option)) + .build(); } } diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 2289a7d4a..6e39bab28 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -26,6 +26,7 @@ import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; import dev.cel.bundle.CelBuilder; +import dev.cel.bundle.CelExperimentalFactory; import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; @@ -52,10 +53,14 @@ import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparser; import dev.cel.parser.CelUnparserFactory; +import dev.cel.runtime.CelAttributePattern; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelRuntime; import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.runtime.CelUnknownSet; +import dev.cel.runtime.PartialVars; +import dev.cel.runtime.Program; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; @@ -64,10 +69,40 @@ @RunWith(TestParameterInjector.class) public class SubexpressionOptimizerTest { - private static final Cel CEL = newCelBuilder().build(); + private enum RuntimeEnv { + LEGACY( + setupCelEnv(CelFactory.standardCelBuilder()), + setupCelForEvaluatingBlock(CelFactory.standardCelBuilder())), + PLANNER( + setupCelEnv(CelExperimentalFactory.plannerCelBuilder()), + setupCelForEvaluatingBlock(CelExperimentalFactory.plannerCelBuilder())); - private static final Cel CEL_FOR_EVALUATING_BLOCK = - CelFactory.standardCelBuilder() + private final Cel cel; + private final Cel celForEvaluatingBlock; + + private static Cel setupCelEnv(CelBuilder celBuilder) { + return celBuilder + .addMessageTypes(TestAllTypes.getDescriptor()) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setOptions( + CelOptions.current() + .populateMacroCalls(true) + .enableHeterogeneousNumericComparisons(true) + .build()) + .addCompilerLibraries(CelExtensions.bindings(), CelExtensions.strings()) + .addRuntimeLibraries(CelExtensions.strings()) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "non_pure_custom_func", + newGlobalOverload( + "non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT))) + .addVar("x", SimpleType.DYN) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .build(); + } + + private static Cel setupCelForEvaluatingBlock(CelBuilder celBuilder) { + return celBuilder .setStandardMacros(CelStandardMacro.STANDARD_MACROS) .addFunctionDeclarations( // These are test only declarations, as the actual function is made internal using @ @@ -98,6 +133,15 @@ public class SubexpressionOptimizerTest { .addMessageTypes(TestAllTypes.getDescriptor()) .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) .build(); + } + + RuntimeEnv(Cel cel, Cel celForEvaluatingBlock) { + this.cel = cel; + this.celForEvaluatingBlock = celForEvaluatingBlock; + } + } + + @TestParameter RuntimeEnv runtimeEnv; private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); @@ -115,8 +159,8 @@ private static CelBuilder newCelBuilder() { .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); } - private static CelOptimizer newCseOptimizer(SubexpressionOptimizerOptions options) { - return CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + private CelOptimizer newCseOptimizer(SubexpressionOptimizerOptions options) { + return CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) .addAstOptimizers(SubexpressionOptimizer.newInstance(options)) .build(); } @@ -130,15 +174,56 @@ public void cse_resultTypeSet_celBlockOptimizationSuccess() throws Exception { SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().build())) .build(); - CelAbstractSyntaxTree ast = CEL.compile("size('a') + size('a') == 2").getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("size('a') + size('a') == 2").getAst(); CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); - assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(true); + assertThat(runtimeEnv.cel.createProgram(optimizedAst).eval()).isEqualTo(true); assertThat(CEL_UNPARSER.unparse(optimizedAst)) .isEqualTo("cel.@block([size(\"a\")], @index0 + @index0 == 2)"); } + @Test + public void cse_indexEvaluationErrors_throws() throws Exception { + CelAbstractSyntaxTree ast = + runtimeEnv.cel.compile("\"abc\".charAt(10) + \"abc\".charAt(10)").getAst(); + CelOptimizer optimizedOptimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) + .addAstOptimizers(SubexpressionOptimizer.getInstance()) + .build(); + + CelAbstractSyntaxTree optimizedAst = optimizedOptimizer.optimize(ast); + + String unparsed = CEL_UNPARSER.unparse(optimizedAst); + assertThat(unparsed).isEqualTo("cel.@block([\"abc\".charAt(10)], @index0 + @index0)"); + + Program program = runtimeEnv.cel.createProgram(optimizedAst); + CelEvaluationException e = + assertThrows(CelEvaluationException.class, () -> program.eval(ImmutableMap.of())); + assertThat(e).hasMessageThat().contains("charAt failure: Index out of range: 10"); + } + + @Test + public void cse_withUnknownAttributes() throws Exception { + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("size(\"a\") == 1 ? x.y : x.y").getAst(); + CelOptimizer optimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) + .addAstOptimizers(SubexpressionOptimizer.getInstance()) + .build(); + + CelAbstractSyntaxTree optimizedAst = optimizer.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)) + .isEqualTo("cel.@block([x.y], (size(\"a\") == 1) ? @index0 : @index0)"); + + Object result = + runtimeEnv + .cel + .createProgram(optimizedAst) + .eval(PartialVars.of(CelAttributePattern.fromQualifiedIdentifier("x"))); + assertThat(result).isInstanceOf(CelUnknownSet.class); + } + private enum CseNoOpTestCase { // Nothing to optimize NO_COMMON_SUBEXPR("size(\"hello\")"), @@ -169,7 +254,7 @@ private enum CseNoOpTestCase { @Test public void cse_withCelBind_noop(@TestParameter CseNoOpTestCase testCase) throws Exception { - CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(testCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = newCseOptimizer(SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()) @@ -181,7 +266,7 @@ public void cse_withCelBind_noop(@TestParameter CseNoOpTestCase testCase) throws @Test public void cse_withCelBlock_noop(@TestParameter CseNoOpTestCase testCase) throws Exception { - CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(testCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = newCseOptimizer(SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()) @@ -194,7 +279,7 @@ public void cse_withCelBlock_noop(@TestParameter CseNoOpTestCase testCase) throw @Test public void cse_withComprehensionStructureRetained() throws Exception { CelAbstractSyntaxTree ast = - CEL.compile("['foo'].map(x, [x+x]) + ['foo'].map(x, [x+x, x+x])").getAst(); + runtimeEnv.cel.compile("['foo'].map(x, [x+x]) + ['foo'].map(x, [x+x, x+x])").getAst(); CelOptimizer celOptimizer = newCseOptimizer( SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()); @@ -210,10 +295,12 @@ public void cse_withComprehensionStructureRetained() throws Exception { @Test public void cse_applyConstFoldingBefore() throws Exception { CelAbstractSyntaxTree ast = - CEL.compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") + runtimeEnv + .cel + .compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") .getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) .addAstOptimizers( ConstantFoldingOptimizer.getInstance(), SubexpressionOptimizer.newInstance( @@ -228,10 +315,12 @@ public void cse_applyConstFoldingBefore() throws Exception { @Test public void cse_applyConstFoldingAfter() throws Exception { CelAbstractSyntaxTree ast = - CEL.compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") + runtimeEnv + .cel + .compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") .getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) .addAstOptimizers( SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().build()), @@ -246,9 +335,9 @@ public void cse_applyConstFoldingAfter() throws Exception { @Test public void cse_applyConstFoldingAfter_nothingToFold() throws Exception { - CelAbstractSyntaxTree ast = CEL.compile("size(x) + size(x)").getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("size(x) + size(x)").getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) .addAstOptimizers( SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()), @@ -271,7 +360,7 @@ public void iterationLimitReached_throws() throws Exception { largeExprBuilder.append("+"); } } - CelAbstractSyntaxTree ast = CEL.compile(largeExprBuilder.toString()).getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(largeExprBuilder.toString()).getAst(); CelOptimizationException e = assertThrows( @@ -287,9 +376,9 @@ public void iterationLimitReached_throws() throws Exception { @Test public void celBlock_astExtensionTagged() throws Exception { - CelAbstractSyntaxTree ast = CEL.compile("size(x) + size(x)").getAst(); + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("size(x) + size(x)").getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) .addAstOptimizers( SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()), @@ -322,7 +411,7 @@ private enum BlockTestCase { public void block_success(@TestParameter BlockTestCase testCase) throws Exception { CelAbstractSyntaxTree ast = compileUsingInternalFunctions(testCase.source); - Object evaluatedResult = CEL_FOR_EVALUATING_BLOCK.createProgram(ast).eval(); + Object evaluatedResult = runtimeEnv.celForEvaluatingBlock.createProgram(ast).eval(); assertThat(evaluatedResult).isNotNull(); } @@ -584,7 +673,7 @@ public void block_containsCycle_throws() throws Exception { CelAbstractSyntaxTree ast = compileUsingInternalFunctions("cel.block([index1,index0],index0)"); CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> CEL.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> runtimeEnv.cel.createProgram(ast).eval()); assertThat(e).hasMessageThat().contains("Cycle detected: @index0"); } @@ -595,7 +684,7 @@ public void block_lazyEvaluationContainsError_cleansUpCycleState() throws Except "cel.block([1/0 > 0], (index0 && false) || (index0 && true))"); CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> CEL.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> runtimeEnv.cel.createProgram(ast).eval()); assertThat(e).hasMessageThat().contains("/ by zero"); assertThat(e).hasMessageThat().doesNotContain("Cycle detected"); @@ -605,9 +694,10 @@ public void block_lazyEvaluationContainsError_cleansUpCycleState() throws Except * Converts AST containing cel.block related test functions to internal functions (e.g: cel.block * -> cel.@block) */ - private static CelAbstractSyntaxTree compileUsingInternalFunctions(String expression) + private CelAbstractSyntaxTree compileUsingInternalFunctions(String expression) throws CelValidationException { - CelAbstractSyntaxTree astToModify = CEL_FOR_EVALUATING_BLOCK.compile(expression).getAst(); + CelAbstractSyntaxTree astToModify = + runtimeEnv.celForEvaluatingBlock.compile(expression).getAst(); CelMutableAst mutableAst = CelMutableAst.fromCelAst(astToModify); CelNavigableMutableAst.fromAst(mutableAst) .getRoot() @@ -629,6 +719,6 @@ private static CelAbstractSyntaxTree compileUsingInternalFunctions(String expres indexExpr.ident().setName(internalIdentName); }); - return CEL_FOR_EVALUATING_BLOCK.check(mutableAst.toParsedAst()).getAst(); + return runtimeEnv.celForEvaluatingBlock.check(mutableAst.toParsedAst()).getAst(); } } diff --git a/optimizer/src/test/resources/constfold_before_subexpression_unparsed.baseline b/optimizer/src/test/resources/constfold_before_subexpression_unparsed.baseline index 55da856cd..9139c7a35 100644 --- a/optimizer/src/test/resources/constfold_before_subexpression_unparsed.baseline +++ b/optimizer/src/test/resources/constfold_before_subexpression_unparsed.baseline @@ -526,7 +526,7 @@ Result: [[[foofoo, foofoo, foofoo, foofoo], [foofoo, foofoo, foofoo, foofoo]], [ Test case: MACRO_SHADOWED_VARIABLE_COMP_V2_1 Source: [x - y - 1 > 3 ? x - y - 1 : 5].exists(x, y, x - y - 1 > 3) || x - y - 1 > 3 =====> -Result: CelUnknownSet{attributes=[], unknownExprIds=[6]} +Result: false [BLOCK_COMMON_SUBEXPR_ONLY]: cel.@block([x - y - 1, @index0 > 3], [@index1 ? @index0 : 5].exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index1) [BLOCK_RECURSION_DEPTH_1]: cel.@block([x - y, @index0 - 1, @index1 > 3, @index2 ? @index1 : 5, [@index3]], @index4.exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index2) [BLOCK_RECURSION_DEPTH_2]: cel.@block([x - y - 1, @index0 > 3, [@index1 ? @index0 : 5]], @index2.exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index1) diff --git a/optimizer/src/test/resources/subexpression_unparsed.baseline b/optimizer/src/test/resources/subexpression_unparsed.baseline index e0edc8987..780664a14 100644 --- a/optimizer/src/test/resources/subexpression_unparsed.baseline +++ b/optimizer/src/test/resources/subexpression_unparsed.baseline @@ -526,7 +526,7 @@ Result: [[[foofoo, foofoo, foofoo, foofoo], [foofoo, foofoo, foofoo, foofoo]], [ Test case: MACRO_SHADOWED_VARIABLE_COMP_V2_1 Source: [x - y - 1 > 3 ? x - y - 1 : 5].exists(x, y, x - y - 1 > 3) || x - y - 1 > 3 =====> -Result: CelUnknownSet{attributes=[], unknownExprIds=[6]} +Result: false [BLOCK_COMMON_SUBEXPR_ONLY]: cel.@block([x - y - 1, @index0 > 3], [@index1 ? @index0 : 5].exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index1) [BLOCK_RECURSION_DEPTH_1]: cel.@block([x - y, @index0 - 1, @index1 > 3, @index2 ? @index1 : 5, [@index3]], @index4.exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index2) [BLOCK_RECURSION_DEPTH_2]: cel.@block([x - y - 1, @index0 > 3, [@index1 ? @index0 : 5]], @index2.exists(@it:0:0, @it2:0:0, @it:0:0 - @it2:0:0 - 1 > 3) || @index1) diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel index fc70118e4..cb2ad5a82 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel @@ -18,6 +18,7 @@ java_library( ":eval_and", ":eval_attribute", ":eval_binary", + ":eval_block", ":eval_conditional", ":eval_const", ":eval_create_list", @@ -67,7 +68,6 @@ java_library( srcs = ["PlannedProgram.java"], deps = [ ":error_metadata", - ":execution_frame", ":localized_evaluation_exception", ":planned_interpretable", "//:auto_value", @@ -92,11 +92,9 @@ java_library( name = "eval_const", srcs = ["EvalConstant.java"], deps = [ - ":execution_frame", ":planned_interpretable", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", - "@maven//:com_google_guava_guava", ], ) @@ -123,7 +121,6 @@ java_library( deps = [ ":activation_wrapper", ":eval_helpers", - ":execution_frame", ":planned_interpretable", ":qualifier", "//common:container", @@ -183,8 +180,8 @@ java_library( srcs = ["EvalAttribute.java"], deps = [ ":attribute", - ":execution_frame", ":interpretable_attribute", + ":planned_interpretable", ":qualifier", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", @@ -195,8 +192,8 @@ java_library( name = "eval_test_only", srcs = ["EvalTestOnly.java"], deps = [ - ":execution_frame", ":interpretable_attribute", + ":planned_interpretable", ":presence_test_qualifier", ":qualifier", "//runtime:evaluation_exception", @@ -210,7 +207,6 @@ java_library( srcs = ["EvalZeroArity.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:evaluation_exception", @@ -224,7 +220,6 @@ java_library( srcs = ["EvalUnary.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:evaluation_exception", @@ -238,7 +233,6 @@ java_library( srcs = ["EvalBinary.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -253,7 +247,6 @@ java_library( srcs = ["EvalVarArgsCall.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -268,7 +261,6 @@ java_library( srcs = ["EvalLateBoundCall.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/exceptions:overload_not_found", "//common/values", @@ -285,7 +277,6 @@ java_library( srcs = ["EvalOr.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -299,7 +290,6 @@ java_library( srcs = ["EvalAnd.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -312,7 +302,6 @@ java_library( name = "eval_conditional", srcs = ["EvalConditional.java"], deps = [ - ":execution_frame", ":planned_interpretable", "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", @@ -326,7 +315,6 @@ java_library( srcs = ["EvalCreateStruct.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/types:type_providers", "//common/values", @@ -344,7 +332,6 @@ java_library( srcs = ["EvalCreateList.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", @@ -359,7 +346,6 @@ java_library( srcs = ["EvalCreateMap.java"], deps = [ ":eval_helpers", - ":execution_frame", ":localized_evaluation_exception", ":planned_interpretable", "//common/exceptions:duplicate_key", @@ -377,7 +363,6 @@ java_library( srcs = ["EvalFold.java"], deps = [ ":activation_wrapper", - ":execution_frame", ":planned_interpretable", "//runtime:accumulated_unknowns", "//runtime:concatenated_list_view", @@ -389,24 +374,10 @@ java_library( ], ) -java_library( - name = "execution_frame", - srcs = ["ExecutionFrame.java"], - deps = [ - "//common:options", - "//common/exceptions:iteration_budget_exceeded", - "//runtime:evaluation_exception", - "//runtime:function_resolver", - "//runtime:partial_vars", - "//runtime:resolved_overload", - ], -) - java_library( name = "eval_helpers", srcs = ["EvalHelpers.java"], deps = [ - ":execution_frame", ":localized_evaluation_exception", ":planned_interpretable", "//common:error_codes", @@ -440,11 +411,20 @@ java_library( java_library( name = "planned_interpretable", - srcs = ["PlannedInterpretable.java"], + srcs = [ + "BlockMemoizer.java", + "ExecutionFrame.java", + "PlannedInterpretable.java", + ], deps = [ - ":execution_frame", + ":localized_evaluation_exception", + "//common:options", + "//common/exceptions:iteration_budget_exceeded", "//runtime:evaluation_exception", + "//runtime:function_resolver", "//runtime:interpretable", + "//runtime:partial_vars", + "//runtime:resolved_overload", "@maven//:com_google_errorprone_error_prone_annotations", ], ) @@ -454,7 +434,6 @@ java_library( srcs = ["EvalOptionalOr.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/exceptions:overload_not_found", "//runtime:accumulated_unknowns", @@ -469,7 +448,6 @@ java_library( srcs = ["EvalOptionalOrValue.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/exceptions:overload_not_found", "//runtime:accumulated_unknowns", @@ -484,7 +462,6 @@ java_library( srcs = ["EvalOptionalSelectField.java"], deps = [ ":eval_helpers", - ":execution_frame", ":planned_interpretable", "//common/values", "//runtime:accumulated_unknowns", @@ -493,3 +470,14 @@ java_library( "@maven//:com_google_guava_guava", ], ) + +java_library( + name = "eval_block", + srcs = ["EvalBlock.java"], + deps = [ + ":planned_interpretable", + "//runtime:evaluation_exception", + "//runtime:interpretable", + "@maven//:com_google_errorprone_error_prone_annotations", + ], +) diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java b/runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java new file mode 100644 index 000000000..978029b3d --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java @@ -0,0 +1,72 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.GlobalResolver; +import java.util.Arrays; + +/** Handles memoization, lazy evaluation, and cycle detection for cel.@block slots. */ +final class BlockMemoizer { + + private static final Object IN_PROGRESS = new Object(); + private static final Object UNSET = new Object(); + + private final PlannedInterpretable[] slotExprs; + private final Object[] slotVals; + private final ExecutionFrame frame; + + static BlockMemoizer create(PlannedInterpretable[] slotExprs, ExecutionFrame frame) { + return new BlockMemoizer(slotExprs, frame); + } + + private BlockMemoizer(PlannedInterpretable[] slotExprs, ExecutionFrame frame) { + this.slotExprs = slotExprs; + this.frame = frame; + this.slotVals = new Object[slotExprs.length]; + Arrays.fill(this.slotVals, UNSET); + } + + Object resolveSlot(int idx, GlobalResolver resolver) { + Object val = slotVals[idx]; + + // Already evaluated + if (val != UNSET && val != IN_PROGRESS) { + if (val instanceof RuntimeException) { + throw (RuntimeException) val; + } + return val; + } + + if (val == IN_PROGRESS) { + throw new IllegalStateException("Cycle detected: @index" + idx); + } + + slotVals[idx] = IN_PROGRESS; + try { + Object result = slotExprs[idx].eval(resolver, frame); + slotVals[idx] = result; + return result; + } catch (CelEvaluationException e) { + LocalizedEvaluationException localizedException = + new LocalizedEvaluationException(e, e.getErrorCode(), slotExprs[idx].exprId()); + slotVals[idx] = localizedException; + throw localizedException; + } catch (RuntimeException e) { + slotVals[idx] = e; + throw e; + } + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java new file mode 100644 index 000000000..41ad4034e --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java @@ -0,0 +1,67 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import com.google.errorprone.annotations.Immutable; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.GlobalResolver; + +/** Eval implementation of {@code cel.@block}. */ +@Immutable +final class EvalBlock extends PlannedInterpretable { + + @SuppressWarnings("Immutable") // Array not mutated after creation + private final PlannedInterpretable[] slotExprs; + + private final PlannedInterpretable resultExpr; + + static EvalBlock create( + long exprId, PlannedInterpretable[] slotExprs, PlannedInterpretable resultExpr) { + return new EvalBlock(exprId, slotExprs, resultExpr); + } + + private EvalBlock( + long exprId, PlannedInterpretable[] slotExprs, PlannedInterpretable resultExpr) { + super(exprId); + this.slotExprs = slotExprs; + this.resultExpr = resultExpr; + } + + @Override + public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + BlockMemoizer memoizer = BlockMemoizer.create(slotExprs, frame); + frame.setBlockMemoizer(memoizer); + return resultExpr.eval(resolver, frame); + } + + @Immutable + static final class EvalBlockSlot extends PlannedInterpretable { + private final int slotIndex; + + static EvalBlockSlot create(long exprId, int slotIndex) { + return new EvalBlockSlot(exprId, slotIndex); + } + + private EvalBlockSlot(long exprId, int slotIndex) { + super(exprId); + this.slotIndex = slotIndex; + } + + @Override + public Object eval(GlobalResolver resolver, ExecutionFrame frame) { + return frame.getBlockMemoizer().resolveSlot(slotIndex, resolver); + } + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java b/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java index e29c68dd8..282b7c83a 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java @@ -30,6 +30,7 @@ final class ExecutionFrame { private final CelFunctionResolver functionResolver; private final PartialVars partialVars; private int iterationCount; + private BlockMemoizer blockMemoizer; Optional findOverload( String functionName, Collection overloadIds, Object[] args) @@ -49,6 +50,17 @@ void incrementIterations() { } } + void setBlockMemoizer(BlockMemoizer blockMemoizer) { + if (this.blockMemoizer != null) { + throw new IllegalStateException("BlockMemoizer is already initialized"); + } + this.blockMemoizer = blockMemoizer; + } + + BlockMemoizer getBlockMemoizer() { + return blockMemoizer; + } + static ExecutionFrame create( CelFunctionResolver functionResolver, PartialVars partialVars, CelOptions celOptions) { return new ExecutionFrame( diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java index add918f64..9bd5f3ecd 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java @@ -164,6 +164,11 @@ private PlannedInterpretable planIdent(CelExpr celExpr, PlannerContext ctx) { } String identName = celExpr.ident().name(); + PlannedInterpretable blockSlot = maybeInterceptBlockSlot(celExpr.id(), identName).orElse(null); + if (blockSlot != null) { + return blockSlot; + } + if (ctx.isLocalVar(identName)) { return EvalAttribute.create(celExpr.id(), attributeFactory.newAbsoluteAttribute(identName)); } @@ -196,11 +201,42 @@ private PlannedInterpretable planCheckedIdent( return EvalConstant.create(id, identType); } + String identName = identRef.name(); + PlannedInterpretable blockSlot = maybeInterceptBlockSlot(id, identName).orElse(null); + if (blockSlot != null) { + return blockSlot; + } + return EvalAttribute.create(id, attributeFactory.newAbsoluteAttribute(identRef.name())); } + private Optional maybeInterceptBlockSlot(long id, String identName) { + if (!identName.startsWith("@index")) { + return Optional.empty(); + } + if (identName.length() <= 6) { + throw new IllegalArgumentException("Malformed block slot identifier: " + identName); + } + try { + int slotIndex = Integer.parseInt(identName.substring(6)); + if (slotIndex < 0) { + throw new IllegalArgumentException("Negative block slot index: " + identName); + } + return Optional.of(EvalBlock.EvalBlockSlot.create(id, slotIndex)); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid block slot index: " + identName, e); + } + } + private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { ResolvedFunction resolvedFunction = resolveFunction(expr, ctx.referenceMap()); + String functionName = resolvedFunction.functionName(); + + PlannedInterpretable blockCall = maybeInterceptBlockCall(functionName, expr, ctx).orElse(null); + if (blockCall != null) { + return blockCall; + } + CelExpr target = resolvedFunction.target().orElse(null); int argCount = expr.call().args().size(); if (target != null) { @@ -220,7 +256,6 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { evaluatedArgs[argIndex + offset] = plan(args.get(argIndex), ctx); } - String functionName = resolvedFunction.functionName(); Operator operator = Operator.findReverse(functionName).orElse(null); if (operator != null) { switch (operator) { @@ -285,6 +320,28 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { } } + private Optional maybeInterceptBlockCall( + String functionName, CelExpr expr, PlannerContext ctx) { + if (!functionName.equals("cel.@block")) { + return Optional.empty(); + } + + CelCall blockCall = expr.call(); + + if (blockCall.args().size() != 2) { + throw new IllegalArgumentException( + "Expected 2 arguments for cel.@block call. Got: " + blockCall.args().size()); + } + + CelList exprList = blockCall.args().get(0).list(); + PlannedInterpretable[] slotExprs = new PlannedInterpretable[exprList.elements().size()]; + for (int i = 0; i < slotExprs.length; i++) { + slotExprs[i] = plan(exprList.elements().get(i), ctx); + } + PlannedInterpretable resultExpr = plan(blockCall.args().get(1), ctx); + return Optional.of(EvalBlock.create(expr.id(), slotExprs, resultExpr)); + } + /** * Intercepts a potential optional function call. * From 288c3b935d7c1a3df47921a73b3605e347649b02 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Mon, 13 Apr 2026 17:41:17 -0700 Subject: [PATCH 31/66] Add cel.@block test coverage for parsed-only mode PiperOrigin-RevId: 899267021 --- .../CelComprehensionsExtensions.java | 35 +++++++------------ .../SubexpressionOptimizerBaselineTest.java | 31 ++++++++++++++-- .../SubexpressionOptimizerTest.java | 23 +++++++++++- 3 files changed, 63 insertions(+), 26 deletions(-) diff --git a/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java index 23663f02e..7c298a773 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java @@ -118,29 +118,18 @@ public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) { @Override public void setRuntimeOptions( CelRuntimeBuilder runtimeBuilder, RuntimeEquality runtimeEquality, CelOptions celOptions) { - for (Function function : functions) { - for (CelOverloadDecl overload : function.functionDecl.overloads()) { - switch (overload.overloadId()) { - case MAP_INSERT_OVERLOAD_MAP_MAP: - runtimeBuilder.addFunctionBindings( - CelFunctionBinding.from( - MAP_INSERT_OVERLOAD_MAP_MAP, - Map.class, - Map.class, - (map1, map2) -> mapInsertMap(map1, map2, runtimeEquality))); - break; - case MAP_INSERT_OVERLOAD_KEY_VALUE: - runtimeBuilder.addFunctionBindings( - CelFunctionBinding.from( - MAP_INSERT_OVERLOAD_KEY_VALUE, - ImmutableList.of(Map.class, Object.class, Object.class), - args -> mapInsertKeyValue(args, runtimeEquality))); - break; - default: - // Nothing to add. - } - } - } + runtimeBuilder.addFunctionBindings( + CelFunctionBinding.fromOverloads( + MAP_INSERT_FUNCTION, + CelFunctionBinding.from( + MAP_INSERT_OVERLOAD_MAP_MAP, + Map.class, + Map.class, + (map1, map2) -> mapInsertMap(map1, map2, runtimeEquality)), + CelFunctionBinding.from( + MAP_INSERT_OVERLOAD_KEY_VALUE, + ImmutableList.of(Map.class, Object.class, Object.class), + args -> mapInsertKeyValue(args, runtimeEquality)))); } @Override diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java index 74e3b5b32..9db04ceac 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java @@ -83,8 +83,13 @@ private static Cel setupCelEnv(CelBuilder celBuilder) { .addFunctionBindings( // This is pure, but for the purposes of excluding it as a CSE candidate, pretend that // it isn't. - CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val), - CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val)) + CelFunctionBinding.fromOverloads( + "non_pure_custom_func", + CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val))) + .addFunctionBindings( + CelFunctionBinding.fromOverloads( + "pure_custom_func", + CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val))) .addVar("x", SimpleType.DYN) .addVar("y", SimpleType.DYN) .addVar("opt_x", OptionalType.create(SimpleType.DYN)) @@ -153,6 +158,28 @@ public void allOptimizers_producesSameEvaluationResult( assertThat(optimizedEvalResult).isEqualTo(expectedEvalResult); } + @Test + public void allOptimizers_producesSameEvaluationResult_parsedOnly( + @TestParameter CseTestCase cseTestCase, @TestParameter CseTestOptimizer cseTestOptimizer) + throws Exception { + skipBaselineVerification(); + if (runtimeEnv == RuntimeEnv.LEGACY) { + return; + } + CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); + ImmutableMap inputMap = + ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L)); + Object expectedEvalResult = runtimeEnv.cel.createProgram(ast).eval(inputMap); + + CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.newCseOptimizer(runtimeEnv).optimize(ast); + CelAbstractSyntaxTree parsedOnlyOptimizedAst = + CelAbstractSyntaxTree.newParsedAst(optimizedAst.getExpr(), optimizedAst.getSource()); + + Object optimizedEvalResult = + runtimeEnv.cel.createProgram(parsedOnlyOptimizedAst).eval(inputMap); + assertThat(optimizedEvalResult).isEqualTo(expectedEvalResult); + } + @Test public void subexpression_unparsed() throws Exception { for (CseTestCase cseTestCase : EnumSet.allOf(CseTestCase.class)) { diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 6e39bab28..23459e5d8 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -416,6 +416,19 @@ public void block_success(@TestParameter BlockTestCase testCase) throws Exceptio assertThat(evaluatedResult).isNotNull(); } + @Test + public void block_success_parsedOnly(@TestParameter BlockTestCase testCase) throws Exception { + if (runtimeEnv == RuntimeEnv.LEGACY) { + return; + } + CelAbstractSyntaxTree ast = + compileUsingInternalFunctions(testCase.source, /* parsedOnly= */ true); + + Object evaluatedResult = runtimeEnv.celForEvaluatingBlock.createProgram(ast).eval(); + + assertThat(evaluatedResult).isNotNull(); + } + @Test @SuppressWarnings("Immutable") // Test only public void lazyEval_blockIndexNeverReferenced() throws Exception { @@ -694,7 +707,7 @@ public void block_lazyEvaluationContainsError_cleansUpCycleState() throws Except * Converts AST containing cel.block related test functions to internal functions (e.g: cel.block * -> cel.@block) */ - private CelAbstractSyntaxTree compileUsingInternalFunctions(String expression) + private CelAbstractSyntaxTree compileUsingInternalFunctions(String expression, boolean parsedOnly) throws CelValidationException { CelAbstractSyntaxTree astToModify = runtimeEnv.celForEvaluatingBlock.compile(expression).getAst(); @@ -719,6 +732,14 @@ private CelAbstractSyntaxTree compileUsingInternalFunctions(String expression) indexExpr.ident().setName(internalIdentName); }); + if (parsedOnly) { + return mutableAst.toParsedAst(); + } return runtimeEnv.celForEvaluatingBlock.check(mutableAst.toParsedAst()).getAst(); } + + private CelAbstractSyntaxTree compileUsingInternalFunctions(String expression) + throws CelValidationException { + return compileUsingInternalFunctions(expression, false); + } } From 207dca5fa5be53a60099f5f47167d57bbf134013 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Mon, 13 Apr 2026 17:50:57 -0700 Subject: [PATCH 32/66] Refactor tests to inject a runtime environment to invoke planner and legacy tests PiperOrigin-RevId: 899271038 --- .../src/test/java/dev/cel/bundle/BUILD.bazel | 2 +- .../test/java/dev/cel/bundle/CelImplTest.java | 60 +++--- .../dev/cel/optimizer/optimizers/BUILD.bazel | 2 +- .../ConstantFoldingOptimizerTest.java | 151 ++++++------- .../SubexpressionOptimizerBaselineTest.java | 152 ++++++-------- .../SubexpressionOptimizerTest.java | 198 ++++++++---------- testing/BUILD.bazel | 5 + .../src/main/java/dev/cel/testing/BUILD.bazel | 9 + .../dev/cel/testing/CelRuntimeFlavor.java | 38 ++++ 9 files changed, 306 insertions(+), 311 deletions(-) create mode 100644 testing/src/main/java/dev/cel/testing/CelRuntimeFlavor.java diff --git a/bundle/src/test/java/dev/cel/bundle/BUILD.bazel b/bundle/src/test/java/dev/cel/bundle/BUILD.bazel index 2901e1ff9..265f6d89c 100644 --- a/bundle/src/test/java/dev/cel/bundle/BUILD.bazel +++ b/bundle/src/test/java/dev/cel/bundle/BUILD.bazel @@ -17,7 +17,6 @@ java_library( deps = [ "//:java_truth", "//bundle:cel", - "//bundle:cel_experimental_factory", "//bundle:cel_impl", "//bundle:environment", "//bundle:environment_exception", @@ -56,6 +55,7 @@ java_library( "//runtime:evaluation_listener", "//runtime:function_binding", "//runtime:unknown_attributes", + "//testing:cel_runtime_flavor", "//testing/protos:single_file_extension_java_proto", "//testing/protos:single_file_java_proto", "@cel_spec//proto/cel/expr:checked_java_proto", diff --git a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java index 22ef7e2f4..a3ad60d40 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java @@ -114,6 +114,7 @@ import dev.cel.runtime.CelUnknownSet; import dev.cel.runtime.CelVariableResolver; import dev.cel.runtime.UnknownContext; +import dev.cel.testing.CelRuntimeFlavor; import dev.cel.testing.testdata.SingleFile; import dev.cel.testing.testdata.SingleFileExtensionsProto; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; @@ -2144,8 +2145,9 @@ public void toBuilder_isImmutable() { } @Test - public void eval_withJsonFieldName(@TestParameter RuntimeEnv runtimeEnv) throws Exception { - Cel cel = runtimeEnv.cel; + public void eval_withJsonFieldName(@TestParameter CelRuntimeFlavor runtimeFlavor) + throws Exception { + Cel cel = setupEnv(runtimeFlavor.builder()); CelAbstractSyntaxTree ast = cel.compile( "file.int32_snake_case_json_name == 1 && " @@ -2176,8 +2178,9 @@ public void eval_withJsonFieldName(@TestParameter RuntimeEnv runtimeEnv) throws } @Test - public void eval_withJsonFieldName_fieldsFallBack(@TestParameter RuntimeEnv runtimeEnv) throws Exception { - Cel cel = runtimeEnv.cel; + public void eval_withJsonFieldName_fieldsFallBack(@TestParameter CelRuntimeFlavor runtimeFlavor) + throws Exception { + Cel cel = setupEnv(runtimeFlavor.builder()); CelAbstractSyntaxTree ast = cel.compile( "dyn(file).int32_snake_case_json_name == 1 && " @@ -2206,8 +2209,9 @@ public void eval_withJsonFieldName_fieldsFallBack(@TestParameter RuntimeEnv runt } @Test - public void eval_withJsonFieldName_extensionFields(@TestParameter RuntimeEnv runtimeEnv) throws Exception { - Cel cel = runtimeEnv.cel; + public void eval_withJsonFieldName_extensionFields(@TestParameter CelRuntimeFlavor runtimeFlavor) + throws Exception { + Cel cel = setupEnv(runtimeFlavor.builder()); CelAbstractSyntaxTree ast = cel.compile( "proto.getExt(file, dev.cel.testing.testdata.int64CamelCaseJsonName) == 5 &&" @@ -2317,33 +2321,21 @@ private static TypeProvider aliasingProvider(ImmutableMap typeAlia }; } - private enum RuntimeEnv { - LEGACY(setupEnv(CelFactory.standardCelBuilder())), - PLANNER(setupEnv(CelExperimentalFactory.plannerCelBuilder())) - ; - - private final Cel cel; - - private static Cel setupEnv(CelBuilder celBuilder) { - ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); - SingleFileExtensionsProto.registerAllExtensions(extensionRegistry); - return celBuilder - .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) - .addMessageTypes(SingleFile.getDescriptor()) - .addFileTypes(SingleFileExtensionsProto.getDescriptor()) - .addCompilerLibraries(CelExtensions.protos()) - .setExtensionRegistry(extensionRegistry) - .setOptions( - CelOptions.current() - .enableJsonFieldNames(true) - .enableHeterogeneousNumericComparisons(true) - .enableQuotedIdentifierSyntax(true) - .build()) - .build(); - } - - RuntimeEnv(Cel cel) { - this.cel = cel; - } + private static Cel setupEnv(CelBuilder celBuilder) { + ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); + SingleFileExtensionsProto.registerAllExtensions(extensionRegistry); + return celBuilder + .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) + .addMessageTypes(SingleFile.getDescriptor()) + .addFileTypes(SingleFileExtensionsProto.getDescriptor()) + .addCompilerLibraries(CelExtensions.protos()) + .setExtensionRegistry(extensionRegistry) + .setOptions( + CelOptions.current() + .enableJsonFieldNames(true) + .enableHeterogeneousNumericComparisons(true) + .enableQuotedIdentifierSyntax(true) + .build()) + .build(); } } diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel index 734aa6879..d1220a41a 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -11,7 +11,6 @@ java_library( deps = [ # "//java/com/google/testing/testsize:annotations", "//bundle:cel", - "//bundle:cel_experimental_factory", "//common:cel_ast", "//common:cel_source", "//common:compiler_common", @@ -37,6 +36,7 @@ java_library( "//runtime:program", "//runtime:unknown_attributes", "//testing:baseline_test_case", + "//testing:cel_runtime_flavor", "@maven//:junit_junit", "@maven//:com_google_testparameterinjector_test_parameter_injector", "//:java_truth", diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java index bbb5c6e7e..33dc2d941 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -23,8 +23,6 @@ import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; import dev.cel.bundle.CelBuilder; -import dev.cel.bundle.CelExperimentalFactory; -import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelFunctionDecl; @@ -44,6 +42,8 @@ import dev.cel.parser.CelUnparser; import dev.cel.parser.CelUnparserFactory; import dev.cel.runtime.CelFunctionBinding; +import dev.cel.testing.CelRuntimeFlavor; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -57,72 +57,57 @@ public class ConstantFoldingOptimizerTest { private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); - @SuppressWarnings("ImmutableEnumChecker") // test only - private enum RuntimeEnv { - LEGACY(setupEnv(CelFactory.standardCelBuilder())), - PLANNER(setupEnv(CelExperimentalFactory.plannerCelBuilder())); - - private final Cel cel; - private final CelOptimizer celOptimizer; - - private static Cel setupEnv(CelBuilder celBuilder) { - return celBuilder - .addVar("x", SimpleType.DYN) - .addVar("y", SimpleType.DYN) - .addVar("list_var", ListType.create(SimpleType.STRING)) - .addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.STRING)) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addFunctionDeclarations( - CelFunctionDecl.newFunctionDeclaration( - "get_true", - CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)), - CelFunctionDecl.newFunctionDeclaration( - "get_list", - CelOverloadDecl.newGlobalOverload( - "get_list_overload", - ListType.create(SimpleType.INT), - ListType.create(SimpleType.INT)))) - .addFunctionBindings( - CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true)) - .addMessageTypes(TestAllTypes.getDescriptor()) - .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) - .setOptions(CEL_OPTIONS) - .addCompilerLibraries( - CelExtensions.bindings(), - CelOptionalLibrary.INSTANCE, - CelExtensions.math(CEL_OPTIONS), - CelExtensions.strings(), - CelExtensions.sets(CEL_OPTIONS), - CelExtensions.encoders(CEL_OPTIONS)) - .addRuntimeLibraries( - CelOptionalLibrary.INSTANCE, - CelExtensions.math(CEL_OPTIONS), - CelExtensions.strings(), - CelExtensions.sets(CEL_OPTIONS), - CelExtensions.encoders(CEL_OPTIONS)) - .build(); - } - - RuntimeEnv(Cel cel) { - this.cel = cel; - this.celOptimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(cel) - .addAstOptimizers(ConstantFoldingOptimizer.getInstance()) - .build(); - } - - private CelBuilder newCelBuilder() { - switch (this) { - case LEGACY: - return CelFactory.standardCelBuilder(); - case PLANNER: - return CelExperimentalFactory.plannerCelBuilder(); - } - throw new AssertionError("Unknown RuntimeEnv: " + this); - } + @TestParameter CelRuntimeFlavor runtimeFlavor; + + private Cel cel; + private CelOptimizer celOptimizer; + + @Before + public void setUp() { + this.cel = setupEnv(runtimeFlavor.builder()); + this.celOptimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(this.cel) + .addAstOptimizers(ConstantFoldingOptimizer.getInstance()) + .build(); } - @TestParameter RuntimeEnv runtimeEnv; + private static Cel setupEnv(CelBuilder celBuilder) { + return celBuilder + .addVar("x", SimpleType.DYN) + .addVar("y", SimpleType.DYN) + .addVar("list_var", ListType.create(SimpleType.STRING)) + .addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.STRING)) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "get_true", + CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)), + CelFunctionDecl.newFunctionDeclaration( + "get_list", + CelOverloadDecl.newGlobalOverload( + "get_list_overload", + ListType.create(SimpleType.INT), + ListType.create(SimpleType.INT)))) + .addFunctionBindings( + CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true)) + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .setOptions(CEL_OPTIONS) + .addCompilerLibraries( + CelExtensions.bindings(), + CelOptionalLibrary.INSTANCE, + CelExtensions.math(CEL_OPTIONS), + CelExtensions.strings(), + CelExtensions.sets(CEL_OPTIONS), + CelExtensions.encoders(CEL_OPTIONS)) + .addRuntimeLibraries( + CelOptionalLibrary.INSTANCE, + CelExtensions.math(CEL_OPTIONS), + CelExtensions.strings(), + CelExtensions.sets(CEL_OPTIONS), + CelExtensions.encoders(CEL_OPTIONS)) + .build(); + } @Test @TestParameters("{source: 'null', expected: 'null'}") @@ -270,9 +255,9 @@ private CelBuilder newCelBuilder() { // TODO: Support folding lists with mixed types. This requires mutable lists. // @TestParameters("{source: 'dyn([1]) + [1.0]'}") public void constantFold_success(String source, String expected) throws Exception { - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(source).getAst(); - CelAbstractSyntaxTree optimizedAst = runtimeEnv.celOptimizer.optimize(ast); + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(expected); } @@ -317,8 +302,8 @@ public void constantFold_success(String source, String expected) throws Exceptio public void constantFold_macros_macroCallMetadataPopulated(String source, String expected) throws Exception { Cel cel = - runtimeEnv - .newCelBuilder() + runtimeFlavor + .builder() .addVar("x", SimpleType.DYN) .addVar("y", SimpleType.DYN) .addMessageTypes(TestAllTypes.getDescriptor()) @@ -363,8 +348,8 @@ public void constantFold_macros_macroCallMetadataPopulated(String source, String @TestParameters("{source: 'false ? false : cel.bind(a, true, a)'}") public void constantFold_macros_withoutMacroCallMetadata(String source) throws Exception { Cel cel = - runtimeEnv - .newCelBuilder() + runtimeFlavor + .builder() .addVar("x", SimpleType.DYN) .addVar("y", SimpleType.DYN) .addMessageTypes(TestAllTypes.getDescriptor()) @@ -418,20 +403,20 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E @TestParameters("{source: 'get_list([1, 2]).map(x, x * 2)'}") @TestParameters("{source: '[(x - 1 > 3) ? (x - 1) : 5].exists(x, x - 1 > 3)'}") public void constantFold_noOp(String source) throws Exception { - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(source).getAst(); - CelAbstractSyntaxTree optimizedAst = runtimeEnv.celOptimizer.optimize(ast); + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(source); } @Test public void constantFold_addFoldableFunction_success() throws Exception { - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("get_true() == get_true()").getAst(); + CelAbstractSyntaxTree ast = cel.compile("get_true() == get_true()").getAst(); ConstantFoldingOptions options = ConstantFoldingOptions.newBuilder().addFoldableFunctions("get_true").build(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) + CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers(ConstantFoldingOptimizer.newInstance(options)) .build(); @@ -442,7 +427,7 @@ public void constantFold_addFoldableFunction_success() throws Exception { @Test public void constantFold_withExpectedResultTypeSet_success() throws Exception { - Cel cel = runtimeEnv.newCelBuilder().setResultType(SimpleType.STRING).build(); + Cel cel = runtimeFlavor.builder().setResultType(SimpleType.STRING).build(); CelOptimizer optimizer = CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers(ConstantFoldingOptimizer.getInstance()) @@ -458,8 +443,8 @@ public void constantFold_withExpectedResultTypeSet_success() throws Exception { public void constantFold_withMacroCallPopulated_comprehensionsAreReplacedWithNotSet() throws Exception { Cel cel = - runtimeEnv - .newCelBuilder() + runtimeFlavor + .builder() .addVar("x", SimpleType.DYN) .setStandardMacros(CelStandardMacro.STANDARD_MACROS) .setOptions(CEL_OPTIONS) @@ -532,9 +517,9 @@ public void constantFold_withMacroCallPopulated_comprehensionsAreReplacedWithNot @Test public void constantFold_astProducesConsistentlyNumberedIds() throws Exception { - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("[1] + [2] + [3]").getAst(); + CelAbstractSyntaxTree ast = cel.compile("[1] + [2] + [3]").getAst(); - CelAbstractSyntaxTree optimizedAst = runtimeEnv.celOptimizer.optimize(ast); + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); assertThat(optimizedAst.getExpr().toString()) .isEqualTo( @@ -555,8 +540,8 @@ public void iterationLimitReached_throws() throws Exception { sb.append(" + ").append(i); } // 0 + 1 + 2 + 3 + ... 200 Cel cel = - runtimeEnv - .newCelBuilder() + runtimeFlavor + .builder() .setOptions( CelOptions.current() .enableHeterogeneousNumericComparisons(true) diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java index 9db04ceac..04e4e6a1d 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java @@ -24,8 +24,6 @@ // import com.google.testing.testsize.MediumTest; import dev.cel.bundle.Cel; import dev.cel.bundle.CelBuilder; -import dev.cel.bundle.CelExperimentalFactory; -import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelFunctionDecl; @@ -44,6 +42,7 @@ import dev.cel.parser.CelUnparserFactory; import dev.cel.runtime.CelFunctionBinding; import dev.cel.testing.BaselineTestCase; +import dev.cel.testing.CelRuntimeFlavor; import java.util.EnumSet; import java.util.Optional; import org.junit.Before; @@ -53,53 +52,41 @@ // @MediumTest @RunWith(TestParameterInjector.class) public class SubexpressionOptimizerBaselineTest extends BaselineTestCase { - private enum RuntimeEnv { - LEGACY(setupCelEnv(CelFactory.standardCelBuilder())), - PLANNER(setupCelEnv(CelExperimentalFactory.plannerCelBuilder())); - - private final Cel cel; - - private static Cel setupCelEnv(CelBuilder celBuilder) { - return celBuilder - .addMessageTypes(TestAllTypes.getDescriptor()) - .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .setOptions( - CelOptions.current() - .populateMacroCalls(true) - .enableHeterogeneousNumericComparisons(true) - .build()) - .addCompilerLibraries( - CelExtensions.optional(), CelExtensions.bindings(), CelExtensions.comprehensions()) - .addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions()) - .addFunctionDeclarations( - CelFunctionDecl.newFunctionDeclaration( - "pure_custom_func", - newGlobalOverload("pure_custom_func_overload", SimpleType.INT, SimpleType.INT)), - CelFunctionDecl.newFunctionDeclaration( - "non_pure_custom_func", - newGlobalOverload( - "non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT))) - .addFunctionBindings( - // This is pure, but for the purposes of excluding it as a CSE candidate, pretend that - // it isn't. - CelFunctionBinding.fromOverloads( - "non_pure_custom_func", - CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val))) - .addFunctionBindings( - CelFunctionBinding.fromOverloads( - "pure_custom_func", - CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val))) - .addVar("x", SimpleType.DYN) - .addVar("y", SimpleType.DYN) - .addVar("opt_x", OptionalType.create(SimpleType.DYN)) - .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) - .build(); - } - - RuntimeEnv(Cel cel) { - this.cel = cel; - } + private static Cel setupCelEnv(CelBuilder celBuilder) { + return celBuilder + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setOptions( + CelOptions.current() + .populateMacroCalls(true) + .enableHeterogeneousNumericComparisons(true) + .build()) + .addCompilerLibraries( + CelExtensions.optional(), CelExtensions.bindings(), CelExtensions.comprehensions()) + .addRuntimeLibraries(CelExtensions.optional(), CelExtensions.comprehensions()) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "pure_custom_func", + newGlobalOverload("pure_custom_func_overload", SimpleType.INT, SimpleType.INT)), + CelFunctionDecl.newFunctionDeclaration( + "non_pure_custom_func", + newGlobalOverload("non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT))) + .addFunctionBindings( + // This is pure, but for the purposes of excluding it as a CSE candidate, pretend that + // it isn't. + CelFunctionBinding.fromOverloads( + "non_pure_custom_func", + CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val))) + .addFunctionBindings( + CelFunctionBinding.fromOverloads( + "pure_custom_func", + CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val))) + .addVar("x", SimpleType.DYN) + .addVar("y", SimpleType.DYN) + .addVar("opt_x", OptionalType.create(SimpleType.DYN)) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .build(); } private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); @@ -129,6 +116,7 @@ private static Cel setupCelEnv(CelBuilder celBuilder) { @Before public void setUp() { + this.cel = setupCelEnv(runtimeFlavor.builder()); overriddenBaseFilePath = ""; } @@ -140,21 +128,23 @@ protected String baselineFileName() { return overriddenBaseFilePath; } - @TestParameter RuntimeEnv runtimeEnv; + @TestParameter CelRuntimeFlavor runtimeFlavor; + + private Cel cel; @Test public void allOptimizers_producesSameEvaluationResult( @TestParameter CseTestOptimizer cseTestOptimizer, @TestParameter CseTestCase cseTestCase) throws Exception { skipBaselineVerification(); - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(cseTestCase.source).getAst(); ImmutableMap inputMap = ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L)); - Object expectedEvalResult = runtimeEnv.cel.createProgram(ast).eval(inputMap); + Object expectedEvalResult = cel.createProgram(ast).eval(inputMap); - CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.newCseOptimizer(runtimeEnv).optimize(ast); + CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.newCseOptimizer(cel).optimize(ast); - Object optimizedEvalResult = runtimeEnv.cel.createProgram(optimizedAst).eval(inputMap); + Object optimizedEvalResult = cel.createProgram(optimizedAst).eval(inputMap); assertThat(optimizedEvalResult).isEqualTo(expectedEvalResult); } @@ -163,20 +153,19 @@ public void allOptimizers_producesSameEvaluationResult_parsedOnly( @TestParameter CseTestCase cseTestCase, @TestParameter CseTestOptimizer cseTestOptimizer) throws Exception { skipBaselineVerification(); - if (runtimeEnv == RuntimeEnv.LEGACY) { + if (runtimeFlavor.equals(CelRuntimeFlavor.LEGACY)) { return; } - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(cseTestCase.source).getAst(); ImmutableMap inputMap = ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L)); - Object expectedEvalResult = runtimeEnv.cel.createProgram(ast).eval(inputMap); + Object expectedEvalResult = cel.createProgram(ast).eval(inputMap); - CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.newCseOptimizer(runtimeEnv).optimize(ast); + CelAbstractSyntaxTree optimizedAst = cseTestOptimizer.newCseOptimizer(cel).optimize(ast); CelAbstractSyntaxTree parsedOnlyOptimizedAst = CelAbstractSyntaxTree.newParsedAst(optimizedAst.getExpr(), optimizedAst.getSource()); - Object optimizedEvalResult = - runtimeEnv.cel.createProgram(parsedOnlyOptimizedAst).eval(inputMap); + Object optimizedEvalResult = cel.createProgram(parsedOnlyOptimizedAst).eval(inputMap); assertThat(optimizedEvalResult).isEqualTo(expectedEvalResult); } @@ -186,22 +175,20 @@ public void subexpression_unparsed() throws Exception { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(cseTestCase.source).getAst(); boolean resultPrinted = false; for (CseTestOptimizer cseTestOptimizer : CseTestOptimizer.values()) { String optimizerName = cseTestOptimizer.name(); CelAbstractSyntaxTree optimizedAst; try { - optimizedAst = cseTestOptimizer.newCseOptimizer(runtimeEnv).optimize(ast); + optimizedAst = cseTestOptimizer.newCseOptimizer(cel).optimize(ast); } catch (Exception e) { testOutput().printf("[%s]: Optimization Error: %s", optimizerName, e); continue; } if (!resultPrinted) { Object optimizedEvalResult = - runtimeEnv - .cel - .createProgram(optimizedAst) + cel.createProgram(optimizedAst) .eval( ImmutableMap.of( "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L))); @@ -225,17 +212,15 @@ public void constfold_before_subexpression_unparsed() throws Exception { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(cseTestCase.source).getAst(); boolean resultPrinted = false; for (CseTestOptimizer cseTestOptimizer : EnumSet.allOf(CseTestOptimizer.class)) { String optimizerName = cseTestOptimizer.name(); CelAbstractSyntaxTree optimizedAst = - cseTestOptimizer.newCseWithConstFoldingOptimizer(runtimeEnv).optimize(ast); + cseTestOptimizer.newCseWithConstFoldingOptimizer(cel).optimize(ast); if (!resultPrinted) { Object optimizedEvalResult = - runtimeEnv - .cel - .createProgram(optimizedAst) + cel.createProgram(optimizedAst) .eval( ImmutableMap.of( "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "y", 6L, "opt_x", Optional.of(5L))); @@ -261,9 +246,9 @@ public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer) testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(cseTestCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = - newCseOptimizer(runtimeEnv.cel, cseTestOptimizer.option).optimize(ast); + newCseOptimizer(cel, cseTestOptimizer.option).optimize(ast); testOutput().println(optimizedAst.getExpr()); } } @@ -272,8 +257,7 @@ public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer) public void large_expressions_block_common_subexpr() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - runtimeEnv.cel, - SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()); + cel, SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()); runLargeTestCases(celOptimizer); } @@ -282,7 +266,7 @@ public void large_expressions_block_common_subexpr() throws Exception { public void large_expressions_block_recursion_depth_1() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - runtimeEnv.cel, + cel, SubexpressionOptimizerOptions.newBuilder() .populateMacroCalls(true) .subexpressionMaxRecursionDepth(1) @@ -295,7 +279,7 @@ public void large_expressions_block_recursion_depth_1() throws Exception { public void large_expressions_block_recursion_depth_2() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - runtimeEnv.cel, + cel, SubexpressionOptimizerOptions.newBuilder() .populateMacroCalls(true) .subexpressionMaxRecursionDepth(2) @@ -308,7 +292,7 @@ public void large_expressions_block_recursion_depth_2() throws Exception { public void large_expressions_block_recursion_depth_3() throws Exception { CelOptimizer celOptimizer = newCseOptimizer( - runtimeEnv.cel, + cel, SubexpressionOptimizerOptions.newBuilder() .populateMacroCalls(true) .subexpressionMaxRecursionDepth(3) @@ -322,12 +306,10 @@ private void runLargeTestCases(CelOptimizer celOptimizer) throws Exception { testOutput().println("Test case: " + cseTestCase.name()); testOutput().println("Source: " + cseTestCase.source); testOutput().println("=====>"); - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(cseTestCase.source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(cseTestCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); Object optimizedEvalResult = - runtimeEnv - .cel - .createProgram(optimizedAst) + cel.createProgram(optimizedAst) .eval( ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))); testOutput().println("Result: " + optimizedEvalResult); @@ -376,13 +358,13 @@ private enum CseTestOptimizer { } // Defers building the optimizer until the test runs - private CelOptimizer newCseOptimizer(RuntimeEnv env) { - return SubexpressionOptimizerBaselineTest.newCseOptimizer(env.cel, option); + private CelOptimizer newCseOptimizer(Cel cel) { + return SubexpressionOptimizerBaselineTest.newCseOptimizer(cel, option); } // Defers building the optimizer until the test runs - private CelOptimizer newCseWithConstFoldingOptimizer(RuntimeEnv env) { - return CelOptimizerFactory.standardCelOptimizerBuilder(env.cel) + private CelOptimizer newCseWithConstFoldingOptimizer(Cel cel) { + return CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers( ConstantFoldingOptimizer.getInstance(), SubexpressionOptimizer.newInstance(option)) .build(); diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 23459e5d8..e7387d7d8 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -26,7 +26,6 @@ import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; import dev.cel.bundle.CelBuilder; -import dev.cel.bundle.CelExperimentalFactory; import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; @@ -61,87 +60,80 @@ import dev.cel.runtime.CelUnknownSet; import dev.cel.runtime.PartialVars; import dev.cel.runtime.Program; +import dev.cel.testing.CelRuntimeFlavor; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public class SubexpressionOptimizerTest { - private enum RuntimeEnv { - LEGACY( - setupCelEnv(CelFactory.standardCelBuilder()), - setupCelForEvaluatingBlock(CelFactory.standardCelBuilder())), - PLANNER( - setupCelEnv(CelExperimentalFactory.plannerCelBuilder()), - setupCelForEvaluatingBlock(CelExperimentalFactory.plannerCelBuilder())); - - private final Cel cel; - private final Cel celForEvaluatingBlock; - - private static Cel setupCelEnv(CelBuilder celBuilder) { - return celBuilder - .addMessageTypes(TestAllTypes.getDescriptor()) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .setOptions( - CelOptions.current() - .populateMacroCalls(true) - .enableHeterogeneousNumericComparisons(true) - .build()) - .addCompilerLibraries(CelExtensions.bindings(), CelExtensions.strings()) - .addRuntimeLibraries(CelExtensions.strings()) - .addFunctionDeclarations( - CelFunctionDecl.newFunctionDeclaration( - "non_pure_custom_func", - newGlobalOverload( - "non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT))) - .addVar("x", SimpleType.DYN) - .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) - .build(); - } - - private static Cel setupCelForEvaluatingBlock(CelBuilder celBuilder) { - return celBuilder - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addFunctionDeclarations( - // These are test only declarations, as the actual function is made internal using @ - // symbol. - // If the main function declaration needs updating, be sure to update the test - // declaration as well. - CelFunctionDecl.newFunctionDeclaration( - "cel.block", - CelOverloadDecl.newGlobalOverload( - "block_test_only_overload", - SimpleType.DYN, - ListType.create(SimpleType.DYN), - SimpleType.DYN)), - SubexpressionOptimizer.newCelBlockFunctionDecl(SimpleType.DYN), - CelFunctionDecl.newFunctionDeclaration( - "get_true", - CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) - // Similarly, this is a test only decl (index0 -> @index0) - .addVarDeclarations( - CelVarDecl.newVarDeclaration("c0", SimpleType.DYN), - CelVarDecl.newVarDeclaration("c1", SimpleType.DYN), - CelVarDecl.newVarDeclaration("index0", SimpleType.DYN), - CelVarDecl.newVarDeclaration("index1", SimpleType.DYN), - CelVarDecl.newVarDeclaration("index2", SimpleType.DYN), - CelVarDecl.newVarDeclaration("@index0", SimpleType.DYN), - CelVarDecl.newVarDeclaration("@index1", SimpleType.DYN), - CelVarDecl.newVarDeclaration("@index2", SimpleType.DYN)) - .addMessageTypes(TestAllTypes.getDescriptor()) - .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) - .build(); - } + private static Cel setupCelEnv(CelBuilder celBuilder) { + return celBuilder + .addMessageTypes(TestAllTypes.getDescriptor()) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setOptions( + CelOptions.current() + .populateMacroCalls(true) + .enableHeterogeneousNumericComparisons(true) + .build()) + .addCompilerLibraries(CelExtensions.bindings(), CelExtensions.strings()) + .addRuntimeLibraries(CelExtensions.strings()) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "non_pure_custom_func", + newGlobalOverload("non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT))) + .addVar("x", SimpleType.DYN) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .build(); + } - RuntimeEnv(Cel cel, Cel celForEvaluatingBlock) { - this.cel = cel; - this.celForEvaluatingBlock = celForEvaluatingBlock; - } + private static Cel setupCelForEvaluatingBlock(CelBuilder celBuilder) { + return celBuilder + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addFunctionDeclarations( + // These are test only declarations, as the actual function is made internal using @ + // symbol. + // If the main function declaration needs updating, be sure to update the test + // declaration as well. + CelFunctionDecl.newFunctionDeclaration( + "cel.block", + CelOverloadDecl.newGlobalOverload( + "block_test_only_overload", + SimpleType.DYN, + ListType.create(SimpleType.DYN), + SimpleType.DYN)), + SubexpressionOptimizer.newCelBlockFunctionDecl(SimpleType.DYN), + CelFunctionDecl.newFunctionDeclaration( + "get_true", + CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) + // Similarly, this is a test only decl (index0 -> @index0) + .addVarDeclarations( + CelVarDecl.newVarDeclaration("c0", SimpleType.DYN), + CelVarDecl.newVarDeclaration("c1", SimpleType.DYN), + CelVarDecl.newVarDeclaration("index0", SimpleType.DYN), + CelVarDecl.newVarDeclaration("index1", SimpleType.DYN), + CelVarDecl.newVarDeclaration("index2", SimpleType.DYN), + CelVarDecl.newVarDeclaration("@index0", SimpleType.DYN), + CelVarDecl.newVarDeclaration("@index1", SimpleType.DYN), + CelVarDecl.newVarDeclaration("@index2", SimpleType.DYN)) + .addMessageTypes(TestAllTypes.getDescriptor()) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .build(); } - @TestParameter RuntimeEnv runtimeEnv; + @TestParameter CelRuntimeFlavor runtimeFlavor; + + private Cel cel; + private Cel celForEvaluatingBlock; + + @Before + public void setUp() { + this.cel = setupCelEnv(runtimeFlavor.builder()); + this.celForEvaluatingBlock = setupCelForEvaluatingBlock(runtimeFlavor.builder()); + } private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); @@ -160,7 +152,7 @@ private static CelBuilder newCelBuilder() { } private CelOptimizer newCseOptimizer(SubexpressionOptimizerOptions options) { - return CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) + return CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers(SubexpressionOptimizer.newInstance(options)) .build(); } @@ -174,21 +166,20 @@ public void cse_resultTypeSet_celBlockOptimizationSuccess() throws Exception { SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().build())) .build(); - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("size('a') + size('a') == 2").getAst(); + CelAbstractSyntaxTree ast = cel.compile("size('a') + size('a') == 2").getAst(); CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); - assertThat(runtimeEnv.cel.createProgram(optimizedAst).eval()).isEqualTo(true); + assertThat(cel.createProgram(optimizedAst).eval()).isEqualTo(true); assertThat(CEL_UNPARSER.unparse(optimizedAst)) .isEqualTo("cel.@block([size(\"a\")], @index0 + @index0 == 2)"); } @Test public void cse_indexEvaluationErrors_throws() throws Exception { - CelAbstractSyntaxTree ast = - runtimeEnv.cel.compile("\"abc\".charAt(10) + \"abc\".charAt(10)").getAst(); + CelAbstractSyntaxTree ast = cel.compile("\"abc\".charAt(10) + \"abc\".charAt(10)").getAst(); CelOptimizer optimizedOptimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) + CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers(SubexpressionOptimizer.getInstance()) .build(); @@ -197,7 +188,7 @@ public void cse_indexEvaluationErrors_throws() throws Exception { String unparsed = CEL_UNPARSER.unparse(optimizedAst); assertThat(unparsed).isEqualTo("cel.@block([\"abc\".charAt(10)], @index0 + @index0)"); - Program program = runtimeEnv.cel.createProgram(optimizedAst); + Program program = cel.createProgram(optimizedAst); CelEvaluationException e = assertThrows(CelEvaluationException.class, () -> program.eval(ImmutableMap.of())); assertThat(e).hasMessageThat().contains("charAt failure: Index out of range: 10"); @@ -205,9 +196,9 @@ public void cse_indexEvaluationErrors_throws() throws Exception { @Test public void cse_withUnknownAttributes() throws Exception { - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("size(\"a\") == 1 ? x.y : x.y").getAst(); + CelAbstractSyntaxTree ast = cel.compile("size(\"a\") == 1 ? x.y : x.y").getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) + CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers(SubexpressionOptimizer.getInstance()) .build(); @@ -217,9 +208,7 @@ public void cse_withUnknownAttributes() throws Exception { .isEqualTo("cel.@block([x.y], (size(\"a\") == 1) ? @index0 : @index0)"); Object result = - runtimeEnv - .cel - .createProgram(optimizedAst) + cel.createProgram(optimizedAst) .eval(PartialVars.of(CelAttributePattern.fromQualifiedIdentifier("x"))); assertThat(result).isInstanceOf(CelUnknownSet.class); } @@ -254,7 +243,7 @@ private enum CseNoOpTestCase { @Test public void cse_withCelBind_noop(@TestParameter CseNoOpTestCase testCase) throws Exception { - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(testCase.source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(testCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = newCseOptimizer(SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()) @@ -266,7 +255,7 @@ public void cse_withCelBind_noop(@TestParameter CseNoOpTestCase testCase) throws @Test public void cse_withCelBlock_noop(@TestParameter CseNoOpTestCase testCase) throws Exception { - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(testCase.source).getAst(); + CelAbstractSyntaxTree ast = cel.compile(testCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = newCseOptimizer(SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()) @@ -279,7 +268,7 @@ public void cse_withCelBlock_noop(@TestParameter CseNoOpTestCase testCase) throw @Test public void cse_withComprehensionStructureRetained() throws Exception { CelAbstractSyntaxTree ast = - runtimeEnv.cel.compile("['foo'].map(x, [x+x]) + ['foo'].map(x, [x+x, x+x])").getAst(); + cel.compile("['foo'].map(x, [x+x]) + ['foo'].map(x, [x+x, x+x])").getAst(); CelOptimizer celOptimizer = newCseOptimizer( SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()); @@ -295,12 +284,10 @@ public void cse_withComprehensionStructureRetained() throws Exception { @Test public void cse_applyConstFoldingBefore() throws Exception { CelAbstractSyntaxTree ast = - runtimeEnv - .cel - .compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") + cel.compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") .getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) + CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers( ConstantFoldingOptimizer.getInstance(), SubexpressionOptimizer.newInstance( @@ -315,12 +302,10 @@ public void cse_applyConstFoldingBefore() throws Exception { @Test public void cse_applyConstFoldingAfter() throws Exception { CelAbstractSyntaxTree ast = - runtimeEnv - .cel - .compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") + cel.compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") .getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) + CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers( SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().build()), @@ -335,9 +320,9 @@ public void cse_applyConstFoldingAfter() throws Exception { @Test public void cse_applyConstFoldingAfter_nothingToFold() throws Exception { - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("size(x) + size(x)").getAst(); + CelAbstractSyntaxTree ast = cel.compile("size(x) + size(x)").getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) + CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers( SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()), @@ -360,7 +345,7 @@ public void iterationLimitReached_throws() throws Exception { largeExprBuilder.append("+"); } } - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile(largeExprBuilder.toString()).getAst(); + CelAbstractSyntaxTree ast = cel.compile(largeExprBuilder.toString()).getAst(); CelOptimizationException e = assertThrows( @@ -376,9 +361,9 @@ public void iterationLimitReached_throws() throws Exception { @Test public void celBlock_astExtensionTagged() throws Exception { - CelAbstractSyntaxTree ast = runtimeEnv.cel.compile("size(x) + size(x)").getAst(); + CelAbstractSyntaxTree ast = cel.compile("size(x) + size(x)").getAst(); CelOptimizer optimizer = - CelOptimizerFactory.standardCelOptimizerBuilder(runtimeEnv.cel) + CelOptimizerFactory.standardCelOptimizerBuilder(cel) .addAstOptimizers( SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()), @@ -411,20 +396,20 @@ private enum BlockTestCase { public void block_success(@TestParameter BlockTestCase testCase) throws Exception { CelAbstractSyntaxTree ast = compileUsingInternalFunctions(testCase.source); - Object evaluatedResult = runtimeEnv.celForEvaluatingBlock.createProgram(ast).eval(); + Object evaluatedResult = celForEvaluatingBlock.createProgram(ast).eval(); assertThat(evaluatedResult).isNotNull(); } @Test public void block_success_parsedOnly(@TestParameter BlockTestCase testCase) throws Exception { - if (runtimeEnv == RuntimeEnv.LEGACY) { + if (runtimeFlavor.equals(CelRuntimeFlavor.LEGACY)) { return; } CelAbstractSyntaxTree ast = compileUsingInternalFunctions(testCase.source, /* parsedOnly= */ true); - Object evaluatedResult = runtimeEnv.celForEvaluatingBlock.createProgram(ast).eval(); + Object evaluatedResult = celForEvaluatingBlock.createProgram(ast).eval(); assertThat(evaluatedResult).isNotNull(); } @@ -686,7 +671,7 @@ public void block_containsCycle_throws() throws Exception { CelAbstractSyntaxTree ast = compileUsingInternalFunctions("cel.block([index1,index0],index0)"); CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> runtimeEnv.cel.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> cel.createProgram(ast).eval()); assertThat(e).hasMessageThat().contains("Cycle detected: @index0"); } @@ -697,7 +682,7 @@ public void block_lazyEvaluationContainsError_cleansUpCycleState() throws Except "cel.block([1/0 > 0], (index0 && false) || (index0 && true))"); CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> runtimeEnv.cel.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> cel.createProgram(ast).eval()); assertThat(e).hasMessageThat().contains("/ by zero"); assertThat(e).hasMessageThat().doesNotContain("Cycle detected"); @@ -709,8 +694,7 @@ public void block_lazyEvaluationContainsError_cleansUpCycleState() throws Except */ private CelAbstractSyntaxTree compileUsingInternalFunctions(String expression, boolean parsedOnly) throws CelValidationException { - CelAbstractSyntaxTree astToModify = - runtimeEnv.celForEvaluatingBlock.compile(expression).getAst(); + CelAbstractSyntaxTree astToModify = celForEvaluatingBlock.compile(expression).getAst(); CelMutableAst mutableAst = CelMutableAst.fromCelAst(astToModify); CelNavigableMutableAst.fromAst(mutableAst) .getRoot() @@ -735,7 +719,7 @@ private CelAbstractSyntaxTree compileUsingInternalFunctions(String expression, b if (parsedOnly) { return mutableAst.toParsedAst(); } - return runtimeEnv.celForEvaluatingBlock.check(mutableAst.toParsedAst()).getAst(); + return celForEvaluatingBlock.check(mutableAst.toParsedAst()).getAst(); } private CelAbstractSyntaxTree compileUsingInternalFunctions(String expression) diff --git a/testing/BUILD.bazel b/testing/BUILD.bazel index c1b2a92b4..b9e68f003 100644 --- a/testing/BUILD.bazel +++ b/testing/BUILD.bazel @@ -11,6 +11,11 @@ java_library( exports = ["//testing/src/main/java/dev/cel/testing:adorner"], ) +java_library( + name = "cel_runtime_flavor", + exports = ["//testing/src/main/java/dev/cel/testing:cel_runtime_flavor"], +) + java_library( name = "line_differ", exports = ["//testing/src/main/java/dev/cel/testing:line_differ"], diff --git a/testing/src/main/java/dev/cel/testing/BUILD.bazel b/testing/src/main/java/dev/cel/testing/BUILD.bazel index 5ee142200..0d94bc8fc 100644 --- a/testing/src/main/java/dev/cel/testing/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/BUILD.bazel @@ -105,3 +105,12 @@ java_library( "@maven_android//:com_google_protobuf_protobuf_javalite", ], ) + +java_library( + name = "cel_runtime_flavor", + srcs = ["CelRuntimeFlavor.java"], + deps = [ + "//bundle:cel", + "//bundle:cel_experimental_factory", + ], +) diff --git a/testing/src/main/java/dev/cel/testing/CelRuntimeFlavor.java b/testing/src/main/java/dev/cel/testing/CelRuntimeFlavor.java new file mode 100644 index 000000000..576e0c1d3 --- /dev/null +++ b/testing/src/main/java/dev/cel/testing/CelRuntimeFlavor.java @@ -0,0 +1,38 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.testing; + +import dev.cel.bundle.CelBuilder; +import dev.cel.bundle.CelExperimentalFactory; +import dev.cel.bundle.CelFactory; + +/** Enumeration of supported CEL runtime environments for testing. */ +public enum CelRuntimeFlavor { + LEGACY { + @Override + public CelBuilder builder() { + return CelFactory.standardCelBuilder(); + } + }, + PLANNER { + @Override + public CelBuilder builder() { + return CelExperimentalFactory.plannerCelBuilder(); + } + }; + + /** Returns a new {@link CelBuilder} instance for this runtime flavor. */ + public abstract CelBuilder builder(); +} From 8342a01ef2d64eb5790ea7b0c5a239042b53bc5c Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 14 Apr 2026 11:14:04 -0700 Subject: [PATCH 33/66] Add parsed-only evaluation test coverage to Regex Extensions PiperOrigin-RevId: 899682453 --- .../test/java/dev/cel/extensions/BUILD.bazel | 1 + .../extensions/CelRegexExtensionsTest.java | 89 ++++++++----------- .../src/main/java/dev/cel/testing/BUILD.bazel | 2 + 3 files changed, 42 insertions(+), 50 deletions(-) diff --git a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel index a9dbfaca2..0b6502410 100644 --- a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel @@ -40,6 +40,7 @@ java_library( "//runtime:lite_runtime_factory", "//runtime:partial_vars", "//runtime:unknown_attributes", + "//testing:cel_runtime_flavor", "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", "@cel_spec//proto/cel/expr/conformance/test:simple_java_proto", diff --git a/extensions/src/test/java/dev/cel/extensions/CelRegexExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelRegexExtensionsTest.java index 8a1bef014..924344b25 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelRegexExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelRegexExtensionsTest.java @@ -20,25 +20,38 @@ import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; -import dev.cel.compiler.CelCompiler; -import dev.cel.compiler.CelCompilerFactory; import dev.cel.runtime.CelEvaluationException; -import dev.cel.runtime.CelRuntime; -import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.testing.CelRuntimeFlavor; import java.util.Optional; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public final class CelRegexExtensionsTest { - private static final CelCompiler COMPILER = - CelCompilerFactory.standardCelCompilerBuilder().addLibraries(CelExtensions.regex()).build(); - private static final CelRuntime RUNTIME = - CelRuntimeFactory.standardCelRuntimeBuilder().addLibraries(CelExtensions.regex()).build(); + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; + + private Cel cel; + + @Before + public void setUp() { + // Legacy runtime does not support parsed-only evaluation mode. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + this.cel = + runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.regex()) + .addRuntimeLibraries(CelExtensions.regex()) + .build(); + } + @Test public void library() { @@ -80,11 +93,7 @@ public void library() { public void replaceAll_success(String target, String regex, String replaceStr, String res) throws Exception { String expr = String.format("regex.replace('%s', '%s', '%s')", target, regex, replaceStr); - CelRuntime.Program program = RUNTIME.createProgram(COMPILER.compile(expr).getAst()); - - Object result = program.eval(); - - assertThat(result).isEqualTo(res); + assertThat(eval(expr)).isEqualTo(res); } @Test @@ -93,11 +102,7 @@ public void replace_nested_success() throws Exception { "regex.replace(" + " regex.replace('%(foo) %(bar) %2','%\\\\((\\\\w+)\\\\)','${\\\\1}')," + " '%(\\\\d+)', '$\\\\1')"; - CelRuntime.Program program = RUNTIME.createProgram(COMPILER.compile(expr).getAst()); - - Object result = program.eval(); - - assertThat(result).isEqualTo("${foo} ${bar} $2"); + assertThat(eval(expr)).isEqualTo("${foo} ${bar} $2"); } @Test @@ -118,11 +123,7 @@ public void replace_nested_success() throws Exception { public void replaceCount_success(String t, String re, String rep, long i, String res) throws Exception { String expr = String.format("regex.replace('%s', '%s', '%s', %d)", t, re, rep, i); - CelRuntime.Program program = RUNTIME.createProgram(COMPILER.compile(expr).getAst()); - - Object result = program.eval(); - - assertThat(result).isEqualTo(res); + assertThat(eval(expr)).isEqualTo(res); } @Test @@ -131,10 +132,8 @@ public void replaceCount_success(String t, String re, String rep, long i, String public void replace_invalidRegex_throwsException(String target, String regex, String replaceStr) throws Exception { String expr = String.format("regex.replace('%s', '%s', '%s')", target, regex, replaceStr); - CelAbstractSyntaxTree ast = COMPILER.compile(expr).getAst(); - CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> RUNTIME.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> eval(expr)); assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); assertThat(e).hasCauseThat().hasMessageThat().contains("Failed to compile regex: "); @@ -143,10 +142,8 @@ public void replace_invalidRegex_throwsException(String target, String regex, St @Test public void replace_invalidCaptureGroupReplaceStr_throwsException() throws Exception { String expr = "regex.replace('test', '(.)', '\\\\2')"; - CelAbstractSyntaxTree ast = COMPILER.compile(expr).getAst(); - CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> RUNTIME.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> eval(expr)); assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); assertThat(e) @@ -158,10 +155,8 @@ public void replace_invalidCaptureGroupReplaceStr_throwsException() throws Excep @Test public void replace_trailingBackslashReplaceStr_throwsException() throws Exception { String expr = "regex.replace('id=123', 'id=(?P\\\\d+)', '\\\\')"; - CelAbstractSyntaxTree ast = COMPILER.compile(expr).getAst(); - CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> RUNTIME.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> eval(expr)); assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); assertThat(e) @@ -173,10 +168,8 @@ public void replace_trailingBackslashReplaceStr_throwsException() throws Excepti @Test public void replace_invalidGroupReferenceReplaceStr_throwsException() throws Exception { String expr = "regex.replace('id=123', 'id=(?P\\\\d+)', '\\\\a')"; - CelAbstractSyntaxTree ast = COMPILER.compile(expr).getAst(); - CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> RUNTIME.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> eval(expr)); assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); assertThat(e) @@ -199,9 +192,7 @@ public void replace_invalidGroupReferenceReplaceStr_throwsException() throws Exc @TestParameters("{target: 'brand', regex: 'brand', expectedResult: 'brand'}") public void extract_success(String target, String regex, String expectedResult) throws Exception { String expr = String.format("regex.extract('%s', '%s')", target, regex); - CelRuntime.Program program = RUNTIME.createProgram(COMPILER.compile(expr).getAst()); - - Object result = program.eval(); + Object result = eval(expr); assertThat(result).isInstanceOf(Optional.class); assertThat((Optional) result).hasValue(expectedResult); @@ -213,9 +204,7 @@ public void extract_success(String target, String regex, String expectedResult) @TestParameters("{target: '', regex: '\\\\w+'}") public void extract_no_match(String target, String regex) throws Exception { String expr = String.format("regex.extract('%s', '%s')", target, regex); - CelRuntime.Program program = RUNTIME.createProgram(COMPILER.compile(expr).getAst()); - - Object result = program.eval(); + Object result = eval(expr); assertThat(result).isInstanceOf(Optional.class); assertThat((Optional) result).isEmpty(); @@ -227,10 +216,8 @@ public void extract_no_match(String target, String regex) throws Exception { public void extract_multipleCaptureGroups_throwsException(String target, String regex) throws Exception { String expr = String.format("regex.extract('%s', '%s')", target, regex); - CelAbstractSyntaxTree ast = COMPILER.compile(expr).getAst(); - CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> RUNTIME.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> eval(expr)); assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); assertThat(e) @@ -263,9 +250,7 @@ private enum ExtractAllTestCase { @Test public void extractAll_success(@TestParameter ExtractAllTestCase testCase) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(testCase.expr).getAst(); - - Object result = RUNTIME.createProgram(ast).eval(); + Object result = eval(testCase.expr); assertThat(result).isEqualTo(testCase.expectedResult); } @@ -281,10 +266,8 @@ public void extractAll_success(@TestParameter ExtractAllTestCase testCase) throw public void extractAll_multipleCaptureGroups_throwsException(String target, String regex) throws Exception { String expr = String.format("regex.extractAll('%s', '%s')", target, regex); - CelAbstractSyntaxTree ast = COMPILER.compile(expr).getAst(); - CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> RUNTIME.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> eval(expr)); assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); assertThat(e) @@ -292,4 +275,10 @@ public void extractAll_multipleCaptureGroups_throwsException(String target, Stri .hasMessageThat() .contains("Regular expression has more than one capturing group:"); } + + private Object eval(String expr) throws Exception { + CelAbstractSyntaxTree ast = + isParseOnly ? cel.parse(expr).getAst() : cel.compile(expr).getAst(); + return cel.createProgram(ast).eval(); + } } diff --git a/testing/src/main/java/dev/cel/testing/BUILD.bazel b/testing/src/main/java/dev/cel/testing/BUILD.bazel index 0d94bc8fc..b52026ec4 100644 --- a/testing/src/main/java/dev/cel/testing/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/BUILD.bazel @@ -109,6 +109,8 @@ java_library( java_library( name = "cel_runtime_flavor", srcs = ["CelRuntimeFlavor.java"], + tags = [ + ], deps = [ "//bundle:cel", "//bundle:cel_experimental_factory", From 74ffbb3a269964c96006fab24e6cdc687adc7170 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 14 Apr 2026 15:08:15 -0700 Subject: [PATCH 34/66] Fix accu_init to be lazily initialized in folder. Add parsed-only evaluation test coverage to Bindings Extensions PiperOrigin-RevId: 899792272 --- .../extensions/CelBindingsExtensionsTest.java | 287 ++++++++++-------- .../java/dev/cel/runtime/planner/BUILD.bazel | 1 + .../dev/cel/runtime/planner/EvalFold.java | 52 +++- 3 files changed, 203 insertions(+), 137 deletions(-) diff --git a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java index ff9e31432..b87967d0e 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java @@ -22,40 +22,51 @@ import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; import dev.cel.common.CelOverloadDecl; import dev.cel.common.CelValidationException; +import dev.cel.common.exceptions.CelDivideByZeroException; import dev.cel.common.types.SimpleType; import dev.cel.common.types.StructTypeReference; -import dev.cel.compiler.CelCompiler; -import dev.cel.compiler.CelCompilerFactory; import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.parser.CelMacro; import dev.cel.parser.CelStandardMacro; +import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelFunctionBinding; -import dev.cel.runtime.CelRuntime; -import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.testing.CelRuntimeFlavor; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public final class CelBindingsExtensionsTest { - private static final CelCompiler COMPILER = - CelCompilerFactory.standardCelCompilerBuilder() - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) - .build(); - - private static final CelRuntime RUNTIME = - CelRuntimeFactory.standardCelRuntimeBuilder() - .addLibraries(CelOptionalLibrary.INSTANCE) - .build(); + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; + + private Cel cel; + + @Before + public void setUp() { + // Legacy runtime does not support parsed-only evaluation mode. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + cel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addCompilerLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) + .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) + .build(); + } @Test public void library() { @@ -93,9 +104,7 @@ private enum BindingTestCase { @Test public void binding_success(@TestParameter BindingTestCase testCase) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(testCase.source).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - boolean evaluatedResult = (boolean) program.eval(); + boolean evaluatedResult = (boolean) eval(testCase.source); assertThat(evaluatedResult).isTrue(); } @@ -103,9 +112,11 @@ public void binding_success(@TestParameter BindingTestCase testCase) throws Exce @Test @TestParameters("{expr: 'false.bind(false, false, false)'}") public void binding_nonCelNamespace_success(String expr) throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.bindings()) + Cel customCel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) + .addCompilerLibraries(CelExtensions.bindings()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "bind", @@ -116,18 +127,16 @@ public void binding_nonCelNamespace_success(String expr) throws Exception { SimpleType.BOOL, SimpleType.BOOL, SimpleType.BOOL))) - .build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "bool_bind_bool_bool_bool", - Arrays.asList(Boolean.class, Boolean.class, Boolean.class, Boolean.class), - (args) -> true)) + CelFunctionBinding.fromOverloads( + "bind", + CelFunctionBinding.from( + "bool_bind_bool_bool_bool", + Arrays.asList(Boolean.class, Boolean.class, Boolean.class, Boolean.class), + (args) -> true))) .build(); - CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst(); - boolean result = (boolean) celRuntime.createProgram(ast).eval(); + boolean result = (boolean) eval(customCel, expr); assertThat(result).isTrue(); } @@ -135,7 +144,7 @@ public void binding_nonCelNamespace_success(String expr) throws Exception { @TestParameters("{expr: 'cel.bind(bad.name, true, bad.name)'}") public void binding_throwsCompilationException(String expr) throws Exception { CelValidationException e = - assertThrows(CelValidationException.class, () -> COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("cel.bind() variable name must be a simple identifier"); } @@ -143,70 +152,76 @@ public void binding_throwsCompilationException(String expr) throws Exception { @Test @SuppressWarnings("Immutable") // Test only public void lazyBinding_bindingVarNeverReferenced() throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() + + AtomicInteger invocation = new AtomicInteger(); + Cel customCel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) .setStandardMacros(CelStandardMacro.HAS) .addMessageTypes(TestAllTypes.getDescriptor()) .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) - .addLibraries(CelExtensions.bindings()) + .addCompilerLibraries(CelExtensions.bindings()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "get_true", CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) - .build(); - AtomicInteger invocation = new AtomicInteger(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() - .addMessageTypes(TestAllTypes.getDescriptor()) .addFunctionBindings( - CelFunctionBinding.from( - "get_true_overload", - ImmutableList.of(), - arg -> { - invocation.getAndIncrement(); - return true; - })) + CelFunctionBinding.fromOverloads( + "get_true", + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + }))) .build(); - CelAbstractSyntaxTree ast = - celCompiler.compile("cel.bind(t, get_true(), has(msg.single_int64) ? t : false)").getAst(); - boolean result = (boolean) - celRuntime - .createProgram(ast) - .eval(ImmutableMap.of("msg", TestAllTypes.getDefaultInstance())); + eval( + customCel, + "cel.bind(t, get_true(), has(msg.single_int64) ? t : false)", + ImmutableMap.of("msg", TestAllTypes.getDefaultInstance())); assertThat(result).isFalse(); assertThat(invocation.get()).isEqualTo(0); } + @Test + public void lazyBinding_throwsEvaluationException() throws Exception { + CelEvaluationException e = + assertThrows(CelEvaluationException.class, () -> eval(cel, "cel.bind(t, 1 / 0, t)")); + + assertThat(e).hasMessageThat().contains("/ by zero"); + assertThat(e).hasCauseThat().isInstanceOf(CelDivideByZeroException.class); + } + @Test @SuppressWarnings("Immutable") // Test only public void lazyBinding_accuInitEvaluatedOnce() throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.bindings()) + AtomicInteger invocation = new AtomicInteger(); + Cel customCel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) + .addCompilerLibraries(CelExtensions.bindings()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "get_true", CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) - .build(); - AtomicInteger invocation = new AtomicInteger(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "get_true_overload", - ImmutableList.of(), - arg -> { - invocation.getAndIncrement(); - return true; - })) + CelFunctionBinding.fromOverloads( + "get_true", + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + }))) .build(); - CelAbstractSyntaxTree ast = - celCompiler.compile("cel.bind(t, get_true(), t && t && t && t)").getAst(); - - boolean result = (boolean) celRuntime.createProgram(ast).eval(); + boolean result = (boolean) eval(customCel, "cel.bind(t, get_true(), t && t && t && t)"); assertThat(result).isTrue(); assertThat(invocation.get()).isEqualTo(1); @@ -215,32 +230,32 @@ public void lazyBinding_accuInitEvaluatedOnce() throws Exception { @Test @SuppressWarnings("Immutable") // Test only public void lazyBinding_withNestedBinds() throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.bindings()) + AtomicInteger invocation = new AtomicInteger(); + Cel customCel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) + .addCompilerLibraries(CelExtensions.bindings()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "get_true", CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) - .build(); - AtomicInteger invocation = new AtomicInteger(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "get_true_overload", - ImmutableList.of(), - arg -> { - invocation.getAndIncrement(); - return true; - })) + CelFunctionBinding.fromOverloads( + "get_true", + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + }))) .build(); - CelAbstractSyntaxTree ast = - celCompiler - .compile("cel.bind(t1, get_true(), cel.bind(t2, get_true(), t1 && t2 && t1 && t2))") - .getAst(); - - boolean result = (boolean) celRuntime.createProgram(ast).eval(); + boolean result = + (boolean) + eval( + customCel, + "cel.bind(t1, get_true(), cel.bind(t2, get_true(), t1 && t2 && t1 && t2))"); assertThat(result).isTrue(); assertThat(invocation.get()).isEqualTo(2); @@ -249,32 +264,31 @@ public void lazyBinding_withNestedBinds() throws Exception { @Test @SuppressWarnings({"Immutable", "unchecked"}) // Test only public void lazyBinding_boundAttributeInComprehension() throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() + AtomicInteger invocation = new AtomicInteger(); + Cel customCel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) .setStandardMacros(CelStandardMacro.MAP) - .addLibraries(CelExtensions.bindings()) + .addCompilerLibraries(CelExtensions.bindings()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "get_true", CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) - .build(); - AtomicInteger invocation = new AtomicInteger(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "get_true_overload", - ImmutableList.of(), - arg -> { - invocation.getAndIncrement(); - return true; - })) + CelFunctionBinding.fromOverloads( + "get_true", + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + }))) .build(); - CelAbstractSyntaxTree ast = - celCompiler.compile("cel.bind(x, get_true(), [1,2,3].map(y, y < 0 || x))").getAst(); - - List result = (List) celRuntime.createProgram(ast).eval(); + List result = + (List) eval(customCel, "cel.bind(x, get_true(), [1,2,3].map(y, y < 0 || x))"); assertThat(result).containsExactly(true, true, true); assertThat(invocation.get()).isEqualTo(1); @@ -283,38 +297,55 @@ public void lazyBinding_boundAttributeInComprehension() throws Exception { @Test @SuppressWarnings({"Immutable"}) // Test only public void lazyBinding_boundAttributeInNestedComprehension() throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() + AtomicInteger invocation = new AtomicInteger(); + Cel customCel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) .setStandardMacros(CelStandardMacro.EXISTS) - .addLibraries(CelExtensions.bindings()) + .addCompilerLibraries(CelExtensions.bindings()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "get_true", CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) - .build(); - AtomicInteger invocation = new AtomicInteger(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "get_true_overload", - ImmutableList.of(), - arg -> { - invocation.getAndIncrement(); - return true; - })) + CelFunctionBinding.fromOverloads( + "get_true", + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + }))) .build(); - CelAbstractSyntaxTree ast = - celCompiler - .compile( + boolean result = + (boolean) + eval( + customCel, "cel.bind(x, get_true(), [1,2,3].exists(unused, x && " - + "['a','b','c'].exists(unused_2, x)))") - .getAst(); - - boolean result = (boolean) celRuntime.createProgram(ast).eval(); + + "['a','b','c'].exists(unused_2, x)))"); assertThat(result).isTrue(); assertThat(invocation.get()).isEqualTo(1); } + + private Object eval(Cel cel, String expression) throws Exception { + return eval(cel, expression, ImmutableMap.of()); + } + + private Object eval(Cel cel, String expression, Map variables) throws Exception { + CelAbstractSyntaxTree ast; + if (isParseOnly) { + ast = cel.parse(expression).getAst(); + } else { + ast = cel.compile(expression).getAst(); + } + return cel.createProgram(ast).eval(variables); + } + + private Object eval(String expression) throws Exception { + return eval(this.cel, expression, ImmutableMap.of()); + } } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel index cb2ad5a82..824c918d8 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel @@ -364,6 +364,7 @@ java_library( deps = [ ":activation_wrapper", ":planned_interpretable", + "//common/exceptions:runtime_exception", "//runtime:accumulated_unknowns", "//runtime:concatenated_list_view", "//runtime:evaluation_exception", diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java index 2631bf0b9..2eb30671e 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.exceptions.CelRuntimeException; import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.ConcatenatedListView; @@ -77,8 +78,7 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval if (iterRangeRaw instanceof AccumulatedUnknowns) { return iterRangeRaw; } - Folder folder = new Folder(resolver, accuVar, iterVar, iterVar2); - folder.accuVal = maybeWrapAccumulator(accuInit.eval(folder, frame)); + Folder folder = new Folder(resolver, frame, accuInit, accuVar, iterVar, iterVar2); Object result; if (iterRangeRaw instanceof Map) { @@ -104,11 +104,14 @@ private Object evalMap(Map iterRange, Folder folder, ExecutionFrame frame) boolean cond = (boolean) condition.eval(folder, frame); if (!cond) { + folder.computeResult = true; return result.eval(folder, frame); } folder.accuVal = loopStep.eval(folder, frame); + folder.initialized = true; } + folder.computeResult = true; return result.eval(folder, frame); } @@ -127,12 +130,15 @@ private Object evalList(Collection iterRange, Folder folder, ExecutionFrame f boolean cond = (boolean) condition.eval(folder, frame); if (!cond) { + folder.computeResult = true; return result.eval(folder, frame); } folder.accuVal = loopStep.eval(folder, frame); + folder.initialized = true; index++; } + folder.computeResult = true; return result.eval(folder, frame); } @@ -155,6 +161,8 @@ private static Object maybeUnwrapAccumulator(Object val) { private static class Folder implements ActivationWrapper { private final GlobalResolver resolver; + private final ExecutionFrame frame; + private final PlannedInterpretable accuInit; private final String accuVar; private final String iterVar; private final String iterVar2; @@ -162,9 +170,19 @@ private static class Folder implements ActivationWrapper { private Object iterVarVal; private Object iterVar2Val; private Object accuVal; - - private Folder(GlobalResolver resolver, String accuVar, String iterVar, String iterVar2) { + private boolean initialized = false; + private boolean computeResult = false; + + private Folder( + GlobalResolver resolver, + ExecutionFrame frame, + PlannedInterpretable accuInit, + String accuVar, + String iterVar, + String iterVar2) { this.resolver = resolver; + this.frame = frame; + this.accuInit = accuInit; this.accuVar = accuVar; this.iterVar = iterVar; this.iterVar2 = iterVar2; @@ -183,18 +201,34 @@ public boolean isLocallyBound(String name) { @Override public @Nullable Object resolve(String name) { if (name.equals(accuVar)) { + if (!initialized) { + initialized = true; + try { + accuVal = maybeWrapAccumulator(accuInit.eval(resolver, frame)); + } catch (CelEvaluationException e) { + throw new LazyEvaluationRuntimeException(e); + } + } return accuVal; } - if (name.equals(iterVar)) { - return this.iterVarVal; - } + if (!computeResult) { + if (name.equals(iterVar)) { + return this.iterVarVal; + } - if (name.equals(iterVar2)) { - return this.iterVar2Val; + if (name.equals(iterVar2)) { + return this.iterVar2Val; + } } return resolver.resolve(name); } } + + private static class LazyEvaluationRuntimeException extends CelRuntimeException { + private LazyEvaluationRuntimeException(CelEvaluationException cause) { + super(cause, cause.getErrorCode()); + } + } } From 3075687e905227f5daa4da15df01b480397eb0c1 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 14 Apr 2026 18:04:42 -0700 Subject: [PATCH 35/66] Add parsed-only evaluation coverage to Proto Extensions PiperOrigin-RevId: 899863532 --- .../extensions/CelProtoExtensionsTest.java | 173 +++++++++--------- 1 file changed, 88 insertions(+), 85 deletions(-) diff --git a/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java index 15f6df5be..2e55619db 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java @@ -26,7 +26,6 @@ import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; -import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelFunctionDecl; @@ -35,8 +34,6 @@ import dev.cel.common.CelValidationException; import dev.cel.common.types.SimpleType; import dev.cel.common.types.StructTypeReference; -import dev.cel.compiler.CelCompiler; -import dev.cel.compiler.CelCompilerFactory; import dev.cel.expr.conformance.proto2.Proto2ExtensionScopedMessage; import dev.cel.expr.conformance.proto2.TestAllTypes; import dev.cel.expr.conformance.proto2.TestAllTypes.NestedEnum; @@ -44,27 +41,35 @@ import dev.cel.parser.CelMacro; import dev.cel.parser.CelStandardMacro; import dev.cel.runtime.CelFunctionBinding; -import dev.cel.runtime.CelRuntime; -import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.testing.CelRuntimeFlavor; +import java.util.Map; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public final class CelProtoExtensionsTest { - private static final CelCompiler CEL_COMPILER = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.protos()) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addFileTypes(TestAllTypesExtensions.getDescriptor()) - .addVar("msg", StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes")) - .setContainer(CelContainer.ofName("cel.expr.conformance.proto2")) - .build(); + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; - private static final CelRuntime CEL_RUNTIME = - CelRuntimeFactory.standardCelRuntimeBuilder() - .addFileTypes(TestAllTypesExtensions.getDescriptor()) - .build(); + private Cel cel; + + @Before + public void setUp() { + // Legacy runtime does not support parsed-only evaluation mode. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + this.cel = + runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.protos()) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addFileTypes(TestAllTypesExtensions.getDescriptor()) + .addVar("msg", StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes")) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto2")) + .build(); + } private static final TestAllTypes PACKAGE_SCOPED_EXT_MSG = TestAllTypes.newBuilder() @@ -106,10 +111,7 @@ public void library() { "{expr: 'proto.hasExt(msg, cel.expr.conformance.proto2.repeated_test_all_types)'}") @TestParameters("{expr: '!proto.hasExt(msg, cel.expr.conformance.proto2.test_all_types_ext)'}") public void hasExt_packageScoped_success(String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - boolean result = - (boolean) - CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); + boolean result = (boolean) eval(expr, ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); assertThat(result).isTrue(); } @@ -128,10 +130,7 @@ public void hasExt_packageScoped_success(String expr) throws Exception { "{expr: '!proto.hasExt(msg," + " cel.expr.conformance.proto2.Proto2ExtensionScopedMessage.nested_enum_ext)'}") public void hasExt_messageScoped_success(String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - boolean result = - (boolean) - CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", MESSAGE_SCOPED_EXT_MSG)); + boolean result = (boolean) eval(expr, ImmutableMap.of("msg", MESSAGE_SCOPED_EXT_MSG)); assertThat(result).isTrue(); } @@ -142,9 +141,10 @@ public void hasExt_messageScoped_success(String expr) throws Exception { public void hasExt_nonProtoNamespace_success(String expr) throws Exception { StructTypeReference proto2MessageTypeReference = StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes"); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.protos()) + Cel customCel = + runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.protos()) .addVar("msg", proto2MessageTypeReference) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( @@ -154,37 +154,35 @@ public void hasExt_nonProtoNamespace_success(String expr) throws Exception { SimpleType.BOOL, ImmutableList.of( proto2MessageTypeReference, SimpleType.STRING, SimpleType.INT)))) - .build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "msg_hasExt", - ImmutableList.of(TestAllTypes.class, String.class, Long.class), - (arg) -> { - TestAllTypes msg = (TestAllTypes) arg[0]; - String extensionField = (String) arg[1]; - return msg.getAllFields().keySet().stream() - .anyMatch(fd -> fd.getFullName().equals(extensionField)); - })) + CelFunctionBinding.fromOverloads( + "hasExt", + CelFunctionBinding.from( + "msg_hasExt", + ImmutableList.of(TestAllTypes.class, String.class, Long.class), + (arg) -> { + TestAllTypes msg = (TestAllTypes) arg[0]; + String extensionField = (String) arg[1]; + return msg.getAllFields().keySet().stream() + .anyMatch(fd -> fd.getFullName().equals(extensionField)); + }))) .build(); - CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst(); boolean result = - (boolean) - celRuntime.createProgram(ast).eval(ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); + (boolean) eval(customCel, expr, ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); assertThat(result).isTrue(); } @Test public void hasExt_undefinedField_throwsException() { + // This is a type-checking failure + Assume.assumeFalse(isParseOnly); CelValidationException exception = assertThrows( CelValidationException.class, () -> - CEL_COMPILER - .compile("!proto.hasExt(msg, cel.expr.conformance.proto2.undefined_field)") + cel.compile("!proto.hasExt(msg, cel.expr.conformance.proto2.undefined_field)") .getAst()); assertThat(exception) @@ -204,10 +202,7 @@ public void hasExt_undefinedField_throwsException() { "{expr: 'proto.getExt(msg, cel.expr.conformance.proto2.repeated_test_all_types) ==" + " [TestAllTypes{single_string: ''A''}, TestAllTypes{single_string: ''B''}]'}") public void getExt_packageScoped_success(String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - boolean result = - (boolean) - CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); + boolean result = (boolean) eval(expr, ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); assertThat(result).isTrue(); } @@ -221,22 +216,20 @@ public void getExt_packageScoped_success(String expr) throws Exception { "{expr: 'proto.getExt(msg," + " cel.expr.conformance.proto2.Proto2ExtensionScopedMessage.int64_ext) == 1'}") public void getExt_messageScopedSuccess(String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - boolean result = - (boolean) - CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", MESSAGE_SCOPED_EXT_MSG)); + boolean result = (boolean) eval(expr, ImmutableMap.of("msg", MESSAGE_SCOPED_EXT_MSG)); assertThat(result).isTrue(); } @Test public void getExt_undefinedField_throwsException() { + // This is a type-checking failure + Assume.assumeFalse(isParseOnly); CelValidationException exception = assertThrows( CelValidationException.class, () -> - CEL_COMPILER - .compile("!proto.getExt(msg, cel.expr.conformance.proto2.undefined_field)") + cel.compile("!proto.getExt(msg, cel.expr.conformance.proto2.undefined_field)") .getAst()); assertThat(exception) @@ -250,9 +243,10 @@ public void getExt_undefinedField_throwsException() { public void getExt_nonProtoNamespace_success(String expr) throws Exception { StructTypeReference proto2MessageTypeReference = StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes"); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.protos()) + Cel customCel = + runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.protos()) .addVar("msg", proto2MessageTypeReference) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( @@ -262,29 +256,26 @@ public void getExt_nonProtoNamespace_success(String expr) throws Exception { SimpleType.DYN, ImmutableList.of( proto2MessageTypeReference, SimpleType.STRING, SimpleType.INT)))) - .build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "msg_getExt", - ImmutableList.of(TestAllTypes.class, String.class, Long.class), - (arg) -> { - TestAllTypes msg = (TestAllTypes) arg[0]; - String extensionField = (String) arg[1]; - FieldDescriptor extensionDescriptor = - msg.getAllFields().keySet().stream() - .filter(fd -> fd.getFullName().equals(extensionField)) - .findAny() - .get(); - return msg.getField(extensionDescriptor); - })) + CelFunctionBinding.fromOverloads( + "getExt", + CelFunctionBinding.from( + "msg_getExt", + ImmutableList.of(TestAllTypes.class, String.class, Long.class), + (arg) -> { + TestAllTypes msg = (TestAllTypes) arg[0]; + String extensionField = (String) arg[1]; + FieldDescriptor extensionDescriptor = + msg.getAllFields().keySet().stream() + .filter(fd -> fd.getFullName().equals(extensionField)) + .findAny() + .get(); + return msg.getField(extensionDescriptor); + }))) .build(); - CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst(); boolean result = - (boolean) - celRuntime.createProgram(ast).eval(ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); + (boolean) eval(customCel, expr, ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); assertThat(result).isTrue(); } @@ -293,21 +284,24 @@ public void getExt_nonProtoNamespace_success(String expr) throws Exception { public void getExt_onAnyPackedExtensionField_success() throws Exception { ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); TestAllTypesExtensions.registerAllExtensions(extensionRegistry); - Cel cel = - CelFactory.standardCelBuilder() + Cel customCel = + runtimeFlavor + .builder() // CEL-Internal-2 .addCompilerLibraries(CelExtensions.protos()) .addFileTypes(TestAllTypesExtensions.getDescriptor()) .setExtensionRegistry(extensionRegistry) .addVar("msg", StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes")) .build(); - CelAbstractSyntaxTree ast = - cel.compile("proto.getExt(msg, cel.expr.conformance.proto2.int32_ext)").getAst(); Any anyMsg = Any.pack( TestAllTypes.newBuilder().setExtension(TestAllTypesExtensions.int32Ext, 1).build()); - - Long result = (Long) cel.createProgram(ast).eval(ImmutableMap.of("msg", anyMsg)); + Long result = + (Long) + eval( + customCel, + "proto.getExt(msg, cel.expr.conformance.proto2.int32_ext)", + ImmutableMap.of("msg", anyMsg)); assertThat(result).isEqualTo(1); } @@ -343,9 +337,18 @@ private enum ParseErrorTestCase { @Test public void parseErrors(@TestParameter ParseErrorTestCase testcase) { CelValidationException e = - assertThrows( - CelValidationException.class, () -> CEL_COMPILER.compile(testcase.expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.parse(testcase.expr).getAst()); assertThat(e).hasMessageThat().isEqualTo(testcase.error); } + + private Object eval(String expression, Map variables) throws Exception { + return eval(this.cel, expression, variables); + } + + private Object eval(Cel cel, String expression, Map variables) throws Exception { + CelAbstractSyntaxTree ast = + this.isParseOnly ? cel.parse(expression).getAst() : cel.compile(expression).getAst(); + return cel.createProgram(ast).eval(variables); + } } From 2203ac83b0fb19ad56fe7b9bfb4a963034886bd3 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 14 Apr 2026 18:59:59 -0700 Subject: [PATCH 36/66] Support parsed-only evaluation for lists extensions, remove check for heterogeneous numeric comparisons for sorting PiperOrigin-RevId: 899880149 --- .../cel/extensions/CelListsExtensions.java | 66 +++----- .../extensions/CelListsExtensionsTest.java | 146 +++++++++--------- 2 files changed, 99 insertions(+), 113 deletions(-) diff --git a/extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java index a91edd822..79539b008 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java @@ -128,7 +128,8 @@ public enum Function { "list_sort", "Sorts a list with comparable elements.", ListType.create(TypeParamType.create("T")), - ListType.create(TypeParamType.create("T"))))), + ListType.create(TypeParamType.create("T")))), + CelFunctionBinding.from("list_sort", Collection.class, CelListsExtensions::sort)), SORT_BY( CelFunctionDecl.newFunctionDeclaration( "lists.@sortByAssociatedKeys", @@ -136,7 +137,11 @@ public enum Function { "list_sortByAssociatedKeys", "Sorts a list by a key value. Used by the 'sortBy' macro", ListType.create(TypeParamType.create("T")), - ListType.create(TypeParamType.create("T"))))); + ListType.create(TypeParamType.create("T")))), + CelFunctionBinding.from( + "list_sortByAssociatedKeys", + Collection.class, + CelListsExtensions::sortByAssociatedKeys)); private final CelFunctionDecl functionDecl; private final ImmutableSet functionBindings; @@ -147,7 +152,10 @@ String getFunction() { Function(CelFunctionDecl functionDecl, CelFunctionBinding... functionBindings) { this.functionDecl = functionDecl; - this.functionBindings = ImmutableSet.copyOf(functionBindings); + this.functionBindings = + functionBindings.length > 0 + ? CelFunctionBinding.fromOverloads(functionDecl.name(), functionBindings) + : ImmutableSet.of(); } } @@ -240,32 +248,13 @@ public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) { @Override public void setRuntimeOptions( CelRuntimeBuilder runtimeBuilder, RuntimeEquality runtimeEquality, CelOptions celOptions) { - for (Function function : functions) { - runtimeBuilder.addFunctionBindings(function.functionBindings); - for (CelOverloadDecl overload : function.functionDecl.overloads()) { - switch (overload.overloadId()) { - case "list_distinct": - runtimeBuilder.addFunctionBindings( - CelFunctionBinding.from( - "list_distinct", Collection.class, (list) -> distinct(list, runtimeEquality))); - break; - case "list_sort": - runtimeBuilder.addFunctionBindings( - CelFunctionBinding.from( - "list_sort", Collection.class, (list) -> sort(list, celOptions))); - break; - case "list_sortByAssociatedKeys": - runtimeBuilder.addFunctionBindings( - CelFunctionBinding.from( - "list_sortByAssociatedKeys", - Collection.class, - (list) -> sortByAssociatedKeys(list, celOptions))); - break; - default: - // Nothing to add - } - } - } + functions.forEach(function -> runtimeBuilder.addFunctionBindings(function.functionBindings)); + + runtimeBuilder.addFunctionBindings( + CelFunctionBinding.fromOverloads( + "distinct", + CelFunctionBinding.from( + "list_distinct", Collection.class, (list) -> distinct(list, runtimeEquality)))); } private static ImmutableList slice(Collection list, long from, long to) { @@ -369,22 +358,18 @@ private static List reverse(Collection list) { } } - private static ImmutableList sort(Collection objects, CelOptions options) { - return ImmutableList.sortedCopyOf( - new CelObjectComparator(options.enableHeterogeneousNumericComparisons()), objects); + private static ImmutableList sort(Collection objects) { + return ImmutableList.sortedCopyOf(new CelObjectComparator(), objects); } private static class CelObjectComparator implements Comparator { - private final boolean enableHeterogeneousNumericComparisons; - CelObjectComparator(boolean enableHeterogeneousNumericComparisons) { - this.enableHeterogeneousNumericComparisons = enableHeterogeneousNumericComparisons; - } + CelObjectComparator() {} @SuppressWarnings({"unchecked"}) @Override public int compare(Object o1, Object o2) { - if (o1 instanceof Number && o2 instanceof Number && enableHeterogeneousNumericComparisons) { + if (o1 instanceof Number && o2 instanceof Number) { return ComparisonFunctions.numericCompare((Number) o1, (Number) o2); } @@ -444,12 +429,9 @@ private static Optional sortByMacro( @SuppressWarnings({"unchecked", "rawtypes"}) private static ImmutableList sortByAssociatedKeys( - Collection> keyValuePairs, CelOptions options) { + Collection> keyValuePairs) { List[] array = keyValuePairs.toArray(new List[0]); - Arrays.sort( - array, - new CelObjectByKeyComparator( - new CelObjectComparator(options.enableHeterogeneousNumericComparisons()))); + Arrays.sort(array, new CelObjectByKeyComparator(new CelObjectComparator())); ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(array.length); for (List pair : array) { builder.add(pair.get(1)); diff --git a/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java index c4739b18b..2083ccc42 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java @@ -19,41 +19,38 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSortedMultiset; import com.google.common.collect.ImmutableSortedSet; +import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; -import dev.cel.bundle.CelFactory; +import dev.cel.bundle.CelBuilder; +import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; -import dev.cel.common.CelOptions; import dev.cel.common.CelValidationException; import dev.cel.common.types.SimpleType; import dev.cel.expr.conformance.test.SimpleTest; import dev.cel.parser.CelStandardMacro; import dev.cel.runtime.CelEvaluationException; +import dev.cel.testing.CelRuntimeFlavor; +import java.util.Map; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public class CelListsExtensionsTest { - private static final Cel CEL = - CelFactory.standardCelBuilder() - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addCompilerLibraries(CelExtensions.lists()) - .addRuntimeLibraries(CelExtensions.lists()) - .setContainer(CelContainer.ofName("cel.expr.conformance.test")) - .addMessageTypes(SimpleTest.getDescriptor()) - .addVar("non_list", SimpleType.DYN) - .build(); - - private static final Cel CEL_WITH_HETEROGENEOUS_NUMERIC_COMPARISONS = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addCompilerLibraries(CelExtensions.lists()) - .addRuntimeLibraries(CelExtensions.lists()) - .setContainer(CelContainer.ofName("cel.expr.conformance.test")) - .addMessageTypes(SimpleTest.getDescriptor()) - .build(); + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; + + private Cel cel; + + @Before + public void setUp() { + // Legacy runtime does not support parsed-only evaluation mode. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + this.cel = setupEnv(runtimeFlavor.builder()); + } @Test public void functionList_byVersion() { @@ -89,10 +86,9 @@ public void macroList_byVersion() { @TestParameters("{expression: 'non_list.slice(1, 3)', expected: '[2, 3]'}") public void slice_success(String expression, String expected) throws Exception { Object result = - CEL.createProgram(CEL.compile(expression).getAst()) - .eval(ImmutableMap.of("non_list", ImmutableSortedSet.of(4L, 1L, 3L, 2L))); + eval(cel, expression, ImmutableMap.of("non_list", ImmutableSortedSet.of(4L, 1L, 3L, 2L))); - assertThat(result).isEqualTo(expectedResult(expected)); + assertThat(result).isEqualTo(eval(cel, expected)); } @Test @@ -107,10 +103,7 @@ public void slice_success(String expression, String expected) throws Exception { "{expression: '[1,2,3,4].slice(-5, -3)', " + "expectedError: 'Negative indexes not supported'}") public void slice_throws(String expression, String expectedError) throws Exception { - assertThat( - assertThrows( - CelEvaluationException.class, - () -> CEL.createProgram(CEL.compile(expression).getAst()).eval())) + assertThat(assertThrows(CelEvaluationException.class, () -> eval(cel, expression))) .hasCauseThat() .hasMessageThat() .contains(expectedError); @@ -127,7 +120,7 @@ public void slice_throws(String expression, String expectedError) throws Excepti @TestParameters("{expression: 'dyn([{1: 2}]).flatten() == [{1: 2}]'}") @TestParameters("{expression: 'dyn([1,2,3,4]).flatten() == [1,2,3,4]'}") public void flattenSingleLevel_success(String expression) throws Exception { - boolean result = (boolean) CEL.createProgram(CEL.compile(expression).getAst()).eval(); + boolean result = (boolean) eval(cel, expression); assertThat(result).isTrue(); } @@ -143,7 +136,7 @@ public void flattenSingleLevel_success(String expression) throws Exception { // The overload with the depth accepts and returns a List(dyn), so the following is permitted. @TestParameters("{expression: '[1].flatten(1) == [1]'}") public void flatten_withDepthValue_success(String expression) throws Exception { - boolean result = (boolean) CEL.createProgram(CEL.compile(expression).getAst()).eval(); + boolean result = (boolean) eval(cel, expression); assertThat(result).isTrue(); } @@ -151,13 +144,17 @@ public void flatten_withDepthValue_success(String expression) throws Exception { @Test public void flatten_negativeDepth_throws() { CelEvaluationException e = - assertThrows( - CelEvaluationException.class, - () -> CEL.createProgram(CEL.compile("[1,2,3,4].flatten(-1)").getAst()).eval()); - - assertThat(e) - .hasMessageThat() - .contains("evaluation error at :17: Function 'list_flatten_list_int' failed"); + assertThrows(CelEvaluationException.class, () -> eval(cel, "[1,2,3,4].flatten(-1)")); + + if (isParseOnly) { + assertThat(e) + .hasMessageThat() + .contains("evaluation error at :17: Function 'flatten' failed"); + } else { + assertThat(e) + .hasMessageThat() + .contains("evaluation error at :17: Function 'list_flatten_list_int' failed"); + } assertThat(e).hasCauseThat().hasMessageThat().isEqualTo("Level must be non-negative"); } @@ -166,9 +163,11 @@ public void flatten_negativeDepth_throws() { @TestParameters("{expression: '[{1: 2}].flatten()'}") @TestParameters("{expression: '[1,2,3,4].flatten()'}") public void flattenSingleLevel_listIsSingleLevel_throws(String expression) { + // This is a type-checking failure. + Assume.assumeFalse(isParseOnly); // Note: Java lacks the capability of conditionally disabling type guards // due to the lack of full-fledged dynamic dispatch. - assertThrows(CelValidationException.class, () -> CEL.compile(expression).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expression).getAst()); } @Test @@ -176,7 +175,7 @@ public void flattenSingleLevel_listIsSingleLevel_throws(String expression) { @TestParameters("{expression: 'lists.range(0) == []'}") @TestParameters("{expression: 'lists.range(-1) == []'}") public void range_success(String expression) throws Exception { - boolean result = (boolean) CEL.createProgram(CEL.compile(expression).getAst()).eval(); + boolean result = (boolean) eval(cel, expression); assertThat(result).isTrue(); } @@ -204,12 +203,13 @@ public void range_success(String expression) throws Exception { @TestParameters("{expression: 'non_list.distinct()', expected: '[1, 2, 3, 4]'}") public void distinct_success(String expression, String expected) throws Exception { Object result = - CEL.createProgram(CEL.compile(expression).getAst()) - .eval( - ImmutableMap.of( - "non_list", ImmutableSortedMultiset.of(1L, 2L, 3L, 4L, 4L, 1L, 3L, 2L))); + eval( + cel, + expression, + ImmutableMap.of( + "non_list", ImmutableSortedMultiset.of(1L, 2L, 3L, 4L, 4L, 1L, 3L, 2L))); - assertThat(result).isEqualTo(expectedResult(expected)); + assertThat(result).isEqualTo(eval(cel, expected)); } @Test @@ -224,10 +224,9 @@ public void distinct_success(String expression, String expected) throws Exceptio @TestParameters("{expression: 'non_list.reverse()', expected: '[4, 3, 2, 1]'}") public void reverse_success(String expression, String expected) throws Exception { Object result = - CEL.createProgram(CEL.compile(expression).getAst()) - .eval(ImmutableMap.of("non_list", ImmutableSortedSet.of(4L, 1L, 3L, 2L))); + eval(cel, expression, ImmutableMap.of("non_list", ImmutableSortedSet.of(4L, 1L, 3L, 2L))); - assertThat(result).isEqualTo(expectedResult(expected)); + assertThat(result).isEqualTo(eval(cel, expected)); } @Test @@ -238,9 +237,9 @@ public void reverse_success(String expression, String expected) throws Exception "{expression: '[\"d\", \"a\", \"b\", \"c\"].sort()', " + "expected: '[\"a\", \"b\", \"c\", \"d\"]'}") public void sort_success(String expression, String expected) throws Exception { - Object result = CEL.createProgram(CEL.compile(expression).getAst()).eval(); + Object result = eval(cel, expression); - assertThat(result).isEqualTo(expectedResult(expected)); + assertThat(result).isEqualTo(eval(cel, expected)); } @Test @@ -248,29 +247,20 @@ public void sort_success(String expression, String expected) throws Exception { @TestParameters("{expression: '[4, 3, 2, 1].sort()', expected: '[1, 2, 3, 4]'}") public void sort_success_heterogeneousNumbers(String expression, String expected) throws Exception { - Object result = - CEL_WITH_HETEROGENEOUS_NUMERIC_COMPARISONS - .createProgram(CEL_WITH_HETEROGENEOUS_NUMERIC_COMPARISONS.compile(expression).getAst()) - .eval(); + Object result = eval(cel, expression); - assertThat(result).isEqualTo(expectedResult(expected)); + assertThat(result).isEqualTo(eval(cel, expected)); } @Test @TestParameters( "{expression: '[\"d\", 3, 2, \"c\"].sort()', " + "expectedError: 'List elements must have the same type'}") - @TestParameters( - "{expression: '[3.0, 2, 1u].sort()', " - + "expectedError: 'List elements must have the same type'}") @TestParameters( "{expression: '[SimpleTest{name: \"a\"}, SimpleTest{name: \"b\"}].sort()', " + "expectedError: 'List elements must be comparable'}") public void sort_throws(String expression, String expectedError) throws Exception { - assertThat( - assertThrows( - CelEvaluationException.class, - () -> CEL.createProgram(CEL.compile(expression).getAst()).eval())) + assertThat(assertThrows(CelEvaluationException.class, () -> eval(cel, expression))) .hasCauseThat() .hasMessageThat() .contains(expectedError); @@ -296,9 +286,9 @@ public void sort_throws(String expression, String expectedError) throws Exceptio + " SimpleTest{name: \"baz\"}," + " SimpleTest{name: \"foo\"}]'}") public void sortBy_success(String expression, String expected) throws Exception { - Object result = CEL.createProgram(CEL.compile(expression).getAst()).eval(); + Object result = eval(cel, expression); - assertThat(result).isEqualTo(expectedResult(expected)); + assertThat(result).isEqualTo(eval(cel, expected)); } @Test @@ -313,7 +303,7 @@ public void sortBy_throws_validationException(String expression, String expected assertThat( assertThrows( CelValidationException.class, - () -> CEL.createProgram(CEL.compile(expression).getAst()).eval())) + () -> cel.createProgram(cel.compile(expression).getAst()).eval())) .hasMessageThat() .contains(expectedError); } @@ -327,17 +317,31 @@ public void sortBy_throws_validationException(String expression, String expected + "expectedError: 'List elements must be comparable'}") public void sortBy_throws_evaluationException(String expression, String expectedError) throws Exception { - assertThat( - assertThrows( - CelEvaluationException.class, - () -> CEL.createProgram(CEL.compile(expression).getAst()).eval())) + assertThat(assertThrows(CelEvaluationException.class, () -> eval(cel, expression))) .hasCauseThat() .hasMessageThat() .contains(expectedError); } - private static Object expectedResult(String expression) - throws CelEvaluationException, CelValidationException { - return CEL.createProgram(CEL.compile(expression).getAst()).eval(); + private static Cel setupEnv(CelBuilder celBuilder) { + return celBuilder + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addCompilerLibraries(CelExtensions.lists()) + .addRuntimeLibraries(CelExtensions.lists()) + .setContainer(CelContainer.ofName("cel.expr.conformance.test")) + .addMessageTypes(SimpleTest.getDescriptor()) + .addVar("non_list", SimpleType.DYN) + .build(); + } + + + + private Object eval(Cel cel, String expr) throws Exception { + return eval(cel, expr, ImmutableMap.of()); + } + + private Object eval(Cel cel, String expr, Map vars) throws Exception { + CelAbstractSyntaxTree ast = isParseOnly ? cel.parse(expr).getAst() : cel.compile(expr).getAst(); + return cel.createProgram(ast).eval(vars); } } From 4f9a3a8987fdb24297216b5817e4fcef7cb11b3a Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 15 Apr 2026 11:26:41 -0700 Subject: [PATCH 37/66] Add parsed-only evaluation coverage to Comprehensions Extensions Includes a fix to preserve first encountered error message for comprehensions PiperOrigin-RevId: 900264092 --- .../CelComprehensionsExtensionsTest.java | 150 +++++++++--------- .../extensions/CelListsExtensionsTest.java | 2 - .../java/dev/cel/runtime/planner/EvalAnd.java | 5 +- .../java/dev/cel/runtime/planner/EvalOr.java | 5 +- .../planner_unknownResultSet_errors.baseline | 4 +- 5 files changed, 89 insertions(+), 77 deletions(-) diff --git a/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java index 34696b688..fbe160cd3 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java @@ -15,11 +15,15 @@ package dev.cel.extensions; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; import static org.junit.Assert.assertThrows; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; @@ -28,15 +32,15 @@ import dev.cel.common.exceptions.CelIndexOutOfBoundsException; import dev.cel.common.types.SimpleType; import dev.cel.common.types.TypeParamType; -import dev.cel.compiler.CelCompiler; -import dev.cel.compiler.CelCompilerFactory; import dev.cel.parser.CelMacro; import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparser; import dev.cel.parser.CelUnparserFactory; import dev.cel.runtime.CelEvaluationException; -import dev.cel.runtime.CelRuntime; -import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.testing.CelRuntimeFlavor; +import java.util.Map; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -46,26 +50,35 @@ public class CelComprehensionsExtensionsTest { private static final CelOptions CEL_OPTIONS = CelOptions.current() + .enableHeterogeneousNumericComparisons(true) // Enable macro call population for unparsing .populateMacroCalls(true) .build(); - private static final CelCompiler CEL_COMPILER = - CelCompilerFactory.standardCelCompilerBuilder() - .setOptions(CEL_OPTIONS) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addLibraries(CelExtensions.comprehensions()) - .addLibraries(CelExtensions.lists()) - .addLibraries(CelExtensions.strings()) - .addLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) - .build(); - private static final CelRuntime CEL_RUNTIME = - CelRuntimeFactory.standardCelRuntimeBuilder() - .addLibraries(CelOptionalLibrary.INSTANCE) - .addLibraries(CelExtensions.lists()) - .addLibraries(CelExtensions.strings()) - .addLibraries(CelExtensions.comprehensions()) - .build(); + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; + + private Cel cel; + + @Before + public void setUp() { + // Legacy runtime does not support parsed-only evaluation mode. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + this.cel = + runtimeFlavor + .builder() + .setOptions(CEL_OPTIONS) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addCompilerLibraries(CelExtensions.comprehensions()) + .addCompilerLibraries(CelExtensions.lists()) + .addCompilerLibraries(CelExtensions.strings()) + .addCompilerLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) + .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) + .addRuntimeLibraries(CelExtensions.lists()) + .addRuntimeLibraries(CelExtensions.strings()) + .addRuntimeLibraries(CelExtensions.comprehensions()) + .build(); + } private static final CelUnparser UNPARSER = CelUnparserFactory.newUnparser(); @@ -101,11 +114,7 @@ public void allMacro_twoVarComprehension_success( }) String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - Object result = CEL_RUNTIME.createProgram(ast).eval(); - - assertThat(result).isEqualTo(true); + assertThat(eval(expr)).isEqualTo(true); } @Test @@ -127,11 +136,7 @@ public void existsMacro_twoVarComprehension_success( }) String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - Object result = CEL_RUNTIME.createProgram(ast).eval(); - - assertThat(result).isEqualTo(true); + assertThat(eval(expr)).isEqualTo(true); } @Test @@ -156,11 +161,7 @@ public void exists_oneMacro_twoVarComprehension_success( }) String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - Object result = CEL_RUNTIME.createProgram(ast).eval(); - - assertThat(result).isEqualTo(true); + assertThat(eval(expr)).isEqualTo(true); } @Test @@ -182,11 +183,7 @@ public void transformListMacro_twoVarComprehension_success( }) String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - Object result = CEL_RUNTIME.createProgram(ast).eval(); - - assertThat(result).isEqualTo(true); + assertThat(eval(expr)).isEqualTo(true); } @Test @@ -210,11 +207,7 @@ public void transformMapMacro_twoVarComprehension_success( }) String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - Object result = CEL_RUNTIME.createProgram(ast).eval(); - - assertThat(result).isEqualTo(true); + assertThat(eval(expr)).isEqualTo(true); } @Test @@ -238,24 +231,22 @@ public void transformMapEntryMacro_twoVarComprehension_success( }) String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - Object result = CEL_RUNTIME.createProgram(ast).eval(); - - assertThat(result).isEqualTo(true); + assertThat(eval(expr)).isEqualTo(true); } @Test public void comprehension_onTypeParam_success() throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() + Assume.assumeFalse(isParseOnly); + Cel customCel = + runtimeFlavor + .builder() .setOptions(CEL_OPTIONS) .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addLibraries(CelExtensions.comprehensions()) + .addCompilerLibraries(CelExtensions.comprehensions()) .addVar("items", TypeParamType.create("T")) .build(); - CelAbstractSyntaxTree ast = celCompiler.compile("items.all(i, v, v > 0)").getAst(); + CelAbstractSyntaxTree ast = customCel.compile("items.all(i, v, v > 0)").getAst(); assertThat(ast.getResultType()).isEqualTo(SimpleType.BOOL); } @@ -275,7 +266,7 @@ public void unparseAST_twoVarComprehension( }) String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = isParseOnly ? cel.parse(expr).getAst() : cel.compile(expr).getAst(); String unparsed = UNPARSER.unparse(ast); assertThat(unparsed).isEqualTo(expr); } @@ -318,8 +309,9 @@ public void unparseAST_twoVarComprehension( "{expr: \"{'hello': 'world', 'greetings': 'tacocat'}.transformMapEntry(k, v, []) == {}\"," + " err: 'no matching overload'}") public void twoVarComprehension_compilerErrors(String expr, String err) throws Exception { + Assume.assumeFalse(isParseOnly); CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains(err); } @@ -339,34 +331,50 @@ public void twoVarComprehension_compilerErrors(String expr, String err) throws E + " '2.0' already exists\"}") public void twoVarComprehension_keyCollision_runtimeError(String expr, String err) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> CEL_RUNTIME.createProgram(ast).eval()); - - assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); - assertThat(e).hasCauseThat().hasMessageThat().contains(err); + // Planner does not allow decimals for map keys + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.PLANNER) && expr.contains("2.0")); + + CelEvaluationException e = assertThrows(CelEvaluationException.class, () -> eval(expr)); + Throwable cause = + Throwables.getCausalChain(e).stream() + .filter(IllegalArgumentException.class::isInstance) + .filter(t -> t.getMessage() != null && t.getMessage().contains(err)) + .findFirst() + .orElse(null); + + assertWithMessage( + "Expected IllegalArgumentException with message containing '%s' in cause chain", err) + .that(cause) + .isNotNull(); } @Test public void twoVarComprehension_arithmeticException_runtimeError() throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile("[0].all(i, k, i/k < k)").getAst(); - CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> CEL_RUNTIME.createProgram(ast).eval()); - + assertThrows(CelEvaluationException.class, () -> eval("[0].all(i, k, i/k < k)")); assertThat(e).hasCauseThat().isInstanceOf(CelDivideByZeroException.class); assertThat(e).hasCauseThat().hasMessageThat().contains("/ by zero"); } @Test public void twoVarComprehension_outOfBounds_runtimeError() throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile("[1, 2].exists(i, v, [0][v] > 0)").getAst(); - CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> CEL_RUNTIME.createProgram(ast).eval()); - + assertThrows(CelEvaluationException.class, () -> eval("[1, 2].exists(i, v, [0][v] > 0)")); assertThat(e).hasCauseThat().isInstanceOf(CelIndexOutOfBoundsException.class); assertThat(e).hasCauseThat().hasMessageThat().contains("Index out of bounds: 1"); } + + private Object eval(String expression) throws Exception { + return eval(this.cel, expression, ImmutableMap.of()); + } + + private Object eval(Cel cel, String expression, Map variables) throws Exception { + CelAbstractSyntaxTree ast; + if (isParseOnly) { + ast = cel.parse(expression).getAst(); + } else { + ast = cel.compile(expression).getAst(); + } + return cel.createProgram(ast).eval(variables); + } } diff --git a/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java index 2083ccc42..b36e0e92e 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java @@ -334,8 +334,6 @@ private static Cel setupEnv(CelBuilder celBuilder) { .build(); } - - private Object eval(Cel cel, String expr) throws Exception { return eval(cel, expr, ImmutableMap.of()); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalAnd.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalAnd.java index eb7406071..91f5b2ff4 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalAnd.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalAnd.java @@ -38,7 +38,10 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) { return false; } } else if (argVal instanceof ErrorValue) { - errorValue = (ErrorValue) argVal; + // Preserve the first encountered error instead of overwriting it with subsequent errors. + if (errorValue == null) { + errorValue = (ErrorValue) argVal; + } } else if (argVal instanceof AccumulatedUnknowns) { unknowns = AccumulatedUnknowns.maybeMerge(unknowns, argVal); } else { diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalOr.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalOr.java index bc19ed81a..62e617d9d 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalOr.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalOr.java @@ -38,7 +38,10 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) { return true; } } else if (argVal instanceof ErrorValue) { - errorValue = (ErrorValue) argVal; + // Preserve the first encountered error instead of overwriting it with subsequent errors. + if (errorValue == null) { + errorValue = (ErrorValue) argVal; + } } else if (argVal instanceof AccumulatedUnknowns) { unknowns = AccumulatedUnknowns.maybeMerge(unknowns, argVal); } else { diff --git a/runtime/src/test/resources/planner_unknownResultSet_errors.baseline b/runtime/src/test/resources/planner_unknownResultSet_errors.baseline index 812067ddf..7885e9da1 100644 --- a/runtime/src/test/resources/planner_unknownResultSet_errors.baseline +++ b/runtime/src/test/resources/planner_unknownResultSet_errors.baseline @@ -32,7 +32,7 @@ single_timestamp { seconds: 15 } , unknown_attributes=[x.single_int32]} -error: evaluation error at test_location:89: Text 'another bad timestamp string' could not be parsed at index 0 +error: evaluation error at test_location:31: Text 'bad timestamp string' could not be parsed at index 0 error_code: BAD_FORMAT Source: x.single_int32 == 1 || x.single_timestamp <= timestamp("bad timestamp string") @@ -69,7 +69,7 @@ single_timestamp { seconds: 15 } , unknown_attributes=[x.single_int32]} -error: evaluation error at test_location:89: Text 'another bad timestamp string' could not be parsed at index 0 +error: evaluation error at test_location:31: Text 'bad timestamp string' could not be parsed at index 0 error_code: BAD_FORMAT Source: x From 5667bd825ed42bacaee4193c42f58b71786cc435 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 17 Apr 2026 13:04:41 -0700 Subject: [PATCH 38/66] Implement mutable map support for planner PiperOrigin-RevId: 901433988 --- .../java/dev/cel/common/values/BUILD.bazel | 32 ++++ .../cel/common/values/MutableMapValue.java | 146 ++++++++++++++++++ common/values/BUILD.bazel | 12 ++ .../main/java/dev/cel/extensions/BUILD.bazel | 1 + .../CelComprehensionsExtensions.java | 45 +++--- .../test/java/dev/cel/extensions/BUILD.bazel | 1 + .../CelComprehensionsExtensionsTest.java | 11 ++ .../java/dev/cel/runtime/planner/BUILD.bazel | 1 + .../dev/cel/runtime/planner/EvalFold.java | 15 +- 9 files changed, 241 insertions(+), 23 deletions(-) create mode 100644 common/src/main/java/dev/cel/common/values/MutableMapValue.java diff --git a/common/src/main/java/dev/cel/common/values/BUILD.bazel b/common/src/main/java/dev/cel/common/values/BUILD.bazel index 53ffdda3d..0d1d5431f 100644 --- a/common/src/main/java/dev/cel/common/values/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/values/BUILD.bazel @@ -118,6 +118,38 @@ java_library( ], ) +java_library( + name = "mutable_map_value", + srcs = ["MutableMapValue.java"], + tags = [ + ], + deps = [ + "//common/annotations", + "//common/exceptions:attribute_not_found", + "//common/types", + "//common/types:type_providers", + "//common/values", + "//common/values:cel_value", + "@maven//:com_google_errorprone_error_prone_annotations", + ], +) + +cel_android_library( + name = "mutable_map_value_android", + srcs = ["MutableMapValue.java"], + tags = [ + ], + deps = [ + ":cel_value_android", + "//common/annotations", + "//common/exceptions:attribute_not_found", + "//common/types:type_providers_android", + "//common/types:types_android", + "//common/values:values_android", + "@maven//:com_google_errorprone_error_prone_annotations", + ], +) + cel_android_library( name = "values_android", srcs = CEL_VALUES_SOURCES, diff --git a/common/src/main/java/dev/cel/common/values/MutableMapValue.java b/common/src/main/java/dev/cel/common/values/MutableMapValue.java new file mode 100644 index 000000000..706436b2e --- /dev/null +++ b/common/src/main/java/dev/cel/common/values/MutableMapValue.java @@ -0,0 +1,146 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.values; + +import com.google.errorprone.annotations.Immutable; +import dev.cel.common.annotations.Internal; +import dev.cel.common.exceptions.CelAttributeNotFoundException; +import dev.cel.common.types.CelType; +import dev.cel.common.types.MapType; +import dev.cel.common.types.SimpleType; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +/** + * A custom CelValue implementation that allows O(1) insertions for maps during comprehension. + * + *

CEL Library Internals. Do Not Use. + */ +@Internal +@Immutable +@SuppressWarnings("Immutable") // Intentionally mutable for performance reasons +public final class MutableMapValue extends CelValue + implements SelectableValue, Map { + private final Map internalMap; + private final CelType celType; + + public static MutableMapValue create(Map map) { + return new MutableMapValue(map); + } + + @Override + public int size() { + return internalMap.size(); + } + + @Override + public boolean isEmpty() { + return internalMap.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return internalMap.containsKey(key); + } + + @Override + public boolean containsValue(Object value) { + return internalMap.containsValue(value); + } + + @Override + public Object get(Object key) { + return internalMap.get(key); + } + + @Override + public Object put(Object key, Object value) { + return internalMap.put(key, value); + } + + @Override + public Object remove(Object key) { + return internalMap.remove(key); + } + + @Override + public void putAll(Map m) { + internalMap.putAll(m); + } + + @Override + public void clear() { + internalMap.clear(); + } + + @Override + public Set keySet() { + return internalMap.keySet(); + } + + @Override + public Collection values() { + return internalMap.values(); + } + + @Override + public Set> entrySet() { + return internalMap.entrySet(); + } + + @Override + public Object select(Object field) { + Object val = internalMap.get(field); + if (val != null) { + return val; + } + if (!internalMap.containsKey(field)) { + throw CelAttributeNotFoundException.forMissingMapKey(field.toString()); + } + throw CelAttributeNotFoundException.of( + String.format("Map value cannot be null for key: %s", field)); + } + + @Override + public Optional find(Object field) { + if (internalMap.containsKey(field)) { + return Optional.ofNullable(internalMap.get(field)); + } + return Optional.empty(); + } + + @Override + public Object value() { + return this; + } + + @Override + public boolean isZeroValue() { + return internalMap.isEmpty(); + } + + @Override + public CelType celType() { + return celType; + } + + private MutableMapValue(Map map) { + this.internalMap = new LinkedHashMap<>(map); + this.celType = MapType.create(SimpleType.DYN, SimpleType.DYN); + } +} diff --git a/common/values/BUILD.bazel b/common/values/BUILD.bazel index f1fa107b6..74bfa9e0f 100644 --- a/common/values/BUILD.bazel +++ b/common/values/BUILD.bazel @@ -47,6 +47,18 @@ cel_android_library( exports = ["//common/src/main/java/dev/cel/common/values:values_android"], ) +java_library( + name = "mutable_map_value", + visibility = ["//:internal"], + exports = ["//common/src/main/java/dev/cel/common/values:mutable_map_value"], +) + +cel_android_library( + name = "mutable_map_value_android", + visibility = ["//:internal"], + exports = ["//common/src/main/java/dev/cel/common/values:mutable_map_value_android"], +) + java_library( name = "base_proto_cel_value_converter", exports = ["//common/src/main/java/dev/cel/common/values:base_proto_cel_value_converter"], diff --git a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel index 77663f2fa..2eb26846f 100644 --- a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel @@ -307,6 +307,7 @@ java_library( "//common:options", "//common/ast", "//common/types", + "//common/values:mutable_map_value", "//compiler:compiler_builder", "//extensions:extension_library", "//parser:macro", diff --git a/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java index 7c298a773..3bf47c4a6 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java @@ -29,6 +29,7 @@ import dev.cel.common.ast.CelExpr; import dev.cel.common.types.MapType; import dev.cel.common.types.TypeParamType; +import dev.cel.common.values.MutableMapValue; import dev.cel.compiler.CelCompilerLibrary; import dev.cel.parser.CelMacro; import dev.cel.parser.CelMacroExprFactory; @@ -171,38 +172,46 @@ public void setParserOptions(CelParserBuilder parserBuilder) { parserBuilder.addMacros(macros()); } - // TODO: Implement a more efficient map insertion based on mutability once mutable - // maps are supported in Java stack. - private static ImmutableMap mapInsertMap( + private static Map mapInsertMap( Map targetMap, Map mapToMerge, RuntimeEquality equality) { - ImmutableMap.Builder resultBuilder = - ImmutableMap.builderWithExpectedSize(targetMap.size() + mapToMerge.size()); - - for (Map.Entry entry : mapToMerge.entrySet()) { - if (equality.findInMap(targetMap, entry.getKey()).isPresent()) { + for (Object key : mapToMerge.keySet()) { + if (equality.findInMap(targetMap, key).isPresent()) { throw new IllegalArgumentException( - String.format("insert failed: key '%s' already exists", entry.getKey())); - } else { - resultBuilder.put(entry.getKey(), entry.getValue()); + String.format("insert failed: key '%s' already exists", key)); } } - return resultBuilder.putAll(targetMap).buildOrThrow(); + + if (targetMap instanceof MutableMapValue) { + MutableMapValue wrapper = (MutableMapValue) targetMap; + wrapper.putAll(mapToMerge); + return wrapper; + } + + return ImmutableMap.builderWithExpectedSize(targetMap.size() + mapToMerge.size()) + .putAll(targetMap) + .putAll(mapToMerge) + .buildOrThrow(); } - private static ImmutableMap mapInsertKeyValue( - Object[] args, RuntimeEquality equality) { - Map map = (Map) args[0]; + private static Map mapInsertKeyValue(Object[] args, RuntimeEquality equality) { + Map mapArg = (Map) args[0]; Object key = args[1]; Object value = args[2]; - if (equality.findInMap(map, key).isPresent()) { + if (equality.findInMap(mapArg, key).isPresent()) { throw new IllegalArgumentException( String.format("insert failed: key '%s' already exists", key)); } + if (mapArg instanceof MutableMapValue) { + MutableMapValue mutableMap = (MutableMapValue) mapArg; + mutableMap.put(key, value); + return mutableMap; + } + ImmutableMap.Builder builder = - ImmutableMap.builderWithExpectedSize(map.size() + 1); - return builder.put(key, value).putAll(map).buildOrThrow(); + ImmutableMap.builderWithExpectedSize(mapArg.size() + 1); + return builder.put(key, value).putAll(mapArg).buildOrThrow(); } private static Optional expandAllMacro( diff --git a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel index 0b6502410..19fd3657e 100644 --- a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel @@ -15,6 +15,7 @@ java_library( "//common:compiler_common", "//common:container", "//common:options", + "//common/exceptions:attribute_not_found", "//common/exceptions:divide_by_zero", "//common/exceptions:index_out_of_bounds", "//common/types", diff --git a/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java index fbe160cd3..374178540 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java @@ -28,6 +28,7 @@ import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; import dev.cel.common.CelValidationException; +import dev.cel.common.exceptions.CelAttributeNotFoundException; import dev.cel.common.exceptions.CelDivideByZeroException; import dev.cel.common.exceptions.CelIndexOutOfBoundsException; import dev.cel.common.types.SimpleType; @@ -222,6 +223,7 @@ public void transformMapEntryMacro_twoVarComprehension_success( + " 'key2': 'value2'}", // map.transformMapEntry() "{'hello': 'world', 'greetings': 'tacocat'}.transformMapEntry(k, v, {}) == {}", + "{'a': 1, 'b': 2}.transformMapEntry(k, v, {k: v}) == {'a': 1, 'b': 2}", "{'a': 1, 'b': 2}.transformMapEntry(k, v, {k + '_new': v * 2}) == {'a_new': 2," + " 'b_new': 4}", "{'a': 1, 'b': 2, 'c': 3}.transformMapEntry(k, v, v % 2 == 1, {k: v * 10}) == {'a': 10," @@ -364,6 +366,15 @@ public void twoVarComprehension_outOfBounds_runtimeError() throws Exception { assertThat(e).hasCauseThat().hasMessageThat().contains("Index out of bounds: 1"); } + @Test + public void mutableMapValue_select_missingKeyException() throws Exception { + CelEvaluationException e = + assertThrows( + CelEvaluationException.class, () -> eval("cel.bind(my_map, {'a': 1}, my_map.b)")); + assertThat(e).hasCauseThat().isInstanceOf(CelAttributeNotFoundException.class); + assertThat(e).hasCauseThat().hasMessageThat().contains("key 'b' is not present in map."); + } + private Object eval(String expression) throws Exception { return eval(this.cel, expression, ImmutableMap.of()); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel index 824c918d8..96382b9a9 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel @@ -365,6 +365,7 @@ java_library( ":activation_wrapper", ":planned_interpretable", "//common/exceptions:runtime_exception", + "//common/values:mutable_map_value", "//runtime:accumulated_unknowns", "//runtime:concatenated_list_view", "//runtime:evaluation_exception", diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java index 2eb30671e..090a8bfae 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java @@ -15,8 +15,10 @@ package dev.cel.runtime.planner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.Immutable; import dev.cel.common.exceptions.CelRuntimeException; +import dev.cel.common.values.MutableMapValue; import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.ConcatenatedListView; @@ -131,7 +133,7 @@ private Object evalList(Collection iterRange, Folder folder, ExecutionFrame f boolean cond = (boolean) condition.eval(folder, frame); if (!cond) { folder.computeResult = true; - return result.eval(folder, frame); + return maybeUnwrapAccumulator(result.eval(folder, frame)); } folder.accuVal = loopStep.eval(folder, frame); @@ -139,14 +141,16 @@ private Object evalList(Collection iterRange, Folder folder, ExecutionFrame f index++; } folder.computeResult = true; - return result.eval(folder, frame); + return maybeUnwrapAccumulator(result.eval(folder, frame)); } private static Object maybeWrapAccumulator(Object val) { if (val instanceof Collection) { return new ConcatenatedListView<>((Collection) val); } - // TODO: Introduce mutable map support (for comp v2) + if (val instanceof Map) { + return MutableMapValue.create((Map) val); + } return val; } @@ -154,8 +158,9 @@ private static Object maybeUnwrapAccumulator(Object val) { if (val instanceof ConcatenatedListView) { return ImmutableList.copyOf((ConcatenatedListView) val); } - - // TODO: Introduce mutable map support (for comp v2) + if (val instanceof MutableMapValue) { + return ImmutableMap.copyOf((MutableMapValue) val); + } return val; } From 64bd51af154a5187d581e741958183cad4fe79b0 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 17 Apr 2026 16:07:22 -0700 Subject: [PATCH 39/66] Change string.split to return an immutable list. Add parsed-only evaluation coverage to CelStringExtensions PiperOrigin-RevId: 901514168 --- .../cel/extensions/CelStringExtensions.java | 28 +- .../main/java/dev/cel/extensions/README.md | 4 +- .../CelComprehensionsExtensionsTest.java | 5 +- .../extensions/CelListsExtensionsTest.java | 7 +- .../extensions/CelStringExtensionsTest.java | 509 ++++++++---------- 5 files changed, 238 insertions(+), 315 deletions(-) diff --git a/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java index 37c8270cc..2bb477b82 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelStringExtensions.java @@ -23,7 +23,6 @@ import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; import com.google.errorprone.annotations.Immutable; import dev.cel.checker.CelCheckerBuilder; import dev.cel.common.CelFunctionDecl; @@ -37,7 +36,6 @@ import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelRuntimeBuilder; import dev.cel.runtime.CelRuntimeLibrary; -import java.util.ArrayList; import java.util.List; import java.util.Set; @@ -475,7 +473,7 @@ private static String quote(String s) { sb.append('"'); for (int i = 0; i < s.length(); ) { int codePoint = s.codePointAt(i); - if (isMalformedUtf16(s, i, codePoint)) { + if (isMalformedUtf16(s, i)) { sb.append('\uFFFD'); i++; continue; @@ -518,7 +516,7 @@ private static String quote(String s) { return sb.toString(); } - private static boolean isMalformedUtf16(String s, int index, int codePoint) { + private static boolean isMalformedUtf16(String s, int index) { char currentChar = s.charAt(index); if (Character.isLowSurrogate(currentChar)) { return true; @@ -587,14 +585,14 @@ private static String reverse(String s) { return new StringBuilder(s).reverse().toString(); } - private static List split(String str, String separator) { + private static ImmutableList split(String str, String separator) { return split(str, separator, Integer.MAX_VALUE); } /** * @param args Object array with indices of: [0: string], [1: separator], [2: limit] */ - private static List split(Object[] args) throws CelEvaluationException { + private static ImmutableList split(Object[] args) throws CelEvaluationException { long limitInLong = (Long) args[2]; int limit; try { @@ -609,16 +607,14 @@ private static List split(Object[] args) throws CelEvaluationException { return split((String) args[0], (String) args[1], limit); } - /** Returns a **mutable** list of strings split on the separator */ - private static List split(String str, String separator, int limit) { + /** Returns an immutable list of strings split on the separator */ + private static ImmutableList split(String str, String separator, int limit) { if (limit == 0) { - return new ArrayList<>(); + return ImmutableList.of(); } if (limit == 1) { - List singleElementList = new ArrayList<>(); - singleElementList.add(str); - return singleElementList; + return ImmutableList.of(str); } if (limit < 0) { @@ -630,7 +626,7 @@ private static List split(String str, String separator, int limit) { } Iterable splitString = Splitter.on(separator).limit(limit).split(str); - return Lists.newArrayList(splitString); + return ImmutableList.copyOf(splitString); } /** @@ -643,8 +639,8 @@ private static List split(String str, String separator, int limit) { *

This exists because neither the built-in String.split nor Guava's splitter is able to deal * with separating single printable characters. */ - private static List explode(String str, int limit) { - List exploded = new ArrayList<>(); + private static ImmutableList explode(String str, int limit) { + ImmutableList.Builder exploded = ImmutableList.builder(); CelCodePointArray codePointArray = CelCodePointArray.fromString(str); if (limit > 0) { limit -= 1; @@ -656,7 +652,7 @@ private static List explode(String str, int limit) { if (codePointArray.length() > limit) { exploded.add(codePointArray.slice(limit, codePointArray.length()).toString()); } - return exploded; + return exploded.build(); } private static Object substring(String s, long i) throws CelEvaluationException { diff --git a/extensions/src/main/java/dev/cel/extensions/README.md b/extensions/src/main/java/dev/cel/extensions/README.md index c3fbf8c54..b1d3611b4 100644 --- a/extensions/src/main/java/dev/cel/extensions/README.md +++ b/extensions/src/main/java/dev/cel/extensions/README.md @@ -522,7 +522,7 @@ Examples: ### Split -Returns a mutable list of strings split from the input by the given separator. The +Returns a list of strings split from the input by the given separator. The function accepts an optional argument specifying a limit on the number of substrings produced by the split. @@ -1069,4 +1069,4 @@ Examples: {valueVar: indexVar}) // returns {1:0, 2:1, 3:2} {'greeting': 'aloha', 'farewell': 'aloha'} - .transformMapEntry(k, v, {v: k}) // error, duplicate key \ No newline at end of file + .transformMapEntry(k, v, {v: k}) // error, duplicate key diff --git a/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java index 374178540..207178cfe 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java @@ -28,6 +28,7 @@ import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; import dev.cel.common.CelValidationException; +import dev.cel.common.CelValidationResult; import dev.cel.common.exceptions.CelAttributeNotFoundException; import dev.cel.common.exceptions.CelDivideByZeroException; import dev.cel.common.exceptions.CelIndexOutOfBoundsException; @@ -312,8 +313,8 @@ public void unparseAST_twoVarComprehension( + " err: 'no matching overload'}") public void twoVarComprehension_compilerErrors(String expr, String err) throws Exception { Assume.assumeFalse(isParseOnly); - CelValidationException e = - assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); + CelValidationResult result = cel.compile(expr); + CelValidationException e = assertThrows(CelValidationException.class, () -> result.getAst()); assertThat(e).hasMessageThat().contains(err); } diff --git a/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java index b36e0e92e..f36d90e2d 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java @@ -27,6 +27,7 @@ import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelValidationException; +import dev.cel.common.CelValidationResult; import dev.cel.common.types.SimpleType; import dev.cel.expr.conformance.test.SimpleTest; import dev.cel.parser.CelStandardMacro; @@ -300,10 +301,8 @@ public void sortBy_success(String expression, String expected) throws Exception + "expectedError: 'variable name must be a simple identifier'}") public void sortBy_throws_validationException(String expression, String expectedError) throws Exception { - assertThat( - assertThrows( - CelValidationException.class, - () -> cel.createProgram(cel.compile(expression).getAst()).eval())) + CelValidationResult result = cel.compile(expression); + assertThat(assertThrows(CelValidationException.class, () -> result.getAst())) .hasMessageThat() .contains(expectedError); } diff --git a/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java index 3624e0902..e7542b7b7 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java @@ -18,43 +18,56 @@ import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterables; import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; import dev.cel.common.CelValidationException; +import dev.cel.common.CelValidationResult; import dev.cel.common.types.SimpleType; import dev.cel.compiler.CelCompiler; import dev.cel.compiler.CelCompilerFactory; import dev.cel.extensions.CelStringExtensions.Function; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelRuntime; -import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.testing.CelRuntimeFlavor; import java.util.List; +import java.util.Map; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public final class CelStringExtensionsTest { - private static final CelCompiler COMPILER = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.strings()) - .addVar("s", SimpleType.STRING) - .addVar("separator", SimpleType.STRING) - .addVar("index", SimpleType.INT) - .addVar("offset", SimpleType.INT) - .addVar("indexOfParam", SimpleType.STRING) - .addVar("beginIndex", SimpleType.INT) - .addVar("endIndex", SimpleType.INT) - .addVar("limit", SimpleType.INT) - .build(); - - private static final CelRuntime RUNTIME = - CelRuntimeFactory.standardCelRuntimeBuilder().addLibraries(CelExtensions.strings()).build(); + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; + + private Cel cel; + + @Before + public void setUp() { + // Legacy runtime does not support parsed-only evaluation mode. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + this.cel = + runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.strings()) + .addRuntimeLibraries(CelExtensions.strings()) + .addVar("s", SimpleType.STRING) + .addVar("separator", SimpleType.STRING) + .addVar("index", SimpleType.INT) + .addVar("offset", SimpleType.INT) + .addVar("indexOfParam", SimpleType.STRING) + .addVar("beginIndex", SimpleType.INT) + .addVar("endIndex", SimpleType.INT) + .addVar("limit", SimpleType.INT) + .build(); + } @Test public void library() { @@ -92,10 +105,8 @@ public void library() { @TestParameters("{string: '😁😑😦', beginIndex: 3, expectedResult: ''}") public void substring_beginIndex_success(String string, int beginIndex, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.substring(beginIndex)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string, "beginIndex", beginIndex)); + Object evaluatedResult = + eval("s.substring(beginIndex)", ImmutableMap.of("s", string, "beginIndex", beginIndex)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -108,10 +119,7 @@ public void substring_beginIndex_success(String string, int beginIndex, String e @TestParameters( "{string: 'A!@#$%^&*()-_+=?/<>.,;:''\"\\', expectedResult: 'a!@#$%^&*()-_+=?/<>.,;:''\"\\'}") public void lowerAscii_success(String string, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.lowerAscii()").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string)); + Object evaluatedResult = eval("s.lowerAscii()", ImmutableMap.of("s", string)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -127,10 +135,7 @@ public void lowerAscii_success(String string, String expectedResult) throws Exce @TestParameters("{string: 'A😁B 😑C가😦D', expectedResult: 'a😁b 😑c가😦d'}") public void lowerAscii_outsideAscii_success(String string, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.lowerAscii()").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string)); + Object evaluatedResult = eval("s.lowerAscii()", ImmutableMap.of("s", string)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -161,10 +166,8 @@ public void lowerAscii_outsideAscii_success(String string, String expectedResult + " ['The quick brown ', ' jumps over the lazy dog']}") public void split_ascii_success(String string, String separator, List expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.split(separator)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string, "separator", separator)); + Object evaluatedResult = + eval("s.split(separator)", ImmutableMap.of("s", string, "separator", separator)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -182,34 +185,30 @@ public void split_ascii_success(String string, String separator, List ex @TestParameters("{string: '😁a😦나😑 😦', separator: '😁a😦나😑 😦', expectedResult: ['','']}") public void split_unicode_success(String string, String separator, List expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.split(separator)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string, "separator", separator)); + Object evaluatedResult = + eval("s.split(separator)", ImmutableMap.of("s", string, "separator", separator)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @Test @SuppressWarnings("unchecked") // Test only, need List cast to test mutability - public void split_collectionIsMutable() throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("'test'.split('')").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); + public void split_collectionIsImmutable() throws Exception { + CelAbstractSyntaxTree ast = cel.compile("'test'.split('')").getAst(); + CelRuntime.Program program = cel.createProgram(ast); List evaluatedResult = (List) program.eval(); - evaluatedResult.add("a"); - evaluatedResult.add("b"); - evaluatedResult.add("c"); - evaluatedResult.remove("c"); - assertThat(evaluatedResult).containsExactly("t", "e", "s", "t", "a", "b").inOrder(); + assertThrows(UnsupportedOperationException.class, () -> evaluatedResult.add("a")); } @Test public void split_separatorIsNonString_throwsException() { + // This is a type-check failure. + Assume.assumeFalse(isParseOnly); + CelValidationResult result = cel.compile("'12'.split(2)"); CelValidationException exception = - assertThrows( - CelValidationException.class, () -> COMPILER.compile("'12'.split(2)").getAst()); + assertThrows(CelValidationException.class, () -> result.getAst()); assertThat(exception).hasMessageThat().contains("found no matching overload for 'split'"); } @@ -295,11 +294,10 @@ public void split_separatorIsNonString_throwsException() { + " expectedResult: ['The quick brown ', ' jumps over the lazy dog']}") public void split_asciiWithLimit_success( String string, String separator, int limit, List expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.split(separator, limit)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = - program.eval(ImmutableMap.of("s", string, "separator", separator, "limit", limit)); + eval( + "s.split(separator, limit)", + ImmutableMap.of("s", string, "separator", separator, "limit", limit)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -351,11 +349,10 @@ public void split_asciiWithLimit_success( "{string: '😁a😦나😑 😦', separator: '😁a😦나😑 😦', limit: -1, expectedResult: ['','']}") public void split_unicodeWithLimit_success( String string, String separator, int limit, List expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.split(separator, limit)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = - program.eval(ImmutableMap.of("s", string, "separator", separator, "limit", limit)); + eval( + "s.split(separator, limit)", + ImmutableMap.of("s", string, "separator", separator, "limit", limit)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -368,35 +365,36 @@ public void split_unicodeWithLimit_success( @TestParameters("{separator: 'te', limit: 1}") @TestParameters("{separator: 'te', limit: 2}") @SuppressWarnings("unchecked") // Test only, need List cast to test mutability - public void split_withLimit_collectionIsMutable(String separator, int limit) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("'test'.split(separator, limit)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - + public void split_withLimit_collectionIsImmutable(String separator, int limit) throws Exception { List evaluatedResult = - (List) program.eval(ImmutableMap.of("separator", separator, "limit", limit)); - evaluatedResult.add("a"); + (List) + eval( + "'test'.split(separator, limit)", + ImmutableMap.of("separator", separator, "limit", limit)); - assertThat(Iterables.getLast(evaluatedResult)).isEqualTo("a"); + assertThrows(UnsupportedOperationException.class, () -> evaluatedResult.add("a")); } @Test public void split_withLimit_separatorIsNonString_throwsException() { + // This is a type-check failure. + Assume.assumeFalse(isParseOnly); + CelValidationResult result = cel.compile("'12'.split(2, 3)"); CelValidationException exception = - assertThrows( - CelValidationException.class, () -> COMPILER.compile("'12'.split(2, 3)").getAst()); + assertThrows(CelValidationException.class, () -> result.getAst()); assertThat(exception).hasMessageThat().contains("found no matching overload for 'split'"); } @Test public void split_withLimitOverflow_throwsException() throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("'test'.split('', limit)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, - () -> program.eval(ImmutableMap.of("limit", 2147483648L))); // INT_MAX + 1 + () -> + eval( + "'test'.split('', limit)", + ImmutableMap.of("limit", 2147483648L))); // INT_MAX + 1 assertThat(exception) .hasMessageThat() @@ -416,11 +414,10 @@ public void split_withLimitOverflow_throwsException() throws Exception { @TestParameters("{string: '', beginIndex: 0, endIndex: 0, expectedResult: ''}") public void substring_beginAndEndIndex_ascii_success( String string, int beginIndex, int endIndex, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.substring(beginIndex, endIndex)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = - program.eval(ImmutableMap.of("s", string, "beginIndex", beginIndex, "endIndex", endIndex)); + eval( + "s.substring(beginIndex, endIndex)", + ImmutableMap.of("s", string, "beginIndex", beginIndex, "endIndex", endIndex)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -444,11 +441,10 @@ public void substring_beginAndEndIndex_ascii_success( @TestParameters("{string: 'a😁나', beginIndex: 3, endIndex: 3, expectedResult: ''}") public void substring_beginAndEndIndex_unicode_success( String string, int beginIndex, int endIndex, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.substring(beginIndex, endIndex)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = - program.eval(ImmutableMap.of("s", string, "beginIndex", beginIndex, "endIndex", endIndex)); + eval( + "s.substring(beginIndex, endIndex)", + ImmutableMap.of("s", string, "beginIndex", beginIndex, "endIndex", endIndex)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -458,13 +454,13 @@ public void substring_beginAndEndIndex_unicode_success( @TestParameters("{string: '', beginIndex: 2}") public void substring_beginIndexOutOfRange_ascii_throwsException(String string, int beginIndex) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.substring(beginIndex)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, - () -> program.eval(ImmutableMap.of("s", string, "beginIndex", beginIndex))); + () -> + eval( + "s.substring(beginIndex)", + ImmutableMap.of("s", string, "beginIndex", beginIndex))); String exceptionMessage = String.format( @@ -482,13 +478,13 @@ public void substring_beginIndexOutOfRange_ascii_throwsException(String string, @TestParameters("{string: '😁가나', beginIndex: 4, uniqueCharCount: 3}") public void substring_beginIndexOutOfRange_unicode_throwsException( String string, int beginIndex, int uniqueCharCount) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.substring(beginIndex)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, - () -> program.eval(ImmutableMap.of("s", string, "beginIndex", beginIndex))); + () -> + eval( + "s.substring(beginIndex)", + ImmutableMap.of("s", string, "beginIndex", beginIndex))); String exceptionMessage = String.format( @@ -505,14 +501,12 @@ public void substring_beginIndexOutOfRange_unicode_throwsException( @TestParameters("{string: '😁😑😦', beginIndex: 2, endIndex: 1}") public void substring_beginAndEndIndexOutOfRange_throwsException( String string, int beginIndex, int endIndex) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.substring(beginIndex, endIndex)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, () -> - program.eval( + eval( + "s.substring(beginIndex, endIndex)", ImmutableMap.of("s", string, "beginIndex", beginIndex, "endIndex", endIndex))); String exceptionMessage = @@ -522,13 +516,13 @@ public void substring_beginAndEndIndexOutOfRange_throwsException( @Test public void substring_beginIndexOverflow_throwsException() throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("'abcd'.substring(beginIndex)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, - () -> program.eval(ImmutableMap.of("beginIndex", 2147483648L))); // INT_MAX + 1 + () -> + eval( + "'abcd'.substring(beginIndex)", + ImmutableMap.of("beginIndex", 2147483648L))); // INT_MAX + 1 assertThat(exception) .hasMessageThat() @@ -540,13 +534,13 @@ public void substring_beginIndexOverflow_throwsException() throws Exception { @TestParameters("{beginIndex: 2147483648, endIndex: 2147483648}") public void substring_beginOrEndIndexOverflow_throwsException(long beginIndex, long endIndex) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("'abcd'.substring(beginIndex, endIndex)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, - () -> program.eval(ImmutableMap.of("beginIndex", beginIndex, "endIndex", endIndex))); + () -> + eval( + "'abcd'.substring(beginIndex, endIndex)", + ImmutableMap.of("beginIndex", beginIndex, "endIndex", endIndex))); assertThat(exception) .hasMessageThat() @@ -563,10 +557,7 @@ public void substring_beginOrEndIndexOverflow_throwsException(long beginIndex, l @TestParameters("{string: 'world', index: 5, expectedResult: ''}") public void charAt_ascii_success(String string, long index, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.charAt(index)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string, "index", index)); + Object evaluatedResult = eval("s.charAt(index)", ImmutableMap.of("s", string, "index", index)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -588,10 +579,7 @@ public void charAt_ascii_success(String string, long index, String expectedResul @TestParameters("{string: 'a😁나', index: 3, expectedResult: ''}") public void charAt_unicode_success(String string, long index, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.charAt(index)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string, "index", index)); + Object evaluatedResult = eval("s.charAt(index)", ImmutableMap.of("s", string, "index", index)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -602,26 +590,21 @@ public void charAt_unicode_success(String string, long index, String expectedRes @TestParameters("{string: '😁😑😦', index: -1}") @TestParameters("{string: '😁😑😦', index: 4}") public void charAt_outOfBounds_throwsException(String string, long index) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.charAt(index)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, - () -> program.eval(ImmutableMap.of("s", string, "index", index))); + () -> eval("s.charAt(index)", ImmutableMap.of("s", string, "index", index))); assertThat(exception).hasMessageThat().contains("charAt failure: Index out of range"); } @Test public void charAt_indexOverflow_throwsException() throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("'test'.charAt(index)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, - () -> program.eval(ImmutableMap.of("index", 2147483648L))); // INT_MAX + 1 + () -> + eval("'test'.charAt(index)", ImmutableMap.of("index", 2147483648L))); // INT_MAX + 1 assertThat(exception) .hasMessageThat() @@ -650,10 +633,8 @@ public void charAt_indexOverflow_throwsException() throws Exception { @TestParameters("{string: 'hello mellow', indexOf: ' ', expectedResult: -1}") public void indexOf_ascii_success(String string, String indexOf, int expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.indexOf(indexOfParam)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string, "indexOfParam", indexOf)); + Object evaluatedResult = + eval("s.indexOf(indexOfParam)", ImmutableMap.of("s", string, "indexOfParam", indexOf)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -682,10 +663,8 @@ public void indexOf_ascii_success(String string, String indexOf, int expectedRes @TestParameters("{string: 'a😁😑 나😦😁😑다', indexOf: 'a😁😑 나😦😁😑다😁', expectedResult: -1}") public void indexOf_unicode_success(String string, String indexOf, int expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.indexOf(indexOfParam)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string, "indexOfParam", indexOf)); + Object evaluatedResult = + eval("s.indexOf(indexOfParam)", ImmutableMap.of("s", string, "indexOfParam", indexOf)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -697,13 +676,10 @@ public void indexOf_unicode_success(String string, String indexOf, int expectedR @TestParameters("{indexOf: '나'}") @TestParameters("{indexOf: '😁'}") public void indexOf_onEmptyString_throwsException(String indexOf) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("''.indexOf(indexOfParam)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, - () -> program.eval(ImmutableMap.of("indexOfParam", indexOf))); + () -> eval("''.indexOf(indexOfParam)", ImmutableMap.of("indexOfParam", indexOf))); assertThat(exception).hasMessageThat().contains("indexOf failure: Offset out of range"); } @@ -728,11 +704,10 @@ public void indexOf_onEmptyString_throwsException(String indexOf) throws Excepti @TestParameters("{string: 'hello mellow', indexOf: 'l', offset: 10, expectedResult: -1}") public void indexOf_asciiWithOffset_success( String string, String indexOf, int offset, int expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.indexOf(indexOfParam, offset)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = - program.eval(ImmutableMap.of("s", string, "indexOfParam", indexOf, "offset", offset)); + eval( + "s.indexOf(indexOfParam, offset)", + ImmutableMap.of("s", string, "indexOfParam", indexOf, "offset", offset)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -779,11 +754,10 @@ public void indexOf_asciiWithOffset_success( "{string: 'a😁😑 나😦😁😑다', indexOf: 'a😁😑 나😦😁😑다😁', offset: 0, expectedResult: -1}") public void indexOf_unicodeWithOffset_success( String string, String indexOf, int offset, int expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.indexOf(indexOfParam, offset)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = - program.eval(ImmutableMap.of("s", string, "indexOfParam", indexOf, "offset", offset)); + eval( + "s.indexOf(indexOfParam, offset)", + ImmutableMap.of("s", string, "indexOfParam", indexOf, "offset", offset)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -797,14 +771,12 @@ public void indexOf_unicodeWithOffset_success( @TestParameters("{string: '😁😑 😦', indexOf: '😦', offset: 4}") public void indexOf_withOffsetOutOfBounds_throwsException( String string, String indexOf, int offset) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.indexOf(indexOfParam, offset)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, () -> - program.eval( + eval( + "s.indexOf(indexOfParam, offset)", ImmutableMap.of("s", string, "indexOfParam", indexOf, "offset", offset))); assertThat(exception).hasMessageThat().contains("indexOf failure: Offset out of range"); @@ -812,13 +784,13 @@ public void indexOf_withOffsetOutOfBounds_throwsException( @Test public void indexOf_offsetOverflow_throwsException() throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("'test'.indexOf('t', offset)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, - () -> program.eval(ImmutableMap.of("offset", 2147483648L))); // INT_MAX + 1 + () -> + eval( + "'test'.indexOf('t', offset)", + ImmutableMap.of("offset", 2147483648L))); // INT_MAX + 1 assertThat(exception) .hasMessageThat() @@ -835,10 +807,7 @@ public void indexOf_offsetOverflow_throwsException() throws Exception { @TestParameters("{list: '[''x'', '' '', '' y '', ''z '']', expectedResult: 'x y z '}") @TestParameters("{list: '[''hello '', ''world'']', expectedResult: 'hello world'}") public void join_ascii_success(String list, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(String.format("%s.join()", list)).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - String result = (String) program.eval(); + String result = (String) eval(String.format("%s.join()", list)); assertThat(result).isEqualTo(expectedResult); } @@ -847,10 +816,7 @@ public void join_ascii_success(String list, String expectedResult) throws Except @TestParameters("{list: '[''가'', ''😁'']', expectedResult: '가😁'}") @TestParameters("{list: '[''😁😦😑 😦'', ''나'']', expectedResult: '😁😦😑 😦나'}") public void join_unicode_success(String list, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(String.format("%s.join()", list)).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - String result = (String) program.eval(); + String result = (String) eval(String.format("%s.join()", list)); assertThat(result).isEqualTo(expectedResult); } @@ -874,11 +840,7 @@ public void join_unicode_success(String list, String expectedResult) throws Exce "{list: '[''hello '', ''world'']', separator: '/', expectedResult: 'hello /world'}") public void join_asciiWithSeparator_success(String list, String separator, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = - COMPILER.compile(String.format("%s.join('%s')", list, separator)).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - String result = (String) program.eval(); + String result = (String) eval(String.format("%s.join('%s')", list, separator)); assertThat(result).isEqualTo(expectedResult); } @@ -893,20 +855,17 @@ public void join_asciiWithSeparator_success(String list, String separator, Strin + " -😑-나'}") public void join_unicodeWithSeparator_success( String list, String separator, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = - COMPILER.compile(String.format("%s.join('%s')", list, separator)).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - String result = (String) program.eval(); + String result = (String) eval(String.format("%s.join('%s')", list, separator)); assertThat(result).isEqualTo(expectedResult); } @Test public void join_separatorIsNonString_throwsException() { + // This is a type-check failure. + Assume.assumeFalse(isParseOnly); CelValidationException exception = - assertThrows( - CelValidationException.class, () -> COMPILER.compile("['x','y'].join(2)").getAst()); + assertThrows(CelValidationException.class, () -> cel.compile("['x','y'].join(2)").getAst()); assertThat(exception).hasMessageThat().contains("found no matching overload for 'join'"); } @@ -935,11 +894,10 @@ public void join_separatorIsNonString_throwsException() { @TestParameters("{string: 'hello mellow', lastIndexOf: ' ', expectedResult: -1}") public void lastIndexOf_ascii_success(String string, String lastIndexOf, int expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.lastIndexOf(indexOfParam)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = - program.eval(ImmutableMap.of("s", string, "indexOfParam", lastIndexOf)); + eval( + "s.lastIndexOf(indexOfParam)", + ImmutableMap.of("s", string, "indexOfParam", lastIndexOf)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -969,11 +927,10 @@ public void lastIndexOf_ascii_success(String string, String lastIndexOf, int exp @TestParameters("{string: 'a😁😑 나😦😁😑다', lastIndexOf: 'a😁😑 나😦😁😑다😁', expectedResult: -1}") public void lastIndexOf_unicode_success(String string, String lastIndexOf, int expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.lastIndexOf(indexOfParam)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = - program.eval(ImmutableMap.of("s", string, "indexOfParam", lastIndexOf)); + eval( + "s.lastIndexOf(indexOfParam)", + ImmutableMap.of("s", string, "indexOfParam", lastIndexOf)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -987,10 +944,8 @@ public void lastIndexOf_unicode_success(String string, String lastIndexOf, int e @TestParameters("{lastIndexOf: '😁'}") public void lastIndexOf_strLengthLessThanSubstrLength_returnsMinusOne(String lastIndexOf) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("''.lastIndexOf(indexOfParam)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", "", "indexOfParam", lastIndexOf)); + Object evaluatedResult = + eval("''.lastIndexOf(indexOfParam)", ImmutableMap.of("s", "", "indexOfParam", lastIndexOf)); assertThat(evaluatedResult).isEqualTo(-1); } @@ -1022,11 +977,10 @@ public void lastIndexOf_strLengthLessThanSubstrLength_returnsMinusOne(String las "{string: 'hello mellow', lastIndexOf: 'hello mellowwww ', offset: 11, expectedResult: -1}") public void lastIndexOf_asciiWithOffset_success( String string, String lastIndexOf, int offset, int expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.lastIndexOf(indexOfParam, offset)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = - program.eval(ImmutableMap.of("s", string, "indexOfParam", lastIndexOf, "offset", offset)); + eval( + "s.lastIndexOf(indexOfParam, offset)", + ImmutableMap.of("s", string, "indexOfParam", lastIndexOf, "offset", offset)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -1097,11 +1051,10 @@ public void lastIndexOf_asciiWithOffset_success( "{string: 'a😁😑 나😦😁😑다', lastIndexOf: 'a😁😑 나😦😁😑다😁', offset: 8, expectedResult: -1}") public void lastIndexOf_unicodeWithOffset_success( String string, String lastIndexOf, int offset, int expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.lastIndexOf(indexOfParam, offset)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = - program.eval(ImmutableMap.of("s", string, "indexOfParam", lastIndexOf, "offset", offset)); + eval( + "s.lastIndexOf(indexOfParam, offset)", + ImmutableMap.of("s", string, "indexOfParam", lastIndexOf, "offset", offset)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -1115,14 +1068,12 @@ public void lastIndexOf_unicodeWithOffset_success( @TestParameters("{string: '😁😑 😦', lastIndexOf: '😦', offset: 4}") public void lastIndexOf_withOffsetOutOfBounds_throwsException( String string, String lastIndexOf, int offset) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.lastIndexOf(indexOfParam, offset)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, () -> - program.eval( + eval( + "s.lastIndexOf(indexOfParam, offset)", ImmutableMap.of("s", string, "indexOfParam", lastIndexOf, "offset", offset))); assertThat(exception).hasMessageThat().contains("lastIndexOf failure: Offset out of range"); @@ -1130,13 +1081,13 @@ public void lastIndexOf_withOffsetOutOfBounds_throwsException( @Test public void lastIndexOf_offsetOverflow_throwsException() throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("'test'.lastIndexOf('t', offset)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, - () -> program.eval(ImmutableMap.of("offset", 2147483648L))); // INT_MAX + 1 + () -> + eval( + "'test'.lastIndexOf('t', offset)", + ImmutableMap.of("offset", 2147483648L))); // INT_MAX + 1 assertThat(exception) .hasMessageThat() @@ -1163,13 +1114,8 @@ public void lastIndexOf_offsetOverflow_throwsException() throws Exception { public void replace_ascii_success( String string, String searchString, String replacement, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = - COMPILER - .compile(String.format("'%s'.replace('%s', '%s')", string, searchString, replacement)) - .getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(); + Object evaluatedResult = + eval(String.format("'%s'.replace('%s', '%s')", string, searchString, replacement)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -1188,13 +1134,8 @@ public void replace_ascii_success( public void replace_unicode_success( String string, String searchString, String replacement, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = - COMPILER - .compile(String.format("'%s'.replace('%s', '%s')", string, searchString, replacement)) - .getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(); + Object evaluatedResult = + eval(String.format("'%s'.replace('%s', '%s')", string, searchString, replacement)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -1273,15 +1214,10 @@ public void replace_unicode_success( public void replace_ascii_withLimit_success( String string, String searchString, String replacement, int limit, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = - COMPILER - .compile( - String.format( - "'%s'.replace('%s', '%s', %d)", string, searchString, replacement, limit)) - .getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(); + Object evaluatedResult = + eval( + String.format( + "'%s'.replace('%s', '%s', %d)", string, searchString, replacement, limit)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -1334,28 +1270,23 @@ public void replace_ascii_withLimit_success( public void replace_unicode_withLimit_success( String string, String searchString, String replacement, int limit, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = - COMPILER - .compile( - String.format( - "'%s'.replace('%s', '%s', %d)", string, searchString, replacement, limit)) - .getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(); + Object evaluatedResult = + eval( + String.format( + "'%s'.replace('%s', '%s', %d)", string, searchString, replacement, limit)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @Test public void replace_limitOverflow_throwsException() throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("'test'.replace('','',index)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - CelEvaluationException exception = assertThrows( CelEvaluationException.class, - () -> program.eval(ImmutableMap.of("index", 2147483648L))); // INT_MAX + 1 + () -> + eval( + "'test'.replace('','',index)", + ImmutableMap.of("index", 2147483648L))); // INT_MAX + 1 assertThat(exception) .hasMessageThat() @@ -1406,10 +1337,7 @@ private enum TrimTestCase { @Test public void trim_success(@TestParameter TrimTestCase testCase) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.trim()").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", testCase.text)); + Object evaluatedResult = eval("s.trim()", ImmutableMap.of("s", testCase.text)); assertThat(evaluatedResult).isEqualTo(testCase.expectedResult); } @@ -1422,10 +1350,7 @@ public void trim_success(@TestParameter TrimTestCase testCase) throws Exception @TestParameters( "{string: 'a!@#$%^&*()-_+=?/<>.,;:''\"\\', expectedResult: 'A!@#$%^&*()-_+=?/<>.,;:''\"\\'}") public void upperAscii_success(String string, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.upperAscii()").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string)); + Object evaluatedResult = eval("s.upperAscii()", ImmutableMap.of("s", string)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -1441,30 +1366,25 @@ public void upperAscii_success(String string, String expectedResult) throws Exce @TestParameters("{string: 'a😁b 😑c가😦d', expectedResult: 'A😁B 😑C가😦D'}") public void upperAscii_outsideAscii_success(String string, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.upperAscii()").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string)); + Object evaluatedResult = eval("s.upperAscii()", ImmutableMap.of("s", string)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @Test public void stringExtension_functionSubset_success() throws Exception { - CelStringExtensions stringExtensions = - CelExtensions.strings(Function.CHAR_AT, Function.SUBSTRING); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder().addLibraries(stringExtensions).build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder().addLibraries(stringExtensions).build(); + Cel customCel = + runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.strings(Function.CHAR_AT, Function.SUBSTRING)) + .addRuntimeLibraries(CelExtensions.strings(Function.CHAR_AT, Function.SUBSTRING)) + .build(); Object evaluatedResult = - celRuntime - .createProgram( - celCompiler - .compile("'test'.substring(2) == 'st' && 'hello'.charAt(1) == 'e'") - .getAst()) - .eval(); + eval( + customCel, + "'test'.substring(2) == 'st' && 'hello'.charAt(1) == 'e'", + ImmutableMap.of()); assertThat(evaluatedResult).isEqualTo(true); } @@ -1476,10 +1396,7 @@ public void stringExtension_functionSubset_success() throws Exception { @TestParameters("{string: 'hello world', expectedResult: 'dlrow olleh'}") @TestParameters("{string: 'ab가cd', expectedResult: 'dc가ba'}") public void reverse_success(String string, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.reverse()").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string)); + Object evaluatedResult = eval("s.reverse()", ImmutableMap.of("s", string)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -1490,10 +1407,7 @@ public void reverse_success(String string, String expectedResult) throws Excepti "{string: '\u180e\u200b\u200c\u200d\u2060\ufeff', expectedResult:" + " '\ufeff\u2060\u200d\u200c\u200b\u180e'}") public void reverse_unicode(String string, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("s.reverse()").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string)); + Object evaluatedResult = eval("s.reverse()", ImmutableMap.of("s", string)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -1502,14 +1416,14 @@ public void reverse_unicode(String string, String expectedResult) throws Excepti @TestParameters("{string: 'hello', expectedResult: '\"hello\"'}") @TestParameters("{string: '', expectedResult: '\"\"'}") @TestParameters( - "{string: 'contains \\\"quotes\\\"', expectedResult: '\"contains \\\\\\\"quotes\\\\\\\"\"'}") - @TestParameters("{string: 'ends with \\\\', expectedResult: '\"ends with \\\\\\\\\"'}") - @TestParameters("{string: '\\\\ starts with', expectedResult: '\"\\\\\\\\ starts with\"'}") + "{string: 'contains \\\\\\\"quotes\\\\\\\"', expectedResult: '\"contains" + + " \\\\\\\\\\\\\\\"quotes\\\\\\\\\\\\\\\"\"'}") + @TestParameters( + "{string: 'ends with \\\\\\\\', expectedResult: '\"ends with \\\\\\\\\\\\\\\\\"'}") + @TestParameters( + "{string: '\\\\\\\\ starts with', expectedResult: '\"\\\\\\\\\\\\\\\\ starts with\"'}") public void quote_success(String string, String expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("strings.quote(s)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(ImmutableMap.of("s", string)); + Object evaluatedResult = eval("strings.quote(s)", ImmutableMap.of("s", string)); assertThat(evaluatedResult).isEqualTo(expectedResult); } @@ -1518,21 +1432,16 @@ public void quote_success(String string, String expectedResult) throws Exception public void quote_singleWithDoubleQuotes() throws Exception { String expr = "strings.quote('single-quote with \"double quote\"')"; String expected = "\"\\\"single-quote with \\\\\\\"double quote\\\\\\\"\\\"\""; - CelAbstractSyntaxTree ast = COMPILER.compile(expr + " == " + expected).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object evaluatedResult = program.eval(); + Object evaluatedResult = eval(expr + " == " + expected); assertThat(evaluatedResult).isEqualTo(true); } @Test public void quote_escapesSpecialCharacters() throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("strings.quote(s)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = - program.eval( + eval( + "strings.quote(s)", ImmutableMap.of("s", "\u0007bell\u000Bvtab\bback\ffeed\rret\nline\ttab\\slash 가 😁")); assertThat(evaluatedResult) @@ -1541,25 +1450,19 @@ public void quote_escapesSpecialCharacters() throws Exception { @Test public void quote_escapesMalformed_endWithHighSurrogate() throws Exception { - CelRuntime.Program program = - RUNTIME.createProgram(COMPILER.compile("strings.quote(s)").getAst()); - assertThat(program.eval(ImmutableMap.of("s", "end with high surrogate \uD83D"))) + assertThat(eval("strings.quote(s)", ImmutableMap.of("s", "end with high surrogate \uD83D"))) .isEqualTo("\"end with high surrogate \uFFFD\""); } @Test public void quote_escapesMalformed_unpairedHighSurrogate() throws Exception { - CelRuntime.Program program = - RUNTIME.createProgram(COMPILER.compile("strings.quote(s)").getAst()); - assertThat(program.eval(ImmutableMap.of("s", "bad pair \uD83DA"))) + assertThat(eval("strings.quote(s)", ImmutableMap.of("s", "bad pair \uD83DA"))) .isEqualTo("\"bad pair \uFFFDA\""); } @Test public void quote_escapesMalformed_unpairedLowSurrogate() throws Exception { - CelRuntime.Program program = - RUNTIME.createProgram(COMPILER.compile("strings.quote(s)").getAst()); - assertThat(program.eval(ImmutableMap.of("s", "bad pair \uDC00A"))) + assertThat(eval("strings.quote(s)", ImmutableMap.of("s", "bad pair \uDC00A"))) .isEqualTo("\"bad pair \uFFFDA\""); } @@ -1570,23 +1473,47 @@ public void stringExtension_compileUnallowedFunction_throws() { .addLibraries(CelExtensions.strings(Function.REPLACE)) .build(); - assertThrows( - CelValidationException.class, - () -> celCompiler.compile("'test'.substring(2) == 'st'").getAst()); + // This is a type-check failure. + Assume.assumeFalse(isParseOnly); + CelValidationResult result = celCompiler.compile("'test'.substring(2) == 'st'"); + assertThrows(CelValidationException.class, () -> result.getAst()); } @Test public void stringExtension_evaluateUnallowedFunction_throws() throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.strings(Function.SUBSTRING)) + Cel customCompilerCel = + runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.strings(Function.SUBSTRING)) .build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() - .addLibraries(CelExtensions.strings(Function.REPLACE)) + Cel customRuntimeCel = + runtimeFlavor + .builder() + .addRuntimeLibraries(CelExtensions.strings(Function.REPLACE)) .build(); - CelAbstractSyntaxTree ast = celCompiler.compile("'test'.substring(2) == 'st'").getAst(); + CelAbstractSyntaxTree ast = + isParseOnly + ? customCompilerCel.parse("'test'.substring(2) == 'st'").getAst() + : customCompilerCel.compile("'test'.substring(2) == 'st'").getAst(); + + assertThrows(CelEvaluationException.class, () -> customRuntimeCel.createProgram(ast).eval()); + } + + private Object eval(Cel cel, String expression, Map variables) throws Exception { + CelAbstractSyntaxTree ast; + if (isParseOnly) { + ast = cel.parse(expression).getAst(); + } else { + ast = cel.compile(expression).getAst(); + } + return cel.createProgram(ast).eval(variables); + } + + private Object eval(String expression) throws Exception { + return eval(this.cel, expression, ImmutableMap.of()); + } - assertThrows(CelEvaluationException.class, () -> celRuntime.createProgram(ast).eval()); + private Object eval(String expression, Map variables) throws Exception { + return eval(this.cel, expression, variables); } } From b029be3cceeeb0dc957d55f8eb0518451b92b75c Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Mon, 20 Apr 2026 14:46:48 -0700 Subject: [PATCH 40/66] Support parsed-only evaluation to encoders extension PiperOrigin-RevId: 902832943 --- .../cel/extensions/CelEncoderExtensions.java | 8 +- .../extensions/CelEncoderExtensionsTest.java | 95 +++++++++---------- 2 files changed, 51 insertions(+), 52 deletions(-) diff --git a/extensions/src/main/java/dev/cel/extensions/CelEncoderExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelEncoderExtensions.java index a98f9db41..498b8555e 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelEncoderExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelEncoderExtensions.java @@ -135,9 +135,13 @@ public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) { functions.forEach( function -> { if (celOptions.evaluateCanonicalTypesToNativeValues()) { - runtimeBuilder.addFunctionBindings(function.nativeBytesFunctionBinding); + runtimeBuilder.addFunctionBindings( + CelFunctionBinding.fromOverloads( + function.getFunction(), function.nativeBytesFunctionBinding)); } else { - runtimeBuilder.addFunctionBindings(function.protoBytesFunctionBinding); + runtimeBuilder.addFunctionBindings( + CelFunctionBinding.fromOverloads( + function.getFunction(), function.protoBytesFunctionBinding)); } }); } diff --git a/extensions/src/test/java/dev/cel/extensions/CelEncoderExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelEncoderExtensionsTest.java index 7eed3dd5a..b0a501ddb 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelEncoderExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelEncoderExtensionsTest.java @@ -19,36 +19,45 @@ import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableMap; +import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; import dev.cel.common.CelValidationException; import dev.cel.common.types.SimpleType; import dev.cel.common.values.CelByteString; -import dev.cel.compiler.CelCompiler; -import dev.cel.compiler.CelCompilerFactory; import dev.cel.runtime.CelEvaluationException; -import dev.cel.runtime.CelRuntime; -import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.testing.CelRuntimeFlavor; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public class CelEncoderExtensionsTest { private static final CelOptions CEL_OPTIONS = - CelOptions.current().build(); - - private static final CelCompiler CEL_COMPILER = - CelCompilerFactory.standardCelCompilerBuilder() - .addVar("stringVar", SimpleType.STRING) - .addLibraries(CelExtensions.encoders(CEL_OPTIONS)) - .build(); - private static final CelRuntime CEL_RUNTIME = - CelRuntimeFactory.standardCelRuntimeBuilder() - .setOptions(CEL_OPTIONS) - .addLibraries(CelExtensions.encoders(CEL_OPTIONS)) - .build(); + CelOptions.current().enableHeterogeneousNumericComparisons(true).build(); + + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; + + private Cel cel; + + @Before + public void setUp() { + // Legacy runtime does not support parsed-only evaluation mode. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + this.cel = + runtimeFlavor + .builder() + .setOptions(CEL_OPTIONS) + .addCompilerLibraries(CelExtensions.encoders(CEL_OPTIONS)) + .addRuntimeLibraries(CelExtensions.encoders(CEL_OPTIONS)) + .addVar("stringVar", SimpleType.STRING) + .build(); + } @Test public void library() { @@ -63,22 +72,14 @@ public void library() { @Test public void encode_success() throws Exception { - String encodedBytes = - (String) - CEL_RUNTIME - .createProgram(CEL_COMPILER.compile("base64.encode(b'hello')").getAst()) - .eval(); + String encodedBytes = (String) eval("base64.encode(b'hello')"); assertThat(encodedBytes).isEqualTo("aGVsbG8="); } @Test public void decode_success() throws Exception { - CelByteString decodedBytes = - (CelByteString) - CEL_RUNTIME - .createProgram(CEL_COMPILER.compile("base64.decode('aGVsbG8=')").getAst()) - .eval(); + CelByteString decodedBytes = (CelByteString) eval("base64.decode('aGVsbG8=')"); assertThat(decodedBytes.size()).isEqualTo(5); assertThat(new String(decodedBytes.toByteArray(), ISO_8859_1)).isEqualTo("hello"); @@ -86,12 +87,7 @@ public void decode_success() throws Exception { @Test public void decode_withoutPadding_success() throws Exception { - CelByteString decodedBytes = - (CelByteString) - CEL_RUNTIME - // RFC2045 6.8, padding can be ignored. - .createProgram(CEL_COMPILER.compile("base64.decode('aGVsbG8')").getAst()) - .eval(); + CelByteString decodedBytes = (CelByteString) eval("base64.decode('aGVsbG8')"); assertThat(decodedBytes.size()).isEqualTo(5); assertThat(new String(decodedBytes.toByteArray(), ISO_8859_1)).isEqualTo("hello"); @@ -99,50 +95,49 @@ public void decode_withoutPadding_success() throws Exception { @Test public void roundTrip_success() throws Exception { - String encodedString = - (String) - CEL_RUNTIME - .createProgram(CEL_COMPILER.compile("base64.encode(b'Hello World!')").getAst()) - .eval(); + String encodedString = (String) eval("base64.encode(b'Hello World!')"); CelByteString decodedBytes = (CelByteString) - CEL_RUNTIME - .createProgram(CEL_COMPILER.compile("base64.decode(stringVar)").getAst()) - .eval(ImmutableMap.of("stringVar", encodedString)); + eval("base64.decode(stringVar)", ImmutableMap.of("stringVar", encodedString)); assertThat(new String(decodedBytes.toByteArray(), ISO_8859_1)).isEqualTo("Hello World!"); } @Test public void encode_invalidParam_throwsCompilationException() { + Assume.assumeFalse(isParseOnly); CelValidationException e = assertThrows( - CelValidationException.class, - () -> CEL_COMPILER.compile("base64.encode('hello')").getAst()); + CelValidationException.class, () -> cel.compile("base64.encode('hello')").getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'base64.encode'"); } @Test public void decode_invalidParam_throwsCompilationException() { + Assume.assumeFalse(isParseOnly); CelValidationException e = assertThrows( - CelValidationException.class, - () -> CEL_COMPILER.compile("base64.decode(b'aGVsbG8=')").getAst()); + CelValidationException.class, () -> cel.compile("base64.decode(b'aGVsbG8=')").getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'base64.decode'"); } @Test public void decode_malformedBase64Char_throwsEvaluationException() throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile("base64.decode('z!')").getAst(); - CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> CEL_RUNTIME.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> eval("base64.decode('z!')")); - assertThat(e) - .hasMessageThat() - .contains("Function 'base64_decode_string' failed with arg(s) 'z!'"); + assertThat(e).hasMessageThat().contains("failed with arg(s) 'z!'"); assertThat(e).hasCauseThat().hasMessageThat().contains("Illegal base64 character"); } + + private Object eval(String expr) throws Exception { + return eval(expr, ImmutableMap.of()); + } + + private Object eval(String expr, ImmutableMap vars) throws Exception { + CelAbstractSyntaxTree ast = isParseOnly ? cel.parse(expr).getAst() : cel.compile(expr).getAst(); + return cel.createProgram(ast).eval(vars); + } } From 5470c95da6fda1e418e485fd7f048e0e0573d8bf Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Mon, 20 Apr 2026 15:02:26 -0700 Subject: [PATCH 41/66] Restructure Java test runner to not require JVM classinfo inspection for loading descriptors PiperOrigin-RevId: 902839390 --- common/internal/BUILD.bazel | 5 -- .../java/dev/cel/common/internal/BUILD.bazel | 13 ----- .../DefaultInstanceMessageFactory.java | 5 +- .../internal/ProtoJavaQualifiedNames.java | 52 ------------------- .../main/java/dev/cel/protobuf/BUILD.bazel | 2 - .../protobuf/CelLiteDescriptorGenerator.java | 4 +- .../protobuf/ProtoDescriptorCollector.java | 5 +- testing/BUILD.bazel | 5 ++ .../dev/cel/testing/testrunner/BUILD.bazel | 13 +++-- .../testing/testrunner/CelTestContext.java | 26 ++++++---- .../CelTestSuiteTextProtoParser.java | 8 ++- .../testrunner/DefaultResultMatcher.java | 16 +++++- .../cel/testing/testrunner/RegistryUtils.java | 27 +++------- .../cel/testing/testrunner/TestExecutor.java | 2 + .../testing/testrunner/TestRunnerLibrary.java | 31 ++++------- .../java/dev/cel/testing/utils/BUILD.bazel | 9 +--- .../cel/testing/utils/ClassLoaderUtils.java | 32 ------------ .../dev/cel/testing/utils/ExprValueUtils.java | 39 ++++---------- .../testing/utils/ProtoDescriptorUtils.java | 34 +++--------- .../dev/cel/testing/testrunner/BUILD.bazel | 5 +- .../CustomVariableBindingUserTest.java | 32 ++++++++---- 21 files changed, 115 insertions(+), 250 deletions(-) delete mode 100644 common/src/main/java/dev/cel/common/internal/ProtoJavaQualifiedNames.java diff --git a/common/internal/BUILD.bazel b/common/internal/BUILD.bazel index 0a07e0d63..781566713 100644 --- a/common/internal/BUILD.bazel +++ b/common/internal/BUILD.bazel @@ -128,11 +128,6 @@ cel_android_library( exports = ["//common/src/main/java/dev/cel/common/internal:internal_android"], ) -java_library( - name = "proto_java_qualified_names", - exports = ["//common/src/main/java/dev/cel/common/internal:proto_java_qualified_names"], -) - java_library( name = "proto_time_utils", exports = ["//common/src/main/java/dev/cel/common/internal:proto_time_utils"], diff --git a/common/src/main/java/dev/cel/common/internal/BUILD.bazel b/common/src/main/java/dev/cel/common/internal/BUILD.bazel index 912b4de4b..6b470d98c 100644 --- a/common/src/main/java/dev/cel/common/internal/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/internal/BUILD.bazel @@ -153,7 +153,6 @@ java_library( tags = [ ], deps = [ - ":proto_java_qualified_names", ":reflection_util", "//common/annotations", "@maven//:com_google_guava_guava", @@ -396,18 +395,6 @@ java_library( ], ) -java_library( - name = "proto_java_qualified_names", - srcs = ["ProtoJavaQualifiedNames.java"], - tags = [ - ], - deps = [ - "//common/annotations", - "@maven//:com_google_guava_guava", - "@maven//:com_google_protobuf_protobuf_java", - ], -) - java_library( name = "reflection_util", srcs = ["ReflectionUtil.java"], diff --git a/common/src/main/java/dev/cel/common/internal/DefaultInstanceMessageFactory.java b/common/src/main/java/dev/cel/common/internal/DefaultInstanceMessageFactory.java index fcb0e7056..163d0273e 100644 --- a/common/src/main/java/dev/cel/common/internal/DefaultInstanceMessageFactory.java +++ b/common/src/main/java/dev/cel/common/internal/DefaultInstanceMessageFactory.java @@ -15,6 +15,7 @@ package dev.cel.common.internal; import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.GeneratorNames; import com.google.protobuf.Message; import com.google.protobuf.MessageLite; import dev.cel.common.annotations.Internal; @@ -45,9 +46,7 @@ public static DefaultInstanceMessageFactory getInstance() { public Optional getPrototype(Descriptor descriptor) { MessageLite defaultInstance = DefaultInstanceMessageLiteFactory.getInstance() - .getPrototype( - descriptor.getFullName(), - ProtoJavaQualifiedNames.getFullyQualifiedJavaClassName(descriptor)) + .getPrototype(descriptor.getFullName(), GeneratorNames.getBytecodeClassName(descriptor)) .orElse(null); if (defaultInstance == null) { return Optional.empty(); diff --git a/common/src/main/java/dev/cel/common/internal/ProtoJavaQualifiedNames.java b/common/src/main/java/dev/cel/common/internal/ProtoJavaQualifiedNames.java deleted file mode 100644 index f27181a50..000000000 --- a/common/src/main/java/dev/cel/common/internal/ProtoJavaQualifiedNames.java +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dev.cel.common.internal; - -import com.google.protobuf.Descriptors.Descriptor; -import com.google.protobuf.Descriptors.FileDescriptor; -import com.google.protobuf.GeneratorNames; -import dev.cel.common.annotations.Internal; - -/** - * Helper class for constructing a fully qualified Java class name from a protobuf descriptor. - * - *

CEL Library Internals. Do Not Use. - */ -@Internal -public final class ProtoJavaQualifiedNames { - /** - * Retrieves the full Java class name from the given descriptor - * - * @return fully qualified class name. - *

Example 1: dev.cel.expr.Value - *

Example 2: com.google.rpc.context.AttributeContext$Resource (Nested classes) - *

Example 3: com.google.api.expr.cel.internal.testdata$SingleFileProto$SingleFile$Path - * (Nested class with java multiple files disabled) - */ - public static String getFullyQualifiedJavaClassName(Descriptor descriptor) { - return GeneratorNames.getBytecodeClassName(descriptor); - } - - /** - * Gets the java package name from the descriptor. See - * https://developers.google.com/protocol-buffers/docs/reference/java-generated#package for rules - * on package name generation - */ - public static String getJavaPackageName(FileDescriptor fileDescriptor) { - return GeneratorNames.getFileJavaPackage(fileDescriptor.toProto()); - } - - private ProtoJavaQualifiedNames() {} -} diff --git a/protobuf/src/main/java/dev/cel/protobuf/BUILD.bazel b/protobuf/src/main/java/dev/cel/protobuf/BUILD.bazel index 6e7b473eb..b2dac98e7 100644 --- a/protobuf/src/main/java/dev/cel/protobuf/BUILD.bazel +++ b/protobuf/src/main/java/dev/cel/protobuf/BUILD.bazel @@ -21,7 +21,6 @@ java_binary( ":java_file_generator", ":proto_descriptor_collector", "//common:cel_descriptor_util", - "//common/internal:proto_java_qualified_names", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", "@maven//:info_picocli_picocli", @@ -50,7 +49,6 @@ java_library( ":cel_lite_descriptor", ":debug_printer", ":lite_descriptor_codegen_metadata", - "//common/internal:proto_java_qualified_names", "//common/internal:well_known_proto", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", diff --git a/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptorGenerator.java b/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptorGenerator.java index 276dd7f91..8c4eaea1c 100644 --- a/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptorGenerator.java +++ b/protobuf/src/main/java/dev/cel/protobuf/CelLiteDescriptorGenerator.java @@ -23,8 +23,8 @@ import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.GeneratorNames; import dev.cel.common.CelDescriptorUtil; -import dev.cel.common.internal.ProtoJavaQualifiedNames; import dev.cel.protobuf.JavaFileGenerator.GeneratedClass; import dev.cel.protobuf.JavaFileGenerator.JavaFileGeneratorOption; import java.io.File; @@ -117,7 +117,7 @@ public Integer call() throws Exception { private ImmutableList codegenCelLiteDescriptors( FileDescriptor targetFileDescriptor) throws Exception { - String javaPackageName = ProtoJavaQualifiedNames.getJavaPackageName(targetFileDescriptor); + String javaPackageName = GeneratorNames.getFileJavaPackage(targetFileDescriptor.toProto()); String javaClassName; List descriptors = targetFileDescriptor.getMessageTypes(); diff --git a/protobuf/src/main/java/dev/cel/protobuf/ProtoDescriptorCollector.java b/protobuf/src/main/java/dev/cel/protobuf/ProtoDescriptorCollector.java index 0031fe6a6..c2fe20557 100644 --- a/protobuf/src/main/java/dev/cel/protobuf/ProtoDescriptorCollector.java +++ b/protobuf/src/main/java/dev/cel/protobuf/ProtoDescriptorCollector.java @@ -22,7 +22,7 @@ import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor.JavaType; import com.google.protobuf.Descriptors.FileDescriptor; -import dev.cel.common.internal.ProtoJavaQualifiedNames; +import com.google.protobuf.GeneratorNames; import dev.cel.common.internal.WellKnownProto; import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor; import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor.EncodingType; @@ -93,8 +93,7 @@ ImmutableList collectCodegenMetadata(Descriptor d // Maps are resolved as an actual Java map, and doesn't have a MessageLite.Builder associated. if (!messageDescriptor.getOptions().getMapEntry()) { String sanitizedJavaClassName = - ProtoJavaQualifiedNames.getFullyQualifiedJavaClassName(messageDescriptor) - .replace('$', '.'); + GeneratorNames.getBytecodeClassName(messageDescriptor).replace('$', '.'); descriptorCodegenBuilder.setJavaClassName(sanitizedJavaClassName); } diff --git a/testing/BUILD.bazel b/testing/BUILD.bazel index b9e68f003..cc389fed1 100644 --- a/testing/BUILD.bazel +++ b/testing/BUILD.bazel @@ -45,3 +45,8 @@ java_library( name = "expr_value_utils", exports = ["//testing/src/main/java/dev/cel/testing/utils:expr_value_utils"], ) + +java_library( + name = "proto_descriptor_utils", + exports = ["//testing/src/main/java/dev/cel/testing/utils:proto_descriptor_utils"], +) diff --git a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel index 5af0665f9..d0fed9bea 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel @@ -93,10 +93,10 @@ java_library( "//bundle:environment_yaml_parser", "//common:cel_ast", "//common:cel_descriptor_util", + "//common:cel_descriptors", "//common:compiler_common", "//common:options", "//common:proto_ast", - "//common/internal:default_instance_message_factory", "//policy", "//policy:compiler_factory", "//policy:parser", @@ -104,7 +104,6 @@ java_library( "//policy:validation_exception", "//runtime", "//testing:expr_value_utils", - "//testing/testrunner:proto_descriptor_utils", "@cel_spec//proto/cel/expr:expr_java_proto", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", @@ -166,10 +165,11 @@ java_library( "//:auto_value", "//bundle:cel", "//common:cel_descriptor_util", + "//common:cel_descriptors", "//common:options", "//policy:parser", "//runtime", - "//testing/testrunner:proto_descriptor_utils", + "//testing:proto_descriptor_utils", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", @@ -182,8 +182,7 @@ java_library( tags = [ ], deps = [ - "//common/internal:default_instance_message_factory", - "//testing/testrunner:proto_descriptor_utils", + "//common:cel_descriptors", "@maven//:com_google_protobuf_protobuf_java", ], ) @@ -212,8 +211,10 @@ java_library( "//:java_truth", "//bundle:cel", "//common:cel_ast", + "//common:cel_descriptors", "//runtime", "//testing:expr_value_utils", + "//testing:proto_descriptor_utils", "@cel_spec//proto/cel/expr:expr_java_proto", "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_google_truth_extensions_truth_proto_extension", @@ -229,7 +230,9 @@ java_library( ":cel_test_suite", ":cel_test_suite_exception", ":registry_utils", + "//common:cel_descriptors", "//common/annotations", + "//testing:proto_descriptor_utils", "@cel_spec//proto/cel/expr:expr_java_proto", "@cel_spec//proto/cel/expr/conformance/test:suite_java_proto", "@maven//:com_google_guava_guava", diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java index 5635b6152..1be0bab25 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java @@ -25,6 +25,7 @@ import dev.cel.bundle.Cel; import dev.cel.bundle.CelFactory; import dev.cel.common.CelDescriptorUtil; +import dev.cel.common.CelDescriptors; import dev.cel.common.CelOptions; import dev.cel.policy.CelPolicyParser; import dev.cel.runtime.CelLateFunctionBindings; @@ -125,6 +126,20 @@ public interface BindingTransformer { abstract ImmutableSet fileTypes(); + @Memoized + public Optional celDescriptors() { + if (fileDescriptorSetPath().isPresent()) { + try { + return Optional.of( + ProtoDescriptorUtils.getDescriptorsFromFile(fileDescriptorSetPath().get())); + } catch (IOException e) { + throw new IllegalStateException( + "Failed to load descriptors from path: " + fileDescriptorSetPath().get(), e); + } + } + return Optional.empty(); + } + @Memoized public Optional typeRegistry() { if (fileTypes().isEmpty() && !fileDescriptorSetPath().isPresent()) { @@ -136,15 +151,8 @@ public Optional typeRegistry() { CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileTypes()) .messageTypeDescriptors()); } - if (fileDescriptorSetPath().isPresent()) { - try { - builder.add( - ProtoDescriptorUtils.getAllDescriptorsFromJvm(fileDescriptorSetPath().get()) - .messageTypeDescriptors()); - } catch (IOException e) { - throw new IllegalStateException( - "Failed to load descriptors from path: " + fileDescriptorSetPath().get(), e); - } + if (celDescriptors().isPresent()) { + builder.add(celDescriptors().get().messageTypeDescriptors()); } return Optional.of(builder.build()); } diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java index 5e7e62498..9c0ab4720 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteTextProtoParser.java @@ -22,6 +22,7 @@ import com.google.protobuf.TextFormat; import com.google.protobuf.TextFormat.ParseException; import com.google.protobuf.TypeRegistry; +import dev.cel.common.CelDescriptors; import dev.cel.common.annotations.Internal; import dev.cel.expr.conformance.test.InputValue; import dev.cel.expr.conformance.test.TestCase; @@ -30,6 +31,7 @@ import dev.cel.testing.testrunner.CelTestSuite.CelTestSection; import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase; import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase.Input.Binding; +import dev.cel.testing.utils.ProtoDescriptorUtils; import java.io.IOException; import java.util.Map; @@ -71,8 +73,10 @@ private TestSuite parseTestSuite( TypeRegistry typeRegistry = customTypeRegistry; ExtensionRegistry extensionRegistry = customExtensionRegistry; if (fileDescriptorSetPath != null) { - extensionRegistry = RegistryUtils.getExtensionRegistry(fileDescriptorSetPath); - typeRegistry = RegistryUtils.getTypeRegistry(fileDescriptorSetPath); + CelDescriptors descriptors = + ProtoDescriptorUtils.getDescriptorsFromFile(fileDescriptorSetPath); + extensionRegistry = RegistryUtils.getExtensionRegistry(descriptors); + typeRegistry = RegistryUtils.getTypeRegistry(descriptors); } TextFormat.Parser parser = TextFormat.Parser.newBuilder().setTypeRegistry(typeRegistry).build(); TestSuite.Builder builder = TestSuite.newBuilder(); diff --git a/testing/src/main/java/dev/cel/testing/testrunner/DefaultResultMatcher.java b/testing/src/main/java/dev/cel/testing/testrunner/DefaultResultMatcher.java index 2d33253af..279d591a2 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/DefaultResultMatcher.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/DefaultResultMatcher.java @@ -22,11 +22,13 @@ import dev.cel.expr.MapValue; import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelDescriptors; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelRuntime.Program; import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase.Output; import dev.cel.testing.testrunner.ResultMatcher.ResultMatcherParams; import dev.cel.testing.testrunner.ResultMatcher.ResultMatcherParams.ComputedOutput; +import dev.cel.testing.utils.ProtoDescriptorUtils; import java.io.IOException; final class DefaultResultMatcher implements ResultMatcher { @@ -41,6 +43,10 @@ public void match(ResultMatcherParams params, Cel cel) throws Exception { "Error: " + params.computedOutput().error().getMessage(), params.computedOutput().error()); } + if (params.computedOutput().kind().equals(ComputedOutput.Kind.UNKNOWN_SET)) { + throw new AssertionError( + "Expected value but got UnknownSet: " + params.computedOutput().unknownSet()); + } CelAbstractSyntaxTree exprAst = cel.compile(result.resultExpr()).getAst(); Program exprProgram = cel.createProgram(exprAst); Object evaluationResult = null; @@ -59,6 +65,10 @@ public void match(ResultMatcherParams params, Cel cel) throws Exception { "Error: " + params.computedOutput().error().getMessage(), params.computedOutput().error()); } + if (params.computedOutput().kind().equals(ComputedOutput.Kind.UNKNOWN_SET)) { + throw new AssertionError( + "Expected value but got UnknownSet: " + params.computedOutput().unknownSet()); + } assertExprValue( params.computedOutput().exprValue(), toExprValue(result.resultValue(), params.resultType())); @@ -85,12 +95,14 @@ private static void assertExprValue(ExprValue exprValue, ExprValue expectedExprV throws IOException { String fileDescriptorSetPath = System.getProperty("file_descriptor_set_path"); if (fileDescriptorSetPath != null) { + CelDescriptors descriptors = + ProtoDescriptorUtils.getDescriptorsFromFile(fileDescriptorSetPath); assertThat(exprValue) .ignoringRepeatedFieldOrderOfFieldDescriptors( MapValue.getDescriptor().findFieldByName("entries")) .unpackingAnyUsing( - RegistryUtils.getTypeRegistry(fileDescriptorSetPath), - RegistryUtils.getExtensionRegistry(fileDescriptorSetPath)) + RegistryUtils.getTypeRegistry(descriptors), + RegistryUtils.getExtensionRegistry(descriptors)) .isEqualTo(expectedExprValue); } else { assertThat(exprValue) diff --git a/testing/src/main/java/dev/cel/testing/testrunner/RegistryUtils.java b/testing/src/main/java/dev/cel/testing/testrunner/RegistryUtils.java index b2f195606..a10904abb 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/RegistryUtils.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/RegistryUtils.java @@ -13,44 +13,33 @@ // limitations under the License. package dev.cel.testing.testrunner; -import static dev.cel.testing.utils.ProtoDescriptorUtils.getAllDescriptorsFromJvm; + import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import com.google.protobuf.TypeRegistry; -import dev.cel.common.internal.DefaultInstanceMessageFactory; -import java.io.IOException; -import java.util.NoSuchElementException; +import dev.cel.common.CelDescriptors; /** Utility class for creating registries from a file descriptor set. */ public final class RegistryUtils { /** Returns the {@link TypeRegistry} for the given file descriptor set. */ - public static TypeRegistry getTypeRegistry(String fileDescriptorSetPath) throws IOException { - return TypeRegistry.newBuilder() - .add(getAllDescriptorsFromJvm(fileDescriptorSetPath).messageTypeDescriptors()) - .build(); + public static TypeRegistry getTypeRegistry(CelDescriptors descriptors) { + return TypeRegistry.newBuilder().add(descriptors.messageTypeDescriptors()).build(); } /** Returns the {@link ExtensionRegistry} for the given file descriptor set. */ - public static ExtensionRegistry getExtensionRegistry(String fileDescriptorSetPath) - throws IOException { + public static ExtensionRegistry getExtensionRegistry(CelDescriptors descriptors) { ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); - getAllDescriptorsFromJvm(fileDescriptorSetPath) + descriptors .extensionDescriptors() .forEach( (descriptorName, descriptor) -> { if (descriptor.getType().equals(FieldDescriptor.Type.MESSAGE)) { - Message output = - DefaultInstanceMessageFactory.getInstance() - .getPrototype(descriptor.getMessageType()) - .orElseThrow( - () -> - new NoSuchElementException( - "Could not find a default message for: " - + descriptor.getFullName())); + Message output = DynamicMessage.getDefaultInstance(descriptor.getMessageType()); extensionRegistry.add(descriptor, output); } else { extensionRegistry.add(descriptor); diff --git a/testing/src/main/java/dev/cel/testing/testrunner/TestExecutor.java b/testing/src/main/java/dev/cel/testing/testrunner/TestExecutor.java index 6f6dff3c1..181d99c6f 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/TestExecutor.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/TestExecutor.java @@ -236,6 +236,8 @@ public String describe() { testResult.setStatus(JUnitXmlReporter.TestResult.FAILURE); testResult.setThrowable(result.getFailures().get(0).getException()); testReporter.onTestFailure(testResult); + System.err.println("Test failed: " + testName); + result.getFailures().forEach(failure -> failure.getException().printStackTrace()); } } } diff --git a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java index 2465d330e..69c365972 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java @@ -28,6 +28,7 @@ import com.google.common.collect.ImmutableMap; import com.google.protobuf.Any; import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.DynamicMessage; import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import com.google.protobuf.TextFormat; @@ -38,10 +39,10 @@ import dev.cel.bundle.CelEnvironmentYamlParser; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelDescriptorUtil; +import dev.cel.common.CelDescriptors; import dev.cel.common.CelOptions; import dev.cel.common.CelProtoAbstractSyntaxTree; import dev.cel.common.CelValidationException; -import dev.cel.common.internal.DefaultInstanceMessageFactory; import dev.cel.policy.CelPolicy; import dev.cel.policy.CelPolicyCompilerFactory; import dev.cel.policy.CelPolicyParser; @@ -52,12 +53,10 @@ import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase; import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase.Input.Binding; import dev.cel.testing.testrunner.ResultMatcher.ResultMatcherParams; -import dev.cel.testing.utils.ProtoDescriptorUtils; import java.io.File; import java.io.IOException; import java.nio.file.Paths; import java.util.Map; -import java.util.NoSuchElementException; import java.util.Optional; import java.util.logging.Logger; import org.jspecify.annotations.Nullable; @@ -201,16 +200,13 @@ private static Cel extendCel(CelTestContext celTestContext, CelOptions celOption // // Note: This needs to be added first because the config file may contain type information // regarding proto messages that need to be added to the cel object. - if (celTestContext.fileDescriptorSetPath().isPresent()) { + CelDescriptors descriptors = celTestContext.celDescriptors().orElse(null); + if (descriptors != null) { extendedCel = extendedCel .toCelBuilder() - .addMessageTypes( - ProtoDescriptorUtils.getAllDescriptorsFromJvm( - celTestContext.fileDescriptorSetPath().get()) - .messageTypeDescriptors()) - .setExtensionRegistry( - RegistryUtils.getExtensionRegistry(celTestContext.fileDescriptorSetPath().get())) + .addMessageTypes(descriptors.messageTypeDescriptors()) + .setExtensionRegistry(RegistryUtils.getExtensionRegistry(descriptors)) .build(); } @@ -369,22 +365,13 @@ private static Message unpackAny(Any any, CelTestContext celTestContext) throws "Proto descriptors are required for unpacking Any messages."); } Descriptor descriptor = - RegistryUtils.getTypeRegistry(celTestContext.fileDescriptorSetPath().get()) + RegistryUtils.getTypeRegistry(celTestContext.celDescriptors().get()) .getDescriptorForTypeUrl(any.getTypeUrl()); - return getDefaultInstance(descriptor) + return DynamicMessage.getDefaultInstance(descriptor) .getParserForType() .parseFrom( any.getValue(), - RegistryUtils.getExtensionRegistry(celTestContext.fileDescriptorSetPath().get())); - } - - private static Message getDefaultInstance(Descriptor descriptor) throws IOException { - return DefaultInstanceMessageFactory.getInstance() - .getPrototype(descriptor) - .orElseThrow( - () -> - new NoSuchElementException( - "Could not find a default message for: " + descriptor.getFullName())); + RegistryUtils.getExtensionRegistry(celTestContext.celDescriptors().get())); } private static Message getEvaluatedContextExpr( diff --git a/testing/src/main/java/dev/cel/testing/utils/BUILD.bazel b/testing/src/main/java/dev/cel/testing/utils/BUILD.bazel index 2947709e5..eea56752d 100644 --- a/testing/src/main/java/dev/cel/testing/utils/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/utils/BUILD.bazel @@ -15,19 +15,16 @@ java_library( tags = [ ], deps = [ - "//common:cel_descriptor_util", "//common:cel_descriptors", - "//common/internal:default_instance_message_factory", "//common/internal:proto_time_utils", "//common/types", "//common/types:type_providers", "//common/values", "//common/values:cel_byte_string", "//runtime:unknown_attributes", + "//testing:proto_descriptor_utils", "//testing/testrunner:registry_utils", "@cel_spec//proto/cel/expr:expr_java_proto", - "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", - "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", "@maven_android//:com_google_protobuf_protobuf_javalite", @@ -41,7 +38,6 @@ java_library( ], deps = [ "@maven//:com_google_guava_guava", - "@maven//:com_google_protobuf_protobuf_java", "@maven//:io_github_classgraph_classgraph", ], ) @@ -49,12 +45,9 @@ java_library( java_library( name = "proto_descriptor_utils", srcs = ["ProtoDescriptorUtils.java"], - tags = [ - ], deps = [ "//common:cel_descriptor_util", "//common:cel_descriptors", - "//testing/testrunner:class_loader_utils", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", ], diff --git a/testing/src/main/java/dev/cel/testing/utils/ClassLoaderUtils.java b/testing/src/main/java/dev/cel/testing/utils/ClassLoaderUtils.java index 652ec85c6..31f45d48f 100644 --- a/testing/src/main/java/dev/cel/testing/utils/ClassLoaderUtils.java +++ b/testing/src/main/java/dev/cel/testing/utils/ClassLoaderUtils.java @@ -15,51 +15,19 @@ import com.google.common.base.Supplier; import com.google.common.base.Suppliers; -import com.google.common.collect.ImmutableList; -import com.google.protobuf.Descriptors.Descriptor; import io.github.classgraph.ClassGraph; import io.github.classgraph.ClassInfo; import io.github.classgraph.ClassInfoList; import io.github.classgraph.ScanResult; -import java.io.IOException; -import java.lang.reflect.InvocationTargetException; -import java.util.logging.Logger; /** Utility class for loading classes using {@link ClassGraph}. */ public final class ClassLoaderUtils { - private static final Logger logger = Logger.getLogger(ClassLoaderUtils.class.getName()); - // Using `enableAllInfo()` to scan all class files upfront. This avoids repeated parsing // of class files by individual methods, improving efficiency. private static final Supplier CLASS_SCAN_RESULT = Suppliers.memoize(() -> new ClassGraph().enableAllInfo().scan()); - /** - * Loads all descriptor type classes from the JVM. - * - * @return A list of {@link Descriptor} objects representing the descriptors loaded from the JVM. - * @throws IOException If there is an error during the loading process. - */ - public static ImmutableList loadDescriptors() throws IOException { - ClassInfoList classInfoList = CLASS_SCAN_RESULT.get().getAllStandardClasses(); - ImmutableList.Builder compileTimeLoadedDescriptors = ImmutableList.builder(); - - for (ClassInfo classInfo : classInfoList) { - try { - Class classInfoClass = classInfo.loadClass(); - Descriptor descriptor = (Descriptor) classInfoClass.getMethod("getDescriptor").invoke(null); - compileTimeLoadedDescriptors.add(descriptor); - } catch (InvocationTargetException e) { - logger.severe( - "Failed to load descriptor: " + classInfo.getName() + " with error: " + e); - } catch (Exception e) { - // Ignore classes that do not have a getDescriptor method. - } - } - return compileTimeLoadedDescriptors.build(); - } - /** * Loads all subclasses of the given class from the JVM. * diff --git a/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java b/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java index 9bccecc95..10ab52786 100644 --- a/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java +++ b/testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java @@ -24,11 +24,12 @@ import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.DynamicMessage; import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; import com.google.protobuf.NullValue; import com.google.protobuf.TypeRegistry; -import dev.cel.common.internal.DefaultInstanceMessageFactory; +import dev.cel.common.CelDescriptors; import dev.cel.common.internal.ProtoTimeUtils; import dev.cel.common.types.CelType; import dev.cel.common.types.ListType; @@ -44,7 +45,6 @@ import java.time.Instant; import java.util.List; import java.util.Map; -import java.util.NoSuchElementException; import java.util.Optional; /** Utility class for ExprValue and Value type conversions during test execution. */ @@ -53,29 +53,23 @@ public final class ExprValueUtils { private ExprValueUtils() {} - /** * Converts a {@link Value} to a Java native object using the given file descriptor set to parse * `Any` messages. * * @param value The {@link Value} to convert. - * @param fileDescriptorSetPath The path to the file descriptor set. * @return The converted Java object. * @throws IOException If there's an error during conversion. */ - public static Object fromValue(Value value, String fileDescriptorSetPath) throws IOException { - TypeRegistry typeRegistry = RegistryUtils.getTypeRegistry(fileDescriptorSetPath); - ExtensionRegistry extensionRegistry = RegistryUtils.getExtensionRegistry(fileDescriptorSetPath); + public static Object fromValue(Value value, CelDescriptors descriptors) throws IOException { + TypeRegistry typeRegistry = RegistryUtils.getTypeRegistry(descriptors); + ExtensionRegistry extensionRegistry = RegistryUtils.getExtensionRegistry(descriptors); return fromValue(value, typeRegistry, extensionRegistry); } - /** - * Converts a {@link Value} to a Java native object. - * - * @param value The {@link Value} to convert. - * @return The converted Java object. - * @throws IOException If there's an error during conversion. - */ + public static Object fromValue(Value value, String fileDescriptorSetPath) throws IOException { + return fromValue(value, ProtoDescriptorUtils.getDescriptorsFromFile(fileDescriptorSetPath)); + } /** * Converts a {@link Value} to a Java native object using custom registries. @@ -97,7 +91,7 @@ public static Object fromValue( "Unknown type, descriptor was not found in registry: " + value.getObjectValue().getTypeUrl()); } - Message prototype = getDefaultInstance(descriptor); + Message prototype = DynamicMessage.getDefaultInstance(descriptor); return prototype .getParserForType() .parseFrom(value.getObjectValue().getValue(), extensionRegistry); @@ -197,7 +191,8 @@ public static Value toValue(Object object, CelType type) throws Exception { if (object instanceof dev.cel.expr.Value) { object = Value.parseFrom( - ((dev.cel.expr.Value) object).toByteArray(), ExtensionRegistry.getEmptyRegistry()); + ((dev.cel.expr.Value) object).toByteArray(), + ExtensionRegistry.getEmptyRegistry()); } if (object instanceof Value) { return (Value) object; @@ -302,16 +297,4 @@ public static Value toValue(Object object, CelType type) throws Exception { throw new IllegalArgumentException( String.format("Unexpected result type: %s", object.getClass())); } - - private static Message getDefaultInstance(Descriptor descriptor) { - return DefaultInstanceMessageFactory.getInstance() - .getPrototype(descriptor) - .orElseThrow( - () -> - new NoSuchElementException( - "Could not find a default message for: " + descriptor.getFullName())); - } - - - } diff --git a/testing/src/main/java/dev/cel/testing/utils/ProtoDescriptorUtils.java b/testing/src/main/java/dev/cel/testing/utils/ProtoDescriptorUtils.java index b6fa2e64b..880c03e12 100644 --- a/testing/src/main/java/dev/cel/testing/utils/ProtoDescriptorUtils.java +++ b/testing/src/main/java/dev/cel/testing/utils/ProtoDescriptorUtils.java @@ -1,4 +1,4 @@ -// Copyright 2025 Google LLC +// Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,18 +11,11 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -package dev.cel.testing.utils; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static dev.cel.testing.utils.ClassLoaderUtils.loadDescriptors; +package dev.cel.testing.utils; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import com.google.common.io.Files; import com.google.protobuf.DescriptorProtos.FileDescriptorSet; -import com.google.protobuf.Descriptors.Descriptor; -import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.ExtensionRegistry; import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelDescriptors; @@ -33,30 +26,15 @@ public final class ProtoDescriptorUtils { /** - * Returns all the descriptors from the JVM. + * Returns all the descriptors from the file descriptor set file. * * @return The {@link CelDescriptors} object containing all the descriptors. */ - public static CelDescriptors getAllDescriptorsFromJvm(String fileDescriptorSetPath) + public static CelDescriptors getDescriptorsFromFile(String fileDescriptorSetPath) throws IOException { - ImmutableList compileTimeLoadedDescriptors = loadDescriptors(); FileDescriptorSet fileDescriptorSet = getFileDescriptorSet(fileDescriptorSetPath); - ImmutableSet runtimeFileDescriptorNames = - CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet(fileDescriptorSet).stream() - .map(FileDescriptor::getFullName) - .collect(toImmutableSet()); - - // Get all the file descriptors from the descriptors which are loaded from the JVM and use the - // ones which match the ones provided by the user in the file descriptor set. - ImmutableList userProvidedFileDescriptors = - CelDescriptorUtil.getFileDescriptorsForDescriptors(compileTimeLoadedDescriptors).stream() - .filter( - fileDescriptor -> runtimeFileDescriptorNames.contains(fileDescriptor.getFullName())) - .collect(toImmutableList()); - - // Get all the descriptors from the file descriptors above which include nested, extension and - // message type descriptors as well. - return CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(userProvidedFileDescriptors); + return CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( + CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet(fileDescriptorSet)); } private static FileDescriptorSet getFileDescriptorSet(String fileDescriptorSetPath) diff --git a/testing/src/test/java/dev/cel/testing/testrunner/BUILD.bazel b/testing/src/test/java/dev/cel/testing/testrunner/BUILD.bazel index 755ef732d..a12654d2c 100644 --- a/testing/src/test/java/dev/cel/testing/testrunner/BUILD.bazel +++ b/testing/src/test/java/dev/cel/testing/testrunner/BUILD.bazel @@ -51,6 +51,7 @@ java_library( name = "custom_variable_binding_user_test", srcs = ["CustomVariableBindingUserTest.java"], deps = [ + "//bundle:cel", "//testing/testrunner:cel_test_context", "//testing/testrunner:cel_user_test_template", "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", @@ -186,10 +187,6 @@ cel_java_test( name = "custom_variable_binding_test_runner_sample", cel_expr = "custom_variable_bindings/policy.yaml", config = "custom_variable_bindings/config.yaml", - proto_deps = [ - "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_proto", - "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_proto", - ], test_data_path = "//testing/src/test/resources/policy", test_src = ":custom_variable_binding_user_test", test_suite = "custom_variable_bindings/tests.yaml", diff --git a/testing/src/test/java/dev/cel/testing/testrunner/CustomVariableBindingUserTest.java b/testing/src/test/java/dev/cel/testing/testrunner/CustomVariableBindingUserTest.java index 707b5eef9..37382052c 100644 --- a/testing/src/test/java/dev/cel/testing/testrunner/CustomVariableBindingUserTest.java +++ b/testing/src/test/java/dev/cel/testing/testrunner/CustomVariableBindingUserTest.java @@ -15,7 +15,8 @@ package dev.cel.testing.testrunner; import com.google.common.collect.ImmutableMap; -import com.google.protobuf.Any; +import com.google.protobuf.ExtensionRegistry; +import dev.cel.bundle.CelFactory; import dev.cel.expr.conformance.proto2.TestAllTypes; import dev.cel.expr.conformance.proto2.TestAllTypesExtensions; import org.junit.runner.RunWith; @@ -29,15 +30,24 @@ public class CustomVariableBindingUserTest extends CelUserTestTemplate { public CustomVariableBindingUserTest() { - super( - CelTestContext.newBuilder() - .setVariableBindings( - ImmutableMap.of( - "spec", - Any.pack( - TestAllTypes.newBuilder() - .setExtension(TestAllTypesExtensions.int32Ext, 1) - .build()))) - .build()); + super(newTestContext()); + } + + private static CelTestContext newTestContext() { + ExtensionRegistry registry = ExtensionRegistry.newInstance(); + registry.add(TestAllTypesExtensions.int32Ext); + + return CelTestContext.newBuilder() + .setCel( + CelFactory.standardCelBuilder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .addFileTypes(TestAllTypesExtensions.getDescriptor()) + .setExtensionRegistry(registry) + .build()) + .setVariableBindings( + ImmutableMap.of( + "spec", + TestAllTypes.newBuilder().setExtension(TestAllTypesExtensions.int32Ext, 1).build())) + .build(); } } From 54060a2198d674faae14a96bb5dc4744df26de1d Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Mon, 20 Apr 2026 16:39:13 -0700 Subject: [PATCH 42/66] Support parsed-only evaluation to sets extension PiperOrigin-RevId: 902882306 --- .../extensions/SetsExtensionsRuntimeImpl.java | 42 ++- .../cel/extensions/CelSetsExtensionsTest.java | 313 ++++++++---------- .../runtime/CelLiteRuntimeAndroidTest.java | 3 +- 3 files changed, 164 insertions(+), 194 deletions(-) diff --git a/extensions/src/main/java/dev/cel/extensions/SetsExtensionsRuntimeImpl.java b/extensions/src/main/java/dev/cel/extensions/SetsExtensionsRuntimeImpl.java index a42fba189..a02fdba8a 100644 --- a/extensions/src/main/java/dev/cel/extensions/SetsExtensionsRuntimeImpl.java +++ b/extensions/src/main/java/dev/cel/extensions/SetsExtensionsRuntimeImpl.java @@ -45,28 +45,34 @@ ImmutableSet newFunctionBindings() { for (SetsFunction function : functions) { switch (function) { case CONTAINS: - bindingBuilder.add( - CelFunctionBinding.from( - "list_sets_contains_list", - Collection.class, - Collection.class, - this::containsAll)); + bindingBuilder.addAll( + CelFunctionBinding.fromOverloads( + function.getFunction(), + CelFunctionBinding.from( + "list_sets_contains_list", + Collection.class, + Collection.class, + this::containsAll))); break; case EQUIVALENT: - bindingBuilder.add( - CelFunctionBinding.from( - "list_sets_equivalent_list", - Collection.class, - Collection.class, - (listA, listB) -> containsAll(listA, listB) && containsAll(listB, listA))); + bindingBuilder.addAll( + CelFunctionBinding.fromOverloads( + function.getFunction(), + CelFunctionBinding.from( + "list_sets_equivalent_list", + Collection.class, + Collection.class, + (listA, listB) -> containsAll(listA, listB) && containsAll(listB, listA)))); break; case INTERSECTS: - bindingBuilder.add( - CelFunctionBinding.from( - "list_sets_intersects_list", - Collection.class, - Collection.class, - this::setIntersects)); + bindingBuilder.addAll( + CelFunctionBinding.fromOverloads( + function.getFunction(), + CelFunctionBinding.from( + "list_sets_intersects_list", + Collection.class, + Collection.class, + this::setIntersects))); break; } } diff --git a/extensions/src/test/java/dev/cel/extensions/CelSetsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelSetsExtensionsTest.java index 1aac5a023..9007bba2e 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelSetsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelSetsExtensionsTest.java @@ -19,8 +19,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelBuilder; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelFunctionDecl; @@ -30,47 +33,34 @@ import dev.cel.common.CelValidationResult; import dev.cel.common.types.ListType; import dev.cel.common.types.SimpleType; -import dev.cel.compiler.CelCompiler; -import dev.cel.compiler.CelCompilerFactory; import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelRuntime; -import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.testing.CelRuntimeFlavor; import java.util.List; +import java.util.Map; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public final class CelSetsExtensionsTest { - private static final CelOptions CEL_OPTIONS = CelOptions.current().build(); - private static final CelCompiler COMPILER = - CelCompilerFactory.standardCelCompilerBuilder() - .addMessageTypes(TestAllTypes.getDescriptor()) - .setOptions(CEL_OPTIONS) - .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) - .addLibraries(CelExtensions.sets(CEL_OPTIONS)) - .addVar("list", ListType.create(SimpleType.INT)) - .addVar("subList", ListType.create(SimpleType.INT)) - .addFunctionDeclarations( - CelFunctionDecl.newFunctionDeclaration( - "new_int", - CelOverloadDecl.newGlobalOverload( - "new_int_int64", SimpleType.INT, SimpleType.INT))) - .build(); - - private static final CelRuntime RUNTIME = - CelRuntimeFactory.standardCelRuntimeBuilder() - .addMessageTypes(TestAllTypes.getDescriptor()) - .addLibraries(CelExtensions.sets(CEL_OPTIONS)) - .setOptions(CEL_OPTIONS) - .addFunctionBindings( - CelFunctionBinding.from( - "new_int_int64", - Long.class, - // Intentionally return java.lang.Integer to test primitive type adaptation - Math::toIntExact)) - .build(); + private static final CelOptions CEL_OPTIONS = + CelOptions.current().enableHeterogeneousNumericComparisons(true).build(); + + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; + + private Cel cel; + + @Before + public void setUp() { + // Legacy runtime does not support parsed-only evaluation mode. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + this.cel = setupEnv(runtimeFlavor.builder()); + } @Test public void library() { @@ -87,22 +77,14 @@ public void library() { public void contains_integerListWithSameValue_succeeds() throws Exception { ImmutableList list = ImmutableList.of(1, 2, 3, 4); ImmutableList subList = ImmutableList.of(1, 2, 3, 4); - CelAbstractSyntaxTree ast = COMPILER.compile("sets.contains(list, subList)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(ImmutableMap.of("list", list, "subList", subList)); - - assertThat(result).isEqualTo(true); + assertThat( + eval("sets.contains(list, subList)", ImmutableMap.of("list", list, "subList", subList))) + .isEqualTo(true); } @Test public void contains_integerListAsExpression_succeeds() throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("sets.contains([1, 1], [1])").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(true); + assertThat(eval("sets.contains([1, 1], [1])")).isEqualTo(true); } @Test @@ -119,12 +101,7 @@ public void contains_integerListAsExpression_succeeds() throws Exception { + " [TestAllTypes{single_int64: 2, single_uint64: 3u}])', expected: false}") public void contains_withProtoMessage_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -133,12 +110,7 @@ public void contains_withProtoMessage_succeeds(String expression, boolean expect @TestParameters("{expression: 'sets.contains([new_int(2)], [1])', expected: false}") public void contains_withFunctionReturningInteger_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -157,12 +129,9 @@ public void contains_withFunctionReturningInteger_succeeds(String expression, bo @TestParameters("{list: [1], subList: [1, 2], expected: false}") public void contains_withIntTypes_succeeds( List list, List subList, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("sets.contains(list, subList)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(ImmutableMap.of("list", list, "subList", subList)); - - assertThat(result).isEqualTo(expected); + assertThat( + eval("sets.contains(list, subList)", ImmutableMap.of("list", list, "subList", subList))) + .isEqualTo(expected); } @Test @@ -177,12 +146,9 @@ public void contains_withIntTypes_succeeds( @TestParameters("{list: [2, 3.0], subList: [2, 3], expected: true}") public void contains_withDoubleTypes_succeeds( List list, List subList, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("sets.contains(list, subList)").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(ImmutableMap.of("list", list, "subList", subList)); - - assertThat(result).isEqualTo(expected); + assertThat( + eval("sets.contains(list, subList)", ImmutableMap.of("list", list, "subList", subList))) + .isEqualTo(expected); } @Test @@ -193,12 +159,7 @@ public void contains_withDoubleTypes_succeeds( @TestParameters("{expression: 'sets.contains([[1], [2, 3.0]], [[2, 3]])', expected: true}") public void contains_withNestedLists_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -206,19 +167,16 @@ public void contains_withNestedLists_succeeds(String expression, boolean expecte @TestParameters("{expression: 'sets.contains([1], [1, \"1\"])', expected: false}") public void contains_withMixingIntAndString_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test - @TestParameters("{expression: 'sets.contains([1], [\"1\"])'}") - @TestParameters("{expression: 'sets.contains([\"1\"], [1])'}") - public void contains_withMixingIntAndString_throwsException(String expression) throws Exception { - CelValidationResult invalidData = COMPILER.compile(expression); + public void contains_withMixingIntAndString_throwsException( + @TestParameter({"sets.contains([1], [\"1\"])", "sets.contains([\"1\"], [1])"}) + String expression) + throws Exception { + Assume.assumeFalse(isParseOnly); + CelValidationResult invalidData = cel.compile(expression); assertThat(invalidData.getErrors()).hasSize(1); assertThat(invalidData.getErrors().get(0).getMessage()) @@ -227,12 +185,7 @@ public void contains_withMixingIntAndString_throwsException(String expression) t @Test public void contains_withMixedValues_succeeds() throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile("sets.contains([1, 2], [2u, 2.0])").getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(true); + assertThat(eval("sets.contains([1, 2], [2u, 2.0])")).isEqualTo(true); } @Test @@ -249,12 +202,7 @@ public void contains_withMixedValues_succeeds() throws Exception { "{expression: 'sets.contains([[[[[[5]]]]]], [[1], [2, 3.0], [[[[[5]]]]]])', expected: false}") public void contains_withMultiLevelNestedList_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -269,12 +217,7 @@ public void contains_withMultiLevelNestedList_succeeds(String expression, boolea + " false}") public void contains_withMapValues_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -289,12 +232,7 @@ public void contains_withMapValues_succeeds(String expression, boolean expected) @TestParameters("{expression: 'sets.equivalent([1, 2], [2, 2, 2])', expected: false}") public void equivalent_withIntTypes_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -308,12 +246,7 @@ public void equivalent_withIntTypes_succeeds(String expression, boolean expected @TestParameters("{expression: 'sets.equivalent([1, 2], [1u, 2, 2.3])', expected: false}") public void equivalent_withMixedTypes_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -338,12 +271,7 @@ public void equivalent_withMixedTypes_succeeds(String expression, boolean expect + " [TestAllTypes{single_int64: 2, single_uint64: 3u}])', expected: false}") public void equivalent_withProtoMessage_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -361,12 +289,7 @@ public void equivalent_withProtoMessage_succeeds(String expression, boolean expe + " expected: false}") public void equivalent_withMapValues_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -391,12 +314,7 @@ public void equivalent_withMapValues_succeeds(String expression, boolean expecte @TestParameters("{expression: 'sets.intersects([1], [1.1, 2u])', expected: false}") public void intersects_withMixedTypes_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - Object result = program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -414,12 +332,11 @@ public void intersects_withMixedTypes_succeeds(String expression, boolean expect @TestParameters("{expression: 'sets.intersects([{2: 1}], [{1: 1}])', expected: false}") public void intersects_withMapValues_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); + // The LEGACY runtime is not spec compliant, because decimal keys are not allowed for maps. + Assume.assumeFalse( + runtimeFlavor.equals(CelRuntimeFlavor.PLANNER) && expression.contains("1.0:")); - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test @@ -444,25 +361,21 @@ public void intersects_withMapValues_succeeds(String expression, boolean expecte + " [TestAllTypes{single_int64: 2, single_uint64: 3u}])', expected: false}") public void intersects_withProtoMessage_succeeds(String expression, boolean expected) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - - boolean result = (boolean) program.eval(); - - assertThat(result).isEqualTo(expected); + assertThat(eval(expression)).isEqualTo(expected); } @Test public void setsExtension_containsFunctionSubset_succeeds() throws Exception { CelSetsExtensions setsExtensions = CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.CONTAINS); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder().addLibraries(setsExtensions).build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder().addLibraries(setsExtensions).build(); + Cel cel = + runtimeFlavor + .builder() + .addCompilerLibraries(setsExtensions) + .addRuntimeLibraries(setsExtensions) + .build(); - Object evaluatedResult = - celRuntime.createProgram(celCompiler.compile("sets.contains([1, 2], [2])").getAst()).eval(); + Object evaluatedResult = eval(cel, "sets.contains([1, 2], [2])", ImmutableMap.of()); assertThat(evaluatedResult).isEqualTo(true); } @@ -471,15 +384,14 @@ public void setsExtension_containsFunctionSubset_succeeds() throws Exception { public void setsExtension_equivalentFunctionSubset_succeeds() throws Exception { CelSetsExtensions setsExtensions = CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.EQUIVALENT); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder().addLibraries(setsExtensions).build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder().addLibraries(setsExtensions).build(); + Cel cel = + runtimeFlavor + .builder() + .addCompilerLibraries(setsExtensions) + .addRuntimeLibraries(setsExtensions) + .build(); - Object evaluatedResult = - celRuntime - .createProgram(celCompiler.compile("sets.equivalent([1, 1], [1])").getAst()) - .eval(); + Object evaluatedResult = eval(cel, "sets.equivalent([1, 1], [1])", ImmutableMap.of()); assertThat(evaluatedResult).isEqualTo(true); } @@ -488,44 +400,95 @@ public void setsExtension_equivalentFunctionSubset_succeeds() throws Exception { public void setsExtension_intersectsFunctionSubset_succeeds() throws Exception { CelSetsExtensions setsExtensions = CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.INTERSECTS); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder().addLibraries(setsExtensions).build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder().addLibraries(setsExtensions).build(); + Cel cel = + runtimeFlavor + .builder() + .addCompilerLibraries(setsExtensions) + .addRuntimeLibraries(setsExtensions) + .build(); - Object evaluatedResult = - celRuntime - .createProgram(celCompiler.compile("sets.intersects([1, 1], [1])").getAst()) - .eval(); + Object evaluatedResult = eval(cel, "sets.intersects([1, 1], [1])", ImmutableMap.of()); assertThat(evaluatedResult).isEqualTo(true); } @Test public void setsExtension_compileUnallowedFunction_throws() { + Assume.assumeFalse(isParseOnly); CelSetsExtensions setsExtensions = CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.EQUIVALENT); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder().addLibraries(setsExtensions).build(); + Cel cel = runtimeFlavor.builder().addCompilerLibraries(setsExtensions).build(); assertThrows( - CelValidationException.class, - () -> celCompiler.compile("sets.contains([1, 2], [2])").getAst()); + CelValidationException.class, () -> cel.compile("sets.contains([1, 2], [2])").getAst()); } @Test public void setsExtension_evaluateUnallowedFunction_throws() throws Exception { CelSetsExtensions setsExtensions = CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.CONTAINS, SetsFunction.EQUIVALENT); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder().addLibraries(setsExtensions).build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() - .addLibraries(CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.EQUIVALENT)) + CelSetsExtensions runtimeLibrary = + CelExtensions.sets(CelOptions.DEFAULT, SetsFunction.EQUIVALENT); + Cel cel = + runtimeFlavor + .builder() + .addCompilerLibraries(setsExtensions) + .addRuntimeLibraries(runtimeLibrary) .build(); - CelAbstractSyntaxTree ast = celCompiler.compile("sets.contains([1, 2], [2])").getAst(); + CelAbstractSyntaxTree ast = + isParseOnly + ? cel.parse("sets.contains([1, 2], [2])").getAst() + : cel.compile("sets.contains([1, 2], [2])").getAst(); + + if (runtimeFlavor.equals(CelRuntimeFlavor.PLANNER) && !isParseOnly) { + // Fails at plan time + assertThrows(CelEvaluationException.class, () -> cel.createProgram(ast)); + } else { + CelRuntime.Program program = cel.createProgram(ast); + assertThrows(CelEvaluationException.class, () -> program.eval()); + } + } + + private Object eval(Cel cel, String expression, Map variables) throws Exception { + CelAbstractSyntaxTree ast; + if (isParseOnly) { + ast = cel.parse(expression).getAst(); + } else { + ast = cel.compile(expression).getAst(); + } + return cel.createProgram(ast).eval(variables); + } + + private Object eval(String expression) throws Exception { + return eval(this.cel, expression, ImmutableMap.of()); + } + + private Object eval(String expression, Map variables) throws Exception { + return eval(this.cel, expression, variables); + } - assertThrows(CelEvaluationException.class, () -> celRuntime.createProgram(ast).eval()); + private static Cel setupEnv(CelBuilder celBuilder) { + return celBuilder + .addMessageTypes(TestAllTypes.getDescriptor()) + .setOptions(CEL_OPTIONS) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .addCompilerLibraries(CelExtensions.sets(CEL_OPTIONS)) + .addRuntimeLibraries(CelExtensions.sets(CEL_OPTIONS)) + .addVar("list", ListType.create(SimpleType.INT)) + .addVar("subList", ListType.create(SimpleType.INT)) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "new_int", + CelOverloadDecl.newGlobalOverload("new_int_int64", SimpleType.INT, SimpleType.INT))) + .addFunctionBindings( + CelFunctionBinding.fromOverloads( + "new_int", + CelFunctionBinding.from( + "new_int_int64", + Long.class, + // Intentionally return java.lang.Integer to test primitive type adaptation + Math::toIntExact))) + .build(); } } diff --git a/runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeAndroidTest.java b/runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeAndroidTest.java index 54ce24417..73492d126 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeAndroidTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeAndroidTest.java @@ -149,9 +149,10 @@ public void toRuntimeBuilder_propertiesCopied() { assertThat(newRuntimeBuilder.standardFunctionBuilder.build()) .containsExactly(intFunction, equalsOperator) .inOrder(); - assertThat(newRuntimeBuilder.customFunctionBindings).hasSize(2); + assertThat(newRuntimeBuilder.customFunctionBindings).hasSize(3); assertThat(newRuntimeBuilder.customFunctionBindings).containsKey("string_isEmpty"); assertThat(newRuntimeBuilder.customFunctionBindings).containsKey("list_sets_intersects_list"); + assertThat(newRuntimeBuilder.customFunctionBindings).containsKey("sets.intersects"); } @Test From d12fb7e6b7bec515a0214f245feda2a1348a5a55 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 22 Apr 2026 13:42:38 -0700 Subject: [PATCH 43/66] Support parsed only evaluation to math extension. Remove signed long support for `uint` PiperOrigin-RevId: 904020157 --- .../main/java/dev/cel/extensions/BUILD.bazel | 2 +- .../dev/cel/extensions/CelExtensions.java | 69 ++- .../dev/cel/extensions/CelMathExtensions.java | 263 ++++++----- .../cel/extensions/CelMathExtensionsTest.java | 417 ++++++++---------- 4 files changed, 353 insertions(+), 398 deletions(-) diff --git a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel index 2eb26846f..f8e4bfc8c 100644 --- a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel @@ -42,6 +42,7 @@ java_library( ":strings", "//common:options", "//extensions:extension_library", + "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", ], ) @@ -121,7 +122,6 @@ java_library( ":extension_library", "//checker:checker_builder", "//common:compiler_common", - "//common:options", "//common/ast", "//common/exceptions:numeric_overflow", "//common/internal:comparison_functions", diff --git a/extensions/src/main/java/dev/cel/extensions/CelExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelExtensions.java index 2d14ed118..8f1770f3f 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelExtensions.java @@ -19,7 +19,9 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; +import com.google.errorprone.annotations.InlineMe; import dev.cel.common.CelOptions; +import dev.cel.extensions.CelMathExtensions.Function; import java.util.Set; /** @@ -121,12 +123,9 @@ public static CelProtoExtensions protos() { *

This will include all functions denoted in {@link CelMathExtensions.Function}, including any * future additions. To expose only a subset of these, use {@link #math(CelOptions, * CelMathExtensions.Function...)} or {@link #math(CelOptions,int)} instead. - * - * @param celOptions CelOptions to configure CelMathExtension with. This should be the same - * options object used to configure the compilation/runtime environments. */ - public static CelMathExtensions math(CelOptions celOptions) { - return CelMathExtensions.library(celOptions).latest(); + public static CelMathExtensions math() { + return CelMathExtensions.library().latest(); } /** @@ -134,8 +133,8 @@ public static CelMathExtensions math(CelOptions celOptions) { * *

Refer to README.md for functions available in each version. */ - public static CelMathExtensions math(CelOptions celOptions, int version) { - return CelMathExtensions.library(celOptions).version(version); + public static CelMathExtensions math(int version) { + return CelMathExtensions.library().version(version); } /** @@ -150,13 +149,9 @@ public static CelMathExtensions math(CelOptions celOptions, int version) { * collision. * *

This will include only the specific functions denoted by {@link CelMathExtensions.Function}. - * - * @param celOptions CelOptions to configure CelMathExtension with. This should be the same - * options object used to configure the compilation/runtime environments. */ - public static CelMathExtensions math( - CelOptions celOptions, CelMathExtensions.Function... functions) { - return math(celOptions, ImmutableSet.copyOf(functions)); + public static CelMathExtensions math(CelMathExtensions.Function... functions) { + return math(ImmutableSet.copyOf(functions)); } /** @@ -171,13 +166,49 @@ public static CelMathExtensions math( * collision. * *

This will include only the specific functions denoted by {@link CelMathExtensions.Function}. - * - * @param celOptions CelOptions to configure CelMathExtension with. This should be the same - * options object used to configure the compilation/runtime environments. */ + public static CelMathExtensions math(Set functions) { + return new CelMathExtensions(functions); + } + + /** + * @deprecated Use {@link #math()} instead. + */ + @Deprecated + @InlineMe(replacement = "CelExtensions.math()", imports = "dev.cel.extensions.CelExtensions") + public static CelMathExtensions math(CelOptions unused) { + return math(); + } + + /** + * @deprecated Use {@link #math(int)} instead. + */ + @Deprecated + @InlineMe( + replacement = "CelExtensions.math(version)", + imports = "dev.cel.extensions.CelExtensions") + public static CelMathExtensions math(CelOptions unused, int version) { + return math(version); + } + + /** + * @deprecated Use {@link #math(Function...)} instead. + */ + @Deprecated + public static CelMathExtensions math(CelOptions unused, CelMathExtensions.Function... functions) { + return math(ImmutableSet.copyOf(functions)); + } + + /** + * @deprecated Use {@link #math(Set)} instead. + */ + @Deprecated + @InlineMe( + replacement = "CelExtensions.math(functions)", + imports = "dev.cel.extensions.CelExtensions") public static CelMathExtensions math( - CelOptions celOptions, Set functions) { - return new CelMathExtensions(celOptions, functions); + CelOptions unused, Set functions) { + return math(functions); } /** @@ -354,7 +385,7 @@ public static CelExtensionLibrary getE case "lists": return CelListsExtensions.library(); case "math": - return CelMathExtensions.library(options); + return CelMathExtensions.library(); case "optional": return CelOptionalLibrary.library(); case "protos": diff --git a/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java index 22336eb22..78a0fd51c 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java @@ -27,7 +27,6 @@ import dev.cel.checker.CelCheckerBuilder; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelIssue; -import dev.cel.common.CelOptions; import dev.cel.common.CelOverloadDecl; import dev.cel.common.ast.CelConstant; import dev.cel.common.ast.CelExpr; @@ -136,7 +135,8 @@ public final class CelMathExtensions return builder.buildOrThrow(); } - enum Function { + /** Enumeration of functions for Math extension. */ + public enum Function { MAX( CelFunctionDecl.newFunctionDeclaration( MATH_MAX_FUNCTION, @@ -206,51 +206,59 @@ enum Function { MATH_MAX_OVERLOAD_DOC, SimpleType.DYN, ListType.create(SimpleType.DYN))), - ImmutableSet.of( - CelFunctionBinding.from("math_@max_double", Double.class, x -> x), - CelFunctionBinding.from("math_@max_int", Long.class, x -> x), - CelFunctionBinding.from( - "math_@max_double_double", Double.class, Double.class, CelMathExtensions::maxPair), - CelFunctionBinding.from( - "math_@max_int_int", Long.class, Long.class, CelMathExtensions::maxPair), - CelFunctionBinding.from( - "math_@max_int_double", Long.class, Double.class, CelMathExtensions::maxPair), - CelFunctionBinding.from( - "math_@max_double_int", Double.class, Long.class, CelMathExtensions::maxPair), - CelFunctionBinding.from("math_@max_list_dyn", List.class, CelMathExtensions::maxList)), - ImmutableSet.of( - CelFunctionBinding.from("math_@max_uint", Long.class, x -> x), - CelFunctionBinding.from( - "math_@max_uint_uint", Long.class, Long.class, CelMathExtensions::maxPair), - CelFunctionBinding.from( - "math_@max_double_uint", Double.class, Long.class, CelMathExtensions::maxPair), - CelFunctionBinding.from( - "math_@max_uint_int", Long.class, Long.class, CelMathExtensions::maxPair), - CelFunctionBinding.from( - "math_@max_uint_double", Long.class, Double.class, CelMathExtensions::maxPair), - CelFunctionBinding.from( - "math_@max_int_uint", Long.class, Long.class, CelMathExtensions::maxPair)), - ImmutableSet.of( - CelFunctionBinding.from("math_@max_uint", UnsignedLong.class, x -> x), - CelFunctionBinding.from( - "math_@max_uint_uint", - UnsignedLong.class, - UnsignedLong.class, - CelMathExtensions::maxPair), - CelFunctionBinding.from( - "math_@max_double_uint", - Double.class, - UnsignedLong.class, - CelMathExtensions::maxPair), - CelFunctionBinding.from( - "math_@max_uint_int", UnsignedLong.class, Long.class, CelMathExtensions::maxPair), - CelFunctionBinding.from( - "math_@max_uint_double", - UnsignedLong.class, - Double.class, - CelMathExtensions::maxPair), - CelFunctionBinding.from( - "math_@max_int_uint", Long.class, UnsignedLong.class, CelMathExtensions::maxPair))), + ImmutableSet.builder() + .add(CelFunctionBinding.from("math_@max_double", Double.class, x -> x)) + .add(CelFunctionBinding.from("math_@max_int", Long.class, x -> x)) + .add( + CelFunctionBinding.from( + "math_@max_double_double", + Double.class, + Double.class, + CelMathExtensions::maxPair)) + .add( + CelFunctionBinding.from( + "math_@max_int_int", Long.class, Long.class, CelMathExtensions::maxPair)) + .add( + CelFunctionBinding.from( + "math_@max_int_double", Long.class, Double.class, CelMathExtensions::maxPair)) + .add( + CelFunctionBinding.from( + "math_@max_double_int", Double.class, Long.class, CelMathExtensions::maxPair)) + .add( + CelFunctionBinding.from( + "math_@max_list_dyn", List.class, CelMathExtensions::maxList)) + .add(CelFunctionBinding.from("math_@max_uint", UnsignedLong.class, x -> x)) + .add( + CelFunctionBinding.from( + "math_@max_uint_uint", + UnsignedLong.class, + UnsignedLong.class, + CelMathExtensions::maxPair)) + .add( + CelFunctionBinding.from( + "math_@max_double_uint", + Double.class, + UnsignedLong.class, + CelMathExtensions::maxPair)) + .add( + CelFunctionBinding.from( + "math_@max_uint_int", + UnsignedLong.class, + Long.class, + CelMathExtensions::maxPair)) + .add( + CelFunctionBinding.from( + "math_@max_uint_double", + UnsignedLong.class, + Double.class, + CelMathExtensions::maxPair)) + .add( + CelFunctionBinding.from( + "math_@max_int_uint", + Long.class, + UnsignedLong.class, + CelMathExtensions::maxPair)) + .build()), MIN( CelFunctionDecl.newFunctionDeclaration( MATH_MIN_FUNCTION, @@ -320,51 +328,59 @@ enum Function { MATH_MIN_OVERLOAD_DOC, SimpleType.DYN, ListType.create(SimpleType.DYN))), - ImmutableSet.of( - CelFunctionBinding.from("math_@min_double", Double.class, x -> x), - CelFunctionBinding.from("math_@min_int", Long.class, x -> x), - CelFunctionBinding.from( - "math_@min_double_double", Double.class, Double.class, CelMathExtensions::minPair), - CelFunctionBinding.from( - "math_@min_int_int", Long.class, Long.class, CelMathExtensions::minPair), - CelFunctionBinding.from( - "math_@min_int_double", Long.class, Double.class, CelMathExtensions::minPair), - CelFunctionBinding.from( - "math_@min_double_int", Double.class, Long.class, CelMathExtensions::minPair), - CelFunctionBinding.from("math_@min_list_dyn", List.class, CelMathExtensions::minList)), - ImmutableSet.of( - CelFunctionBinding.from("math_@min_uint", Long.class, x -> x), - CelFunctionBinding.from( - "math_@min_uint_uint", Long.class, Long.class, CelMathExtensions::minPair), - CelFunctionBinding.from( - "math_@min_double_uint", Double.class, Long.class, CelMathExtensions::minPair), - CelFunctionBinding.from( - "math_@min_uint_int", Long.class, Long.class, CelMathExtensions::minPair), - CelFunctionBinding.from( - "math_@min_uint_double", Long.class, Double.class, CelMathExtensions::minPair), - CelFunctionBinding.from( - "math_@min_int_uint", Long.class, Long.class, CelMathExtensions::minPair)), - ImmutableSet.of( - CelFunctionBinding.from("math_@min_uint", UnsignedLong.class, x -> x), - CelFunctionBinding.from( - "math_@min_uint_uint", - UnsignedLong.class, - UnsignedLong.class, - CelMathExtensions::minPair), - CelFunctionBinding.from( - "math_@min_double_uint", - Double.class, - UnsignedLong.class, - CelMathExtensions::minPair), - CelFunctionBinding.from( - "math_@min_uint_int", UnsignedLong.class, Long.class, CelMathExtensions::minPair), - CelFunctionBinding.from( - "math_@min_uint_double", - UnsignedLong.class, - Double.class, - CelMathExtensions::minPair), - CelFunctionBinding.from( - "math_@min_int_uint", Long.class, UnsignedLong.class, CelMathExtensions::minPair))), + ImmutableSet.builder() + .add(CelFunctionBinding.from("math_@min_double", Double.class, x -> x)) + .add(CelFunctionBinding.from("math_@min_int", Long.class, x -> x)) + .add( + CelFunctionBinding.from( + "math_@min_double_double", + Double.class, + Double.class, + CelMathExtensions::minPair)) + .add( + CelFunctionBinding.from( + "math_@min_int_int", Long.class, Long.class, CelMathExtensions::minPair)) + .add( + CelFunctionBinding.from( + "math_@min_int_double", Long.class, Double.class, CelMathExtensions::minPair)) + .add( + CelFunctionBinding.from( + "math_@min_double_int", Double.class, Long.class, CelMathExtensions::minPair)) + .add( + CelFunctionBinding.from( + "math_@min_list_dyn", List.class, CelMathExtensions::minList)) + .add(CelFunctionBinding.from("math_@min_uint", UnsignedLong.class, x -> x)) + .add( + CelFunctionBinding.from( + "math_@min_uint_uint", + UnsignedLong.class, + UnsignedLong.class, + CelMathExtensions::minPair)) + .add( + CelFunctionBinding.from( + "math_@min_double_uint", + Double.class, + UnsignedLong.class, + CelMathExtensions::minPair)) + .add( + CelFunctionBinding.from( + "math_@min_uint_int", + UnsignedLong.class, + Long.class, + CelMathExtensions::minPair)) + .add( + CelFunctionBinding.from( + "math_@min_uint_double", + UnsignedLong.class, + Double.class, + CelMathExtensions::minPair)) + .add( + CelFunctionBinding.from( + "math_@min_int_uint", + Long.class, + UnsignedLong.class, + CelMathExtensions::minPair)) + .build()), CEIL( CelFunctionDecl.newFunctionDeclaration( MATH_CEIL_FUNCTION, @@ -646,36 +662,14 @@ enum Function { private final CelFunctionDecl functionDecl; private final ImmutableSet functionBindings; - private final ImmutableSet functionBindingsULongSigned; - private final ImmutableSet functionBindingsULongUnsigned; String getFunction() { return functionDecl.name(); } Function(CelFunctionDecl functionDecl, ImmutableSet bindings) { - this(functionDecl, bindings, ImmutableSet.of(), ImmutableSet.of()); - } - - Function( - CelFunctionDecl functionDecl, - ImmutableSet functionBindings, - ImmutableSet functionBindingsULongSigned, - ImmutableSet functionBindingsULongUnsigned) { this.functionDecl = functionDecl; - this.functionBindings = - functionBindings.isEmpty() - ? ImmutableSet.of() - : CelFunctionBinding.fromOverloads(functionDecl.name(), functionBindings); - this.functionBindingsULongSigned = - functionBindingsULongSigned.isEmpty() - ? ImmutableSet.of() - : CelFunctionBinding.fromOverloads(functionDecl.name(), functionBindingsULongSigned); - this.functionBindingsULongUnsigned = - functionBindingsULongUnsigned.isEmpty() - ? ImmutableSet.of() - : CelFunctionBinding.fromOverloads( - functionDecl.name(), functionBindingsULongUnsigned); + this.functionBindings = bindings; } } @@ -684,10 +678,8 @@ private static final class Library implements CelExtensionLibrarybuilder() .addAll(version1.functions) .add(Function.SQRT) - .build(), - enableUnsignedLongs); + .build()); } @Override @@ -734,25 +724,20 @@ public ImmutableSet versions() { } } - private static final Library LIBRARY_UNSIGNED_LONGS_ENABLED = new Library(true); - private static final Library LIBRARY_UNSIGNED_LONGS_DISABLED = new Library(false); + private static final Library LIBRARY = new Library(); - static CelExtensionLibrary library(CelOptions celOptions) { - return celOptions.enableUnsignedLongs() - ? LIBRARY_UNSIGNED_LONGS_ENABLED - : LIBRARY_UNSIGNED_LONGS_DISABLED; + static CelExtensionLibrary library() { + return LIBRARY; } - private final boolean enableUnsignedLongs; private final ImmutableSet functions; private final int version; - CelMathExtensions(CelOptions celOptions, Set functions) { - this(-1, functions, celOptions.enableUnsignedLongs()); + CelMathExtensions(Set functions) { + this(-1, functions); } - private CelMathExtensions(int version, Set functions, boolean enableUnsignedLongs) { - this.enableUnsignedLongs = enableUnsignedLongs; + private CelMathExtensions(int version, Set functions) { this.version = version; this.functions = ImmutableSet.copyOf(functions); } @@ -788,11 +773,11 @@ public void setCheckerOptions(CelCheckerBuilder checkerBuilder) { public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) { functions.forEach( function -> { - runtimeBuilder.addFunctionBindings(function.functionBindings); - runtimeBuilder.addFunctionBindings( - enableUnsignedLongs - ? function.functionBindingsULongUnsigned - : function.functionBindingsULongSigned); + ImmutableSet combined = function.functionBindings; + if (!combined.isEmpty()) { + runtimeBuilder.addFunctionBindings( + CelFunctionBinding.fromOverloads(function.functionDecl.name(), combined)); + } }); } diff --git a/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java index bcdfb0a21..383e50aa2 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java @@ -20,8 +20,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.primitives.UnsignedLong; +import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; @@ -35,34 +37,36 @@ import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelRuntime; import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.testing.CelRuntimeFlavor; +import java.util.Map; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public class CelMathExtensionsTest { - private static final CelOptions CEL_OPTIONS = - CelOptions.current().enableUnsignedLongs(false).build(); - private static final CelCompiler CEL_COMPILER = - CelCompilerFactory.standardCelCompilerBuilder() - .setOptions(CEL_OPTIONS) - .addLibraries(CelExtensions.math(CEL_OPTIONS)) - .build(); - private static final CelRuntime CEL_RUNTIME = - CelRuntimeFactory.standardCelRuntimeBuilder() - .setOptions(CEL_OPTIONS) - .addLibraries(CelExtensions.math(CEL_OPTIONS)) - .build(); - private static final CelOptions CEL_UNSIGNED_OPTIONS = CelOptions.current().build(); - private static final CelCompiler CEL_UNSIGNED_COMPILER = - CelCompilerFactory.standardCelCompilerBuilder() - .setOptions(CEL_UNSIGNED_OPTIONS) - .addLibraries(CelExtensions.math(CEL_UNSIGNED_OPTIONS)) - .build(); - private static final CelRuntime CEL_UNSIGNED_RUNTIME = - CelRuntimeFactory.standardCelRuntimeBuilder() - .setOptions(CEL_UNSIGNED_OPTIONS) - .addLibraries(CelExtensions.math(CEL_UNSIGNED_OPTIONS)) - .build(); + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; + + private Cel cel; + + @Before + public void setUp() { + // Legacy runtime does not support parsed-only evaluation mode. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + this.cel = + runtimeFlavor + .builder() + .setOptions( + CelOptions.current() + .enableHeterogeneousNumericComparisons( + runtimeFlavor.equals(CelRuntimeFlavor.PLANNER)) + .build()) + .addCompilerLibraries(CelExtensions.math()) + .addRuntimeLibraries(CelExtensions.math()) + .build(); + } @Test @TestParameters("{expr: 'math.greatest(-5)', expectedResult: -5}") @@ -97,9 +101,7 @@ public class CelMathExtensionsTest { "{expr: 'math.greatest([dyn(5.4), dyn(10), dyn(3u), dyn(-5.0), dyn(3.5)])', expectedResult:" + " 10}") public void greatest_intResult_success(String expr, long expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = eval(expr); assertThat(result).isEqualTo(expectedResult); } @@ -136,9 +138,7 @@ public void greatest_intResult_success(String expr, long expectedResult) throws "{expr: 'math.greatest([dyn(5.4), dyn(10.0), dyn(3u), dyn(-5.0), dyn(3.5)])', expectedResult:" + " 10.0}") public void greatest_doubleResult_success(String expr, double expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = eval(expr); assertThat(result).isEqualTo(expectedResult); } @@ -163,16 +163,16 @@ public void greatest_doubleResult_success(String expr, double expectedResult) th + " '10.0'}") public void greatest_doubleResult_withUnsignedLongsEnabled_success( String expr, double expectedResult) throws Exception { - CelOptions celOptions = CelOptions.current().enableUnsignedLongs(true).build(); + CelOptions celOptions = CelOptions.DEFAULT; CelCompiler celCompiler = CelCompilerFactory.standardCelCompilerBuilder() .setOptions(celOptions) - .addLibraries(CelExtensions.math(celOptions)) + .addLibraries(CelExtensions.math()) .build(); CelRuntime celRuntime = CelRuntimeFactory.standardCelRuntimeBuilder() .setOptions(celOptions) - .addLibraries(CelExtensions.math(celOptions)) + .addLibraries(CelExtensions.math()) .build(); CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst(); @@ -182,44 +182,7 @@ public void greatest_doubleResult_withUnsignedLongsEnabled_success( } @Test - @TestParameters("{expr: 'math.greatest(5u)', expectedResult: 5}") - @TestParameters("{expr: 'math.greatest(1u, 1.0)', expectedResult: 1}") - @TestParameters("{expr: 'math.greatest(1u, 1)', expectedResult: 1}") - @TestParameters("{expr: 'math.greatest(1u, 1u)', expectedResult: 1}") - @TestParameters("{expr: 'math.greatest(3u, 3.0)', expectedResult: 3}") - @TestParameters("{expr: 'math.greatest(9u, 10u)', expectedResult: 10}") - @TestParameters("{expr: 'math.greatest(15u, 14u)', expectedResult: 15}") - @TestParameters( - "{expr: 'math.greatest(1, 9223372036854775807u)', expectedResult: 9223372036854775807}") - @TestParameters( - "{expr: 'math.greatest(9223372036854775807u, 1)', expectedResult: 9223372036854775807}") - @TestParameters("{expr: 'math.greatest(1u, 1, 1)', expectedResult: 1}") - @TestParameters("{expr: 'math.greatest(3u, 1u, 10u)', expectedResult: 10}") - @TestParameters("{expr: 'math.greatest(1u, 5u, 2u)', expectedResult: 5}") - @TestParameters("{expr: 'math.greatest(-1, 1u, 0u)', expectedResult: 1}") - @TestParameters("{expr: 'math.greatest(dyn(1u), 1, 1.0)', expectedResult: 1}") - @TestParameters("{expr: 'math.greatest(5u, 1.0, 3u)', expectedResult: 5}") - @TestParameters("{expr: 'math.greatest(5.4, 10u, 3u, -5.0, 3.5)', expectedResult: 10}") - @TestParameters( - "{expr: 'math.greatest(5.4, 10, 3u, -5.0, 9223372036854775807)', expectedResult:" - + " 9223372036854775807}") - @TestParameters( - "{expr: 'math.greatest(9223372036854775807, 10, 3u, -5.0, 0)', expectedResult:" - + " 9223372036854775807}") - @TestParameters("{expr: 'math.greatest([5.4, 10, 3u, -5.0, 3.5])', expectedResult: 10}") - @TestParameters( - "{expr: 'math.greatest([dyn(5.4), dyn(10), dyn(3u), dyn(-5.0), dyn(3.5)])', expectedResult:" - + " 10}") - public void greatest_unsignedLongResult_withSignedLongType_success( - String expr, long expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - Object result = CEL_RUNTIME.createProgram(ast).eval(); - - assertThat(result).isEqualTo(expectedResult); - } - - @Test + @TestParameters("{expr: 'math.greatest(5u)', expectedResult: '5'}") @TestParameters( "{expr: 'math.greatest(18446744073709551615u)', expectedResult: '18446744073709551615'}") @TestParameters("{expr: 'math.greatest(1u, 1.0)', expectedResult: '1'}") @@ -251,16 +214,16 @@ public void greatest_unsignedLongResult_withSignedLongType_success( + " '10'}") public void greatest_unsignedLongResult_withUnsignedLongType_success( String expr, String expectedResult) throws Exception { - CelOptions celOptions = CelOptions.current().enableUnsignedLongs(true).build(); + CelOptions celOptions = CelOptions.DEFAULT; CelCompiler celCompiler = CelCompilerFactory.standardCelCompilerBuilder() .setOptions(celOptions) - .addLibraries(CelExtensions.math(celOptions)) + .addLibraries(CelExtensions.math()) .build(); CelRuntime celRuntime = CelRuntimeFactory.standardCelRuntimeBuilder() .setOptions(celOptions) - .addLibraries(CelExtensions.math(celOptions)) + .addLibraries(CelExtensions.math()) .build(); CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst(); @@ -271,9 +234,9 @@ public void greatest_unsignedLongResult_withUnsignedLongType_success( @Test public void greatest_noArgs_throwsCompilationException() { + Assume.assumeFalse(isParseOnly); CelValidationException e = - assertThrows( - CelValidationException.class, () -> CEL_COMPILER.compile("math.greatest()").getAst()); + assertThrows(CelValidationException.class, () -> cel.compile("math.greatest()").getAst()); assertThat(e).hasMessageThat().contains("math.greatest() requires at least one argument"); } @@ -283,8 +246,9 @@ public void greatest_noArgs_throwsCompilationException() { @TestParameters("{expr: 'math.greatest({})'}") @TestParameters("{expr: 'math.greatest([])'}") public void greatest_invalidSingleArg_throwsCompilationException(String expr) { + Assume.assumeFalse(isParseOnly); CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("math.greatest() invalid single argument value"); } @@ -297,8 +261,9 @@ public void greatest_invalidSingleArg_throwsCompilationException(String expr) { @TestParameters("{expr: 'math.greatest([1, {}, 2])'}") @TestParameters("{expr: 'math.greatest([1, [], 2])'}") public void greatest_invalidArgs_throwsCompilationException(String expr) { + Assume.assumeFalse(isParseOnly); CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e) .hasMessageThat() @@ -312,19 +277,16 @@ public void greatest_invalidArgs_throwsCompilationException(String expr) { @TestParameters("{expr: 'math.greatest([1, dyn({}), 2])'}") @TestParameters("{expr: 'math.greatest([1, dyn([]), 2])'}") public void greatest_invalidDynArgs_throwsRuntimeException(String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> CEL_RUNTIME.createProgram(ast).eval()); + CelEvaluationException e = assertThrows(CelEvaluationException.class, () -> eval(expr)); - assertThat(e).hasMessageThat().contains("Function 'math_@max_list_dyn' failed with arg(s)"); + assertThat(e).hasMessageThat().contains("failed with arg(s)"); } @Test public void greatest_listVariableIsEmpty_throwsRuntimeException() throws Exception { CelCompiler celCompiler = CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.math(CEL_OPTIONS)) + .addLibraries(CelExtensions.math()) .addVar("listVar", ListType.create(SimpleType.INT)) .build(); CelAbstractSyntaxTree ast = celCompiler.compile("math.greatest(listVar)").getAst(); @@ -332,12 +294,9 @@ public void greatest_listVariableIsEmpty_throwsRuntimeException() throws Excepti CelEvaluationException e = assertThrows( CelEvaluationException.class, - () -> - CEL_RUNTIME - .createProgram(ast) - .eval(ImmutableMap.of("listVar", ImmutableList.of()))); + () -> cel.createProgram(ast).eval(ImmutableMap.of("listVar", ImmutableList.of()))); - assertThat(e).hasMessageThat().contains("Function 'math_@max_list_dyn' failed with arg(s)"); + assertThat(e).hasMessageThat().contains("failed with arg(s)"); assertThat(e) .hasCauseThat() .hasMessageThat() @@ -347,25 +306,25 @@ public void greatest_listVariableIsEmpty_throwsRuntimeException() throws Excepti @Test @TestParameters("{expr: '100.greatest(1) == 1'}") @TestParameters("{expr: 'dyn(100).greatest(1) == 1'}") - public void greatest_nonProtoNamespace_success(String expr) throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.math(CEL_OPTIONS)) + public void greatest_nonMathNamespace_success(String expr) throws Exception { + Cel cel = + runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.math()) + .addRuntimeLibraries(CelExtensions.math()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "greatest", CelOverloadDecl.newMemberOverload( "int_greatest_int", SimpleType.INT, SimpleType.INT, SimpleType.INT))) - .build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "int_greatest_int", Long.class, Long.class, (arg1, arg2) -> arg2)) + CelFunctionBinding.fromOverloads( + "greatest", + CelFunctionBinding.from( + "int_greatest_int", Long.class, Long.class, (arg1, arg2) -> arg2))) .build(); - CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst(); - boolean result = (boolean) celRuntime.createProgram(ast).eval(); + boolean result = (boolean) eval(cel, expr); assertThat(result).isTrue(); } @@ -400,13 +359,14 @@ public void greatest_nonProtoNamespace_success(String expr) throws Exception { "{expr: 'math.least(-9223372036854775808, 10, 3u, -5.0, 0)', expectedResult:" + " -9223372036854775808}") @TestParameters("{expr: 'math.least([5.4, -10, 3u, -5.0, 3.5])', expectedResult: -10}") + @TestParameters("{expr: 'math.least(1, 9223372036854775807u)', expectedResult: 1}") + @TestParameters("{expr: 'math.least(9223372036854775807u, 1)', expectedResult: 1}") + @TestParameters("{expr: 'math.least(9223372036854775807, 10, 3u, 5.0, 0)', expectedResult: 0}") @TestParameters( "{expr: 'math.least([dyn(5.4), dyn(-10), dyn(3u), dyn(-5.0), dyn(3.5)])', expectedResult:" + " -10}") public void least_intResult_success(String expr, long expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = eval(expr); assertThat(result).isEqualTo(expectedResult); } @@ -443,9 +403,7 @@ public void least_intResult_success(String expr, long expectedResult) throws Exc "{expr: 'math.least([dyn(5.4), dyn(10.0), dyn(3u), dyn(-5.0), dyn(3.5)])', expectedResult:" + " -5.0}") public void least_doubleResult_success(String expr, double expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = eval(expr); assertThat(result).isEqualTo(expectedResult); } @@ -474,12 +432,12 @@ public void least_doubleResult_withUnsignedLongsEnabled_success( CelCompiler celCompiler = CelCompilerFactory.standardCelCompilerBuilder() .setOptions(celOptions) - .addLibraries(CelExtensions.math(celOptions)) + .addLibraries(CelExtensions.math()) .build(); CelRuntime celRuntime = CelRuntimeFactory.standardCelRuntimeBuilder() .setOptions(celOptions) - .addLibraries(CelExtensions.math(celOptions)) + .addLibraries(CelExtensions.math()) .build(); CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst(); @@ -489,37 +447,15 @@ public void least_doubleResult_withUnsignedLongsEnabled_success( } @Test - @TestParameters("{expr: 'math.least(5u)', expectedResult: 5}") - @TestParameters("{expr: 'math.least(1u, 1.0)', expectedResult: 1}") - @TestParameters("{expr: 'math.least(1u, 1)', expectedResult: 1}") - @TestParameters("{expr: 'math.least(1u, 1u)', expectedResult: 1}") - @TestParameters("{expr: 'math.least(3u, 3.0)', expectedResult: 3}") - @TestParameters("{expr: 'math.least(9u, 10u)', expectedResult: 9}") - @TestParameters("{expr: 'math.least(15u, 14u)', expectedResult: 14}") - @TestParameters("{expr: 'math.least(1, 9223372036854775807u)', expectedResult: 1}") - @TestParameters("{expr: 'math.least(9223372036854775807u, 1)', expectedResult: 1}") - @TestParameters("{expr: 'math.least(1u, 1, 1)', expectedResult: 1}") - @TestParameters("{expr: 'math.least(3u, 1u, 10u)', expectedResult: 1}") - @TestParameters("{expr: 'math.least(1u, 5u, 2u)', expectedResult: 1}") - @TestParameters("{expr: 'math.least(9, 1u, 0u)', expectedResult: 0}") - @TestParameters("{expr: 'math.least(dyn(1u), 1, 1.0)', expectedResult: 1}") - @TestParameters("{expr: 'math.least(5.0, 1u, 3u)', expectedResult: 1}") - @TestParameters("{expr: 'math.least(5.4, 1u, 3u, 9, 3.5)', expectedResult: 1}") - @TestParameters("{expr: 'math.least(5.4, 10, 3u, 5.0, 9223372036854775807)', expectedResult: 3}") - @TestParameters("{expr: 'math.least(9223372036854775807, 10, 3u, 5.0, 0)', expectedResult: 0}") - @TestParameters("{expr: 'math.least([5.4, 10, 3u, 5.0, 3.5])', expectedResult: 3}") + @TestParameters("{expr: 'math.least(9, 1u, 0u)', expectedResult: '0'}") + @TestParameters("{expr: 'math.least(dyn(1u), 1, 1.0)', expectedResult: '1'}") + @TestParameters("{expr: 'math.least(5.0, 1u, 3u)', expectedResult: '1'}") + @TestParameters("{expr: 'math.least(5.4, 1u, 3u, 9, 3.5)', expectedResult: '1'}") @TestParameters( - "{expr: 'math.least([dyn(5.4), dyn(10), dyn(3u), dyn(5.0), dyn(3.5)])', expectedResult: 3}") - public void least_unsignedLongResult_withSignedLongType_success(String expr, long expectedResult) - throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - Object result = CEL_RUNTIME.createProgram(ast).eval(); - - assertThat(result).isEqualTo(expectedResult); - } - - @Test + "{expr: 'math.least(5.4, 10, 3u, 5.0, 9223372036854775807)', expectedResult: '3'}") + @TestParameters("{expr: 'math.least([5.4, 10, 3u, 5.0, 3.5])', expectedResult: '3'}") + @TestParameters( + "{expr: 'math.least([dyn(5.4), dyn(10), dyn(3u), dyn(5.0), dyn(3.5)])', expectedResult: '3'}") @TestParameters( "{expr: 'math.least(18446744073709551615u)', expectedResult: '18446744073709551615'}") @TestParameters("{expr: 'math.least(1u, 1.0)', expectedResult: '1'}") @@ -553,12 +489,12 @@ public void least_unsignedLongResult_withUnsignedLongType_success( CelCompiler celCompiler = CelCompilerFactory.standardCelCompilerBuilder() .setOptions(celOptions) - .addLibraries(CelExtensions.math(celOptions)) + .addLibraries(CelExtensions.math()) .build(); CelRuntime celRuntime = CelRuntimeFactory.standardCelRuntimeBuilder() .setOptions(celOptions) - .addLibraries(CelExtensions.math(celOptions)) + .addLibraries(CelExtensions.math()) .build(); CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst(); @@ -569,9 +505,9 @@ public void least_unsignedLongResult_withUnsignedLongType_success( @Test public void least_noArgs_throwsCompilationException() { + Assume.assumeFalse(isParseOnly); CelValidationException e = - assertThrows( - CelValidationException.class, () -> CEL_COMPILER.compile("math.least()").getAst()); + assertThrows(CelValidationException.class, () -> cel.compile("math.least()").getAst()); assertThat(e).hasMessageThat().contains("math.least() requires at least one argument"); } @@ -581,8 +517,9 @@ public void least_noArgs_throwsCompilationException() { @TestParameters("{expr: 'math.least({})'}") @TestParameters("{expr: 'math.least([])'}") public void least_invalidSingleArg_throwsCompilationException(String expr) { + Assume.assumeFalse(isParseOnly); CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("math.least() invalid single argument value"); } @@ -595,8 +532,9 @@ public void least_invalidSingleArg_throwsCompilationException(String expr) { @TestParameters("{expr: 'math.least([1, {}, 2])'}") @TestParameters("{expr: 'math.least([1, [], 2])'}") public void least_invalidArgs_throwsCompilationException(String expr) { + Assume.assumeFalse(isParseOnly); CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e) .hasMessageThat() @@ -610,19 +548,16 @@ public void least_invalidArgs_throwsCompilationException(String expr) { @TestParameters("{expr: 'math.least([1, dyn({}), 2])'}") @TestParameters("{expr: 'math.least([1, dyn([]), 2])'}") public void least_invalidDynArgs_throwsRuntimeException(String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> CEL_RUNTIME.createProgram(ast).eval()); + CelEvaluationException e = assertThrows(CelEvaluationException.class, () -> eval(expr)); - assertThat(e).hasMessageThat().contains("Function 'math_@min_list_dyn' failed with arg(s)"); + assertThat(e).hasMessageThat().contains("failed with arg(s)"); } @Test public void least_listVariableIsEmpty_throwsRuntimeException() throws Exception { CelCompiler celCompiler = CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.math(CEL_OPTIONS)) + .addLibraries(CelExtensions.math()) .addVar("listVar", ListType.create(SimpleType.INT)) .build(); CelAbstractSyntaxTree ast = celCompiler.compile("math.least(listVar)").getAst(); @@ -630,12 +565,9 @@ public void least_listVariableIsEmpty_throwsRuntimeException() throws Exception CelEvaluationException e = assertThrows( CelEvaluationException.class, - () -> - CEL_RUNTIME - .createProgram(ast) - .eval(ImmutableMap.of("listVar", ImmutableList.of()))); + () -> cel.createProgram(ast).eval(ImmutableMap.of("listVar", ImmutableList.of()))); - assertThat(e).hasMessageThat().contains("Function 'math_@min_list_dyn' failed with arg(s)"); + assertThat(e).hasMessageThat().contains("failed with arg(s)"); assertThat(e) .hasCauseThat() .hasMessageThat() @@ -645,24 +577,25 @@ public void least_listVariableIsEmpty_throwsRuntimeException() throws Exception @Test @TestParameters("{expr: '100.least(1) == 1'}") @TestParameters("{expr: 'dyn(100).least(1) == 1'}") - public void least_nonProtoNamespace_success(String expr) throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.math(CEL_OPTIONS)) + public void least_nonMathNamespace_success(String expr) throws Exception { + Cel cel = + runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.math()) + .addRuntimeLibraries(CelExtensions.math()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "least", CelOverloadDecl.newMemberOverload( "int_least", SimpleType.INT, SimpleType.INT, SimpleType.INT))) - .build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from("int_least", Long.class, Long.class, (arg1, arg2) -> arg2)) + CelFunctionBinding.fromOverloads( + "least", + CelFunctionBinding.from( + "int_least", Long.class, Long.class, (arg1, arg2) -> arg2))) .build(); - CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst(); - boolean result = (boolean) celRuntime.createProgram(ast).eval(); + boolean result = (boolean) eval(cel, expr); assertThat(result).isTrue(); } @@ -676,9 +609,9 @@ public void least_nonProtoNamespace_success(String expr) throws Exception { @TestParameters("{expr: 'math.isNaN(math.sign(0.0/0.0))', expectedResult: true}") @TestParameters("{expr: 'math.isNaN(math.sqrt(-4))', expectedResult: true}") public void isNaN_success(String expr, boolean expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -690,7 +623,7 @@ public void isNaN_success(String expr, boolean expectedResult) throws Exception @TestParameters("{expr: 'math.isNaN(1u)'}") public void isNaN_invalidArgs_throwsException(String expr) { CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'math.isNaN'"); } @@ -701,9 +634,9 @@ public void isNaN_invalidArgs_throwsException(String expr) { @TestParameters("{expr: 'math.isFinite(1.0/0.0)', expectedResult: false}") @TestParameters("{expr: 'math.isFinite(0.0/0.0)', expectedResult: false}") public void isFinite_success(String expr, boolean expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -715,7 +648,7 @@ public void isFinite_success(String expr, boolean expectedResult) throws Excepti @TestParameters("{expr: 'math.isFinite(1u)'}") public void isFinite_invalidArgs_throwsException(String expr) { CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'math.isFinite'"); } @@ -726,9 +659,9 @@ public void isFinite_invalidArgs_throwsException(String expr) { @TestParameters("{expr: 'math.isInf(0.0/0.0)', expectedResult: false}") @TestParameters("{expr: 'math.isInf(10.0)', expectedResult: false}") public void isInf_success(String expr, boolean expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -740,7 +673,7 @@ public void isInf_success(String expr, boolean expectedResult) throws Exception @TestParameters("{expr: 'math.isInf(1u)'}") public void isInf_invalidArgs_throwsException(String expr) { CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'math.isInf'"); } @@ -752,9 +685,9 @@ public void isInf_invalidArgs_throwsException(String expr) { @TestParameters("{expr: 'math.ceil(20.0)' , expectedResult: 20.0}") @TestParameters("{expr: 'math.ceil(0.0/0.0)' , expectedResult: NaN}") public void ceil_success(String expr, double expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -766,7 +699,7 @@ public void ceil_success(String expr, double expectedResult) throws Exception { @TestParameters("{expr: 'math.ceil(1u)'}") public void ceil_invalidArgs_throwsException(String expr) { CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'math.ceil'"); } @@ -777,9 +710,9 @@ public void ceil_invalidArgs_throwsException(String expr) { @TestParameters("{expr: 'math.floor(0.0/0.0)' , expectedResult: NaN}") @TestParameters("{expr: 'math.floor(50.0)' , expectedResult: 50.0}") public void floor_success(String expr, double expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -791,7 +724,7 @@ public void floor_success(String expr, double expectedResult) throws Exception { @TestParameters("{expr: 'math.floor(1u)'}") public void floor_invalidArgs_throwsException(String expr) { CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'math.floor'"); } @@ -806,9 +739,9 @@ public void floor_invalidArgs_throwsException(String expr) { @TestParameters("{expr: 'math.round(1.0/0.0)' , expectedResult: Infinity}") @TestParameters("{expr: 'math.round(-1.0/0.0)' , expectedResult: -Infinity}") public void round_success(String expr, double expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -820,7 +753,7 @@ public void round_success(String expr, double expectedResult) throws Exception { @TestParameters("{expr: 'math.round(1u)'}") public void round_invalidArgs_throwsException(String expr) { CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'math.round'"); } @@ -832,9 +765,9 @@ public void round_invalidArgs_throwsException(String expr) { @TestParameters("{expr: 'math.trunc(1.0/0.0)' , expectedResult: Infinity}") @TestParameters("{expr: 'math.trunc(-1.0/0.0)' , expectedResult: -Infinity}") public void trunc_success(String expr, double expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -846,7 +779,7 @@ public void trunc_success(String expr, double expectedResult) throws Exception { @TestParameters("{expr: 'math.trunc(1u)'}") public void trunc_invalidArgs_throwsException(String expr) { CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'math.trunc'"); } @@ -856,9 +789,9 @@ public void trunc_invalidArgs_throwsException(String expr) { @TestParameters("{expr: 'math.abs(-1657643)', expectedResult: 1657643}") @TestParameters("{expr: 'math.abs(-2147483648)', expectedResult: 2147483648}") public void abs_intResult_success(String expr, long expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -871,9 +804,9 @@ public void abs_intResult_success(String expr, long expectedResult) throws Excep @TestParameters("{expr: 'math.abs(1.0/0.0)' , expectedResult: Infinity}") @TestParameters("{expr: 'math.abs(-1.0/0.0)' , expectedResult: Infinity}") public void abs_doubleResult_success(String expr, double expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -883,7 +816,7 @@ public void abs_overflow_throwsException() { CelValidationException e = assertThrows( CelValidationException.class, - () -> CEL_COMPILER.compile("math.abs(-9223372036854775809)").getAst()); + () -> cel.compile("math.abs(-9223372036854775809)").getAst()); assertThat(e) .hasMessageThat() @@ -896,9 +829,9 @@ public void abs_overflow_throwsException() { @TestParameters("{expr: 'math.sign(-0)', expectedResult: 0}") @TestParameters("{expr: 'math.sign(11213)', expectedResult: 1}") public void sign_intResult_success(String expr, int expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -914,9 +847,9 @@ public void sign_intResult_success(String expr, int expectedResult) throws Excep @TestParameters("{expr: 'math.sign(1.0/0.0)' , expectedResult: 1.0}") @TestParameters("{expr: 'math.sign(-1.0/0.0)' , expectedResult: -1.0}") public void sign_doubleResult_success(String expr, double expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -926,7 +859,7 @@ public void sign_doubleResult_success(String expr, double expectedResult) throws @TestParameters("{expr: 'math.sign(\"\")'}") public void sign_invalidArgs_throwsException(String expr) { CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'math.sign'"); } @@ -938,9 +871,9 @@ public void sign_invalidArgs_throwsException(String expr) { "{expr: 'math.bitAnd(9223372036854775807,9223372036854775807)' , expectedResult:" + " 9223372036854775807}") public void bitAnd_signedInt_success(String expr, long expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -950,9 +883,9 @@ public void bitAnd_signedInt_success(String expr, long expectedResult) throws Ex @TestParameters("{expr: 'math.bitAnd(1u,3u)' , expectedResult: 1}") public void bitAnd_unSignedInt_success(String expr, UnsignedLong expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_UNSIGNED_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_UNSIGNED_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -963,7 +896,7 @@ public void bitAnd_unSignedInt_success(String expr, UnsignedLong expectedResult) @TestParameters("{expr: 'math.bitAnd(1)'}") public void bitAnd_invalidArgs_throwsException(String expr) { CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'math.bitAnd'"); } @@ -973,10 +906,7 @@ public void bitAnd_maxValArg_throwsException() { CelValidationException e = assertThrows( CelValidationException.class, - () -> - CEL_COMPILER - .compile("math.bitAnd(9223372036854775807,9223372036854775809)") - .getAst()); + () -> cel.compile("math.bitAnd(9223372036854775807,9223372036854775809)").getAst()); assertThat(e) .hasMessageThat() @@ -987,9 +917,9 @@ public void bitAnd_maxValArg_throwsException() { @TestParameters("{expr: 'math.bitOr(1,2)' , expectedResult: 3}") @TestParameters("{expr: 'math.bitOr(1,-1)' , expectedResult: -1}") public void bitOr_signedInt_success(String expr, long expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -998,9 +928,9 @@ public void bitOr_signedInt_success(String expr, long expectedResult) throws Exc @TestParameters("{expr: 'math.bitOr(1u,2u)' , expectedResult: 3}") @TestParameters("{expr: 'math.bitOr(1090u,3u)' , expectedResult: 1091}") public void bitOr_unSignedInt_success(String expr, UnsignedLong expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_UNSIGNED_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_UNSIGNED_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -1011,7 +941,7 @@ public void bitOr_unSignedInt_success(String expr, UnsignedLong expectedResult) @TestParameters("{expr: 'math.bitOr(1)'}") public void bitOr_invalidArgs_throwsException(String expr) { CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'math.bitOr'"); } @@ -1020,9 +950,9 @@ public void bitOr_invalidArgs_throwsException(String expr) { @TestParameters("{expr: 'math.bitXor(1,2)' , expectedResult: 3}") @TestParameters("{expr: 'math.bitXor(3,5)' , expectedResult: 6}") public void bitXor_signedInt_success(String expr, long expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -1032,9 +962,9 @@ public void bitXor_signedInt_success(String expr, long expectedResult) throws Ex @TestParameters("{expr: 'math.bitXor(3u, 5u)' , expectedResult: 6}") public void bitXor_unSignedInt_success(String expr, UnsignedLong expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_UNSIGNED_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_UNSIGNED_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -1045,7 +975,7 @@ public void bitXor_unSignedInt_success(String expr, UnsignedLong expectedResult) @TestParameters("{expr: 'math.bitXor(1)'}") public void bitXor_invalidArgs_throwsException(String expr) { CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'math.bitXor'"); } @@ -1055,9 +985,9 @@ public void bitXor_invalidArgs_throwsException(String expr) { @TestParameters("{expr: 'math.bitNot(0)' , expectedResult: -1}") @TestParameters("{expr: 'math.bitNot(-1)' , expectedResult: 0}") public void bitNot_signedInt_success(String expr, long expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -1067,9 +997,9 @@ public void bitNot_signedInt_success(String expr, long expectedResult) throws Ex @TestParameters("{expr: 'math.bitNot(12310u)' , expectedResult: 18446744073709539305}") public void bitNot_unSignedInt_success(String expr, UnsignedLong expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_UNSIGNED_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_UNSIGNED_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -1080,7 +1010,7 @@ public void bitNot_unSignedInt_success(String expr, UnsignedLong expectedResult) @TestParameters("{expr: 'math.bitNot(\"\")'}") public void bitNot_invalidArgs_throwsException(String expr) { CelValidationException e = - assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("found no matching overload for 'math.bitNot'"); } @@ -1090,9 +1020,9 @@ public void bitNot_invalidArgs_throwsException(String expr) { @TestParameters("{expr: 'math.bitShiftLeft(12121, 11)' , expectedResult: 24823808}") @TestParameters("{expr: 'math.bitShiftLeft(-1, 64)' , expectedResult: 0}") public void bitShiftLeft_signedInt_success(String expr, long expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -1103,9 +1033,9 @@ public void bitShiftLeft_signedInt_success(String expr, long expectedResult) thr @TestParameters("{expr: 'math.bitShiftLeft(1u, 65)' , expectedResult: 0}") public void bitShiftLeft_unSignedInt_success(String expr, UnsignedLong expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_UNSIGNED_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_UNSIGNED_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -1114,11 +1044,10 @@ public void bitShiftLeft_unSignedInt_success(String expr, UnsignedLong expectedR @TestParameters("{expr: 'math.bitShiftLeft(1, -2)'}") @TestParameters("{expr: 'math.bitShiftLeft(1u, -2)'}") public void bitShiftLeft_invalidArgs_throwsException(String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); CelEvaluationException e = - assertThrows( - CelEvaluationException.class, () -> CEL_UNSIGNED_RUNTIME.createProgram(ast).eval()); + assertThrows(CelEvaluationException.class, () -> cel.createProgram(ast).eval()); assertThat(e).hasMessageThat().contains("evaluation error"); assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); @@ -1131,9 +1060,9 @@ public void bitShiftLeft_invalidArgs_throwsException(String expr) throws Excepti @TestParameters("{expr: 'math.bitShiftRight(12121, 11)' , expectedResult: 5}") @TestParameters("{expr: 'math.bitShiftRight(-1, 64)' , expectedResult: 0}") public void bitShiftRight_signedInt_success(String expr, long expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); - Object result = CEL_RUNTIME.createProgram(ast).eval(); + Object result = cel.createProgram(ast).eval(); assertThat(result).isEqualTo(expectedResult); } @@ -1144,9 +1073,7 @@ public void bitShiftRight_signedInt_success(String expr, long expectedResult) th @TestParameters("{expr: 'math.bitShiftRight(1u, 65)' , expectedResult: 0}") public void bitShiftRight_unSignedInt_success(String expr, UnsignedLong expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_UNSIGNED_COMPILER.compile(expr).getAst(); - - Object result = CEL_UNSIGNED_RUNTIME.createProgram(ast).eval(); + Object result = eval(expr); assertThat(result).isEqualTo(expectedResult); } @@ -1155,11 +1082,7 @@ public void bitShiftRight_unSignedInt_success(String expr, UnsignedLong expected @TestParameters("{expr: 'math.bitShiftRight(23111u, -212)'}") @TestParameters("{expr: 'math.bitShiftRight(23, -212)'}") public void bitShiftRight_invalidArgs_throwsException(String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - - CelEvaluationException e = - assertThrows( - CelEvaluationException.class, () -> CEL_UNSIGNED_RUNTIME.createProgram(ast).eval()); + CelEvaluationException e = assertThrows(CelEvaluationException.class, () -> eval(expr)); assertThat(e).hasMessageThat().contains("evaluation error"); assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); @@ -1174,10 +1097,26 @@ public void bitShiftRight_invalidArgs_throwsException(String expr) throws Except @TestParameters("{expr: 'math.sqrt(1.0/0.0)', expectedResult: Infinity}") @TestParameters("{expr: 'math.sqrt(-1)', expectedResult: NaN}") public void sqrt_success(String expr, double expectedResult) throws Exception { - CelAbstractSyntaxTree ast = CEL_UNSIGNED_COMPILER.compile(expr).getAst(); - - Object result = CEL_UNSIGNED_RUNTIME.createProgram(ast).eval(); + Object result = eval(expr); assertThat(result).isEqualTo(expectedResult); } + + private Object eval(Cel cel, String expression, Map variables) throws Exception { + CelAbstractSyntaxTree ast; + if (isParseOnly) { + ast = cel.parse(expression).getAst(); + } else { + ast = cel.compile(expression).getAst(); + } + return cel.createProgram(ast).eval(variables); + } + + private Object eval(Cel celInstance, String expression) throws Exception { + return eval(celInstance, expression, ImmutableMap.of()); + } + + private Object eval(String expression) throws Exception { + return eval(this.cel, expression, ImmutableMap.of()); + } } From e7dff9b38fc6f5bb2665f0cd4fb7e27a14672bc4 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 22 Apr 2026 14:44:11 -0700 Subject: [PATCH 44/66] Create a base class for extension tests PiperOrigin-RevId: 904054337 --- .../main/java/dev/cel/extensions/BUILD.bazel | 1 + .../test/java/dev/cel/extensions/BUILD.bazel | 1 + .../extensions/CelBindingsExtensionsTest.java | 51 ++------- .../CelComprehensionsExtensionsTest.java | 55 +++------- .../extensions/CelEncoderExtensionsTest.java | 39 ++----- .../cel/extensions/CelExtensionTestBase.java | 66 ++++++++++++ .../extensions/CelListsExtensionsTest.java | 46 +++----- .../extensions/CelProtoExtensionsTest.java | 44 +++----- .../extensions/CelRegexExtensionsTest.java | 36 ++----- .../cel/extensions/CelSetsExtensionsTest.java | 85 +++++---------- .../extensions/CelStringExtensionsTest.java | 102 ++++++------------ 11 files changed, 197 insertions(+), 329 deletions(-) create mode 100644 extensions/src/test/java/dev/cel/extensions/CelExtensionTestBase.java diff --git a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel index f8e4bfc8c..454b2a2fd 100644 --- a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel @@ -122,6 +122,7 @@ java_library( ":extension_library", "//checker:checker_builder", "//common:compiler_common", + "//common:options", "//common/ast", "//common/exceptions:numeric_overflow", "//common/internal:comparison_functions", diff --git a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel index 19fd3657e..eed240317 100644 --- a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel @@ -12,6 +12,7 @@ java_library( "//bundle:cel", "//bundle:cel_experimental_factory", "//common:cel_ast", + "//common:cel_exception", "//common:compiler_common", "//common:container", "//common:options", diff --git a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java index b87967d0e..00fcad473 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java @@ -23,7 +23,6 @@ import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; -import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; import dev.cel.common.CelOverloadDecl; @@ -36,36 +35,24 @@ import dev.cel.parser.CelStandardMacro; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelFunctionBinding; -import dev.cel.testing.CelRuntimeFlavor; import java.util.Arrays; import java.util.List; -import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; -import org.junit.Assume; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) -public final class CelBindingsExtensionsTest { - - @TestParameter public CelRuntimeFlavor runtimeFlavor; - @TestParameter public boolean isParseOnly; - - private Cel cel; - - @Before - public void setUp() { - // Legacy runtime does not support parsed-only evaluation mode. - Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); - cel = - runtimeFlavor - .builder() - .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addCompilerLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) - .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) - .build(); +public final class CelBindingsExtensionsTest extends CelExtensionTestBase { + + @Override + protected Cel newCelEnv() { + return runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addCompilerLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) + .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) + .build(); } @Test @@ -331,21 +318,5 @@ public void lazyBinding_boundAttributeInNestedComprehension() throws Exception { assertThat(invocation.get()).isEqualTo(1); } - private Object eval(Cel cel, String expression) throws Exception { - return eval(cel, expression, ImmutableMap.of()); - } - - private Object eval(Cel cel, String expression, Map variables) throws Exception { - CelAbstractSyntaxTree ast; - if (isParseOnly) { - ast = cel.parse(expression).getAst(); - } else { - ast = cel.compile(expression).getAst(); - } - return cel.createProgram(ast).eval(variables); - } - private Object eval(String expression) throws Exception { - return eval(this.cel, expression, ImmutableMap.of()); - } } diff --git a/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java index 207178cfe..42dc3e07d 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelComprehensionsExtensionsTest.java @@ -19,7 +19,6 @@ import static org.junit.Assert.assertThrows; import com.google.common.base.Throwables; -import com.google.common.collect.ImmutableMap; import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; @@ -40,15 +39,13 @@ import dev.cel.parser.CelUnparserFactory; import dev.cel.runtime.CelEvaluationException; import dev.cel.testing.CelRuntimeFlavor; -import java.util.Map; import org.junit.Assume; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; /** Test for {@link CelExtensions#comprehensions()} */ @RunWith(TestParameterInjector.class) -public class CelComprehensionsExtensionsTest { +public class CelComprehensionsExtensionsTest extends CelExtensionTestBase { private static final CelOptions CEL_OPTIONS = CelOptions.current() @@ -57,29 +54,21 @@ public class CelComprehensionsExtensionsTest { .populateMacroCalls(true) .build(); - @TestParameter public CelRuntimeFlavor runtimeFlavor; - @TestParameter public boolean isParseOnly; - - private Cel cel; - - @Before - public void setUp() { - // Legacy runtime does not support parsed-only evaluation mode. - Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); - this.cel = - runtimeFlavor - .builder() - .setOptions(CEL_OPTIONS) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addCompilerLibraries(CelExtensions.comprehensions()) - .addCompilerLibraries(CelExtensions.lists()) - .addCompilerLibraries(CelExtensions.strings()) - .addCompilerLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) - .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) - .addRuntimeLibraries(CelExtensions.lists()) - .addRuntimeLibraries(CelExtensions.strings()) - .addRuntimeLibraries(CelExtensions.comprehensions()) - .build(); + @Override + protected Cel newCelEnv() { + return runtimeFlavor + .builder() + .setOptions(CEL_OPTIONS) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addCompilerLibraries(CelExtensions.comprehensions()) + .addCompilerLibraries(CelExtensions.lists()) + .addCompilerLibraries(CelExtensions.strings()) + .addCompilerLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) + .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) + .addRuntimeLibraries(CelExtensions.lists()) + .addRuntimeLibraries(CelExtensions.strings()) + .addRuntimeLibraries(CelExtensions.comprehensions()) + .build(); } private static final CelUnparser UNPARSER = CelUnparserFactory.newUnparser(); @@ -376,17 +365,5 @@ public void mutableMapValue_select_missingKeyException() throws Exception { assertThat(e).hasCauseThat().hasMessageThat().contains("key 'b' is not present in map."); } - private Object eval(String expression) throws Exception { - return eval(this.cel, expression, ImmutableMap.of()); - } - private Object eval(Cel cel, String expression, Map variables) throws Exception { - CelAbstractSyntaxTree ast; - if (isParseOnly) { - ast = cel.parse(expression).getAst(); - } else { - ast = cel.compile(expression).getAst(); - } - return cel.createProgram(ast).eval(variables); - } } diff --git a/extensions/src/test/java/dev/cel/extensions/CelEncoderExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelEncoderExtensionsTest.java index b0a501ddb..afeaa9105 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelEncoderExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelEncoderExtensionsTest.java @@ -19,44 +19,32 @@ import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableMap; -import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import dev.cel.bundle.Cel; -import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; import dev.cel.common.CelValidationException; import dev.cel.common.types.SimpleType; import dev.cel.common.values.CelByteString; import dev.cel.runtime.CelEvaluationException; -import dev.cel.testing.CelRuntimeFlavor; import org.junit.Assume; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) -public class CelEncoderExtensionsTest { +public class CelEncoderExtensionsTest extends CelExtensionTestBase { private static final CelOptions CEL_OPTIONS = CelOptions.current().enableHeterogeneousNumericComparisons(true).build(); - @TestParameter public CelRuntimeFlavor runtimeFlavor; - @TestParameter public boolean isParseOnly; - - private Cel cel; - - @Before - public void setUp() { - // Legacy runtime does not support parsed-only evaluation mode. - Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); - this.cel = - runtimeFlavor - .builder() - .setOptions(CEL_OPTIONS) - .addCompilerLibraries(CelExtensions.encoders(CEL_OPTIONS)) - .addRuntimeLibraries(CelExtensions.encoders(CEL_OPTIONS)) - .addVar("stringVar", SimpleType.STRING) - .build(); + @Override + protected Cel newCelEnv() { + return runtimeFlavor + .builder() + .setOptions(CEL_OPTIONS) + .addCompilerLibraries(CelExtensions.encoders(CEL_OPTIONS)) + .addRuntimeLibraries(CelExtensions.encoders(CEL_OPTIONS)) + .addVar("stringVar", SimpleType.STRING) + .build(); } @Test @@ -132,12 +120,5 @@ public void decode_malformedBase64Char_throwsEvaluationException() throws Except assertThat(e).hasCauseThat().hasMessageThat().contains("Illegal base64 character"); } - private Object eval(String expr) throws Exception { - return eval(expr, ImmutableMap.of()); - } - private Object eval(String expr, ImmutableMap vars) throws Exception { - CelAbstractSyntaxTree ast = isParseOnly ? cel.parse(expr).getAst() : cel.compile(expr).getAst(); - return cel.createProgram(ast).eval(vars); - } } diff --git a/extensions/src/test/java/dev/cel/extensions/CelExtensionTestBase.java b/extensions/src/test/java/dev/cel/extensions/CelExtensionTestBase.java new file mode 100644 index 000000000..c80ee38b6 --- /dev/null +++ b/extensions/src/test/java/dev/cel/extensions/CelExtensionTestBase.java @@ -0,0 +1,66 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.extensions; + +import com.google.common.collect.ImmutableMap; +import com.google.testing.junit.testparameterinjector.TestParameter; +import dev.cel.bundle.Cel; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelException; +import dev.cel.testing.CelRuntimeFlavor; +import java.util.Map; +import org.junit.Assume; +import org.junit.Before; + +/** + * Abstract base class for extension tests to facilitate executing tests with both legacy and + * planner runtime, along with parsed-only and checked expression evaluations for the planner. + */ +abstract class CelExtensionTestBase { + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; + + @Before + public void setUpBase() { + // Legacy runtime does not support parsed-only evaluation. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + this.cel = newCelEnv(); + } + + protected Cel cel; + + /** + * Subclasses must implement this to provide a Cel instance configured with the specific + * extensions being tested. + */ + protected abstract Cel newCelEnv(); + + protected Object eval(String expr) throws CelException { + return eval(cel, expr, ImmutableMap.of()); + } + + protected Object eval(String expr, Map variables) throws CelException { + return eval(cel, expr, variables); + } + + protected Object eval(Cel cel, String expr) throws CelException { + return eval(cel, expr, ImmutableMap.of()); + } + + protected Object eval(Cel cel, String expr, Map variables) throws CelException { + CelAbstractSyntaxTree ast = isParseOnly ? cel.parse(expr).getAst() : cel.compile(expr).getAst(); + return cel.createProgram(ast).eval(variables); + } +} diff --git a/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java index f36d90e2d..f5536da4e 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java @@ -19,12 +19,9 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSortedMultiset; import com.google.common.collect.ImmutableSortedSet; -import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; -import dev.cel.bundle.CelBuilder; -import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelValidationException; import dev.cel.common.CelValidationResult; @@ -32,25 +29,24 @@ import dev.cel.expr.conformance.test.SimpleTest; import dev.cel.parser.CelStandardMacro; import dev.cel.runtime.CelEvaluationException; -import dev.cel.testing.CelRuntimeFlavor; -import java.util.Map; import org.junit.Assume; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) -public class CelListsExtensionsTest { - @TestParameter public CelRuntimeFlavor runtimeFlavor; - @TestParameter public boolean isParseOnly; +public class CelListsExtensionsTest extends CelExtensionTestBase { - private Cel cel; - - @Before - public void setUp() { - // Legacy runtime does not support parsed-only evaluation mode. - Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); - this.cel = setupEnv(runtimeFlavor.builder()); + @Override + protected Cel newCelEnv() { + return runtimeFlavor + .builder() + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addCompilerLibraries(CelExtensions.lists()) + .addRuntimeLibraries(CelExtensions.lists()) + .setContainer(CelContainer.ofName("cel.expr.conformance.test")) + .addMessageTypes(SimpleTest.getDescriptor()) + .addVar("non_list", SimpleType.DYN) + .build(); } @Test @@ -322,23 +318,5 @@ public void sortBy_throws_evaluationException(String expression, String expected .contains(expectedError); } - private static Cel setupEnv(CelBuilder celBuilder) { - return celBuilder - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addCompilerLibraries(CelExtensions.lists()) - .addRuntimeLibraries(CelExtensions.lists()) - .setContainer(CelContainer.ofName("cel.expr.conformance.test")) - .addMessageTypes(SimpleTest.getDescriptor()) - .addVar("non_list", SimpleType.DYN) - .build(); - } - - private Object eval(Cel cel, String expr) throws Exception { - return eval(cel, expr, ImmutableMap.of()); - } - private Object eval(Cel cel, String expr, Map vars) throws Exception { - CelAbstractSyntaxTree ast = isParseOnly ? cel.parse(expr).getAst() : cel.compile(expr).getAst(); - return cel.createProgram(ast).eval(vars); - } } diff --git a/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java index 2e55619db..f46ea5b1a 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java @@ -26,7 +26,6 @@ import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; -import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; @@ -41,34 +40,23 @@ import dev.cel.parser.CelMacro; import dev.cel.parser.CelStandardMacro; import dev.cel.runtime.CelFunctionBinding; -import dev.cel.testing.CelRuntimeFlavor; -import java.util.Map; import org.junit.Assume; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) -public final class CelProtoExtensionsTest { - - @TestParameter public CelRuntimeFlavor runtimeFlavor; - @TestParameter public boolean isParseOnly; - - private Cel cel; - - @Before - public void setUp() { - // Legacy runtime does not support parsed-only evaluation mode. - Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); - this.cel = - runtimeFlavor - .builder() - .addCompilerLibraries(CelExtensions.protos()) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addFileTypes(TestAllTypesExtensions.getDescriptor()) - .addVar("msg", StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes")) - .setContainer(CelContainer.ofName("cel.expr.conformance.proto2")) - .build(); +public final class CelProtoExtensionsTest extends CelExtensionTestBase { + + @Override + protected Cel newCelEnv() { + return runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.protos()) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addFileTypes(TestAllTypesExtensions.getDescriptor()) + .addVar("msg", StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes")) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto2")) + .build(); } private static final TestAllTypes PACKAGE_SCOPED_EXT_MSG = @@ -342,13 +330,5 @@ public void parseErrors(@TestParameter ParseErrorTestCase testcase) { assertThat(e).hasMessageThat().isEqualTo(testcase.error); } - private Object eval(String expression, Map variables) throws Exception { - return eval(this.cel, expression, variables); - } - private Object eval(Cel cel, String expression, Map variables) throws Exception { - CelAbstractSyntaxTree ast = - this.isParseOnly ? cel.parse(expression).getAst() : cel.compile(expression).getAst(); - return cel.createProgram(ast).eval(variables); - } } diff --git a/extensions/src/test/java/dev/cel/extensions/CelRegexExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelRegexExtensionsTest.java index 924344b25..97d0cc90c 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelRegexExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelRegexExtensionsTest.java @@ -21,35 +21,23 @@ import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; -import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; import dev.cel.runtime.CelEvaluationException; -import dev.cel.testing.CelRuntimeFlavor; import java.util.Optional; -import org.junit.Assume; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) -public final class CelRegexExtensionsTest { - - @TestParameter public CelRuntimeFlavor runtimeFlavor; - @TestParameter public boolean isParseOnly; - - private Cel cel; - - @Before - public void setUp() { - // Legacy runtime does not support parsed-only evaluation mode. - Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); - this.cel = - runtimeFlavor - .builder() - .addCompilerLibraries(CelExtensions.regex()) - .addRuntimeLibraries(CelExtensions.regex()) - .build(); +public final class CelRegexExtensionsTest extends CelExtensionTestBase { + + @Override + protected Cel newCelEnv() { + return runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.regex()) + .addRuntimeLibraries(CelExtensions.regex()) + .build(); } @@ -276,9 +264,5 @@ public void extractAll_multipleCaptureGroups_throwsException(String target, Stri .contains("Regular expression has more than one capturing group:"); } - private Object eval(String expr) throws Exception { - CelAbstractSyntaxTree ast = - isParseOnly ? cel.parse(expr).getAst() : cel.compile(expr).getAst(); - return cel.createProgram(ast).eval(); - } + } diff --git a/extensions/src/test/java/dev/cel/extensions/CelSetsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelSetsExtensionsTest.java index 9007bba2e..091d456f5 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelSetsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelSetsExtensionsTest.java @@ -23,7 +23,6 @@ import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; -import dev.cel.bundle.CelBuilder; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelFunctionDecl; @@ -39,27 +38,39 @@ import dev.cel.runtime.CelRuntime; import dev.cel.testing.CelRuntimeFlavor; import java.util.List; -import java.util.Map; import org.junit.Assume; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) -public final class CelSetsExtensionsTest { +public final class CelSetsExtensionsTest extends CelExtensionTestBase { private static final CelOptions CEL_OPTIONS = CelOptions.current().enableHeterogeneousNumericComparisons(true).build(); - @TestParameter public CelRuntimeFlavor runtimeFlavor; - @TestParameter public boolean isParseOnly; - - private Cel cel; - - @Before - public void setUp() { - // Legacy runtime does not support parsed-only evaluation mode. - Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); - this.cel = setupEnv(runtimeFlavor.builder()); + @Override + protected Cel newCelEnv() { + return runtimeFlavor + .builder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .setOptions(CEL_OPTIONS) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .addCompilerLibraries(CelExtensions.sets(CEL_OPTIONS)) + .addRuntimeLibraries(CelExtensions.sets(CEL_OPTIONS)) + .addVar("list", ListType.create(SimpleType.INT)) + .addVar("subList", ListType.create(SimpleType.INT)) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "new_int", + CelOverloadDecl.newGlobalOverload("new_int_int64", SimpleType.INT, SimpleType.INT))) + .addFunctionBindings( + CelFunctionBinding.fromOverloads( + "new_int", + CelFunctionBinding.from( + "new_int_int64", + Long.class, + // Intentionally return java.lang.Integer to test primitive type adaptation + Math::toIntExact))) + .build(); } @Test @@ -375,7 +386,7 @@ public void setsExtension_containsFunctionSubset_succeeds() throws Exception { .addRuntimeLibraries(setsExtensions) .build(); - Object evaluatedResult = eval(cel, "sets.contains([1, 2], [2])", ImmutableMap.of()); + Object evaluatedResult = eval(cel, "sets.contains([1, 2], [2])"); assertThat(evaluatedResult).isEqualTo(true); } @@ -391,7 +402,7 @@ public void setsExtension_equivalentFunctionSubset_succeeds() throws Exception { .addRuntimeLibraries(setsExtensions) .build(); - Object evaluatedResult = eval(cel, "sets.equivalent([1, 1], [1])", ImmutableMap.of()); + Object evaluatedResult = eval(cel, "sets.equivalent([1, 1], [1])"); assertThat(evaluatedResult).isEqualTo(true); } @@ -407,7 +418,7 @@ public void setsExtension_intersectsFunctionSubset_succeeds() throws Exception { .addRuntimeLibraries(setsExtensions) .build(); - Object evaluatedResult = eval(cel, "sets.intersects([1, 1], [1])", ImmutableMap.of()); + Object evaluatedResult = eval(cel, "sets.intersects([1, 1], [1])"); assertThat(evaluatedResult).isEqualTo(true); } @@ -450,45 +461,5 @@ public void setsExtension_evaluateUnallowedFunction_throws() throws Exception { } } - private Object eval(Cel cel, String expression, Map variables) throws Exception { - CelAbstractSyntaxTree ast; - if (isParseOnly) { - ast = cel.parse(expression).getAst(); - } else { - ast = cel.compile(expression).getAst(); - } - return cel.createProgram(ast).eval(variables); - } - - private Object eval(String expression) throws Exception { - return eval(this.cel, expression, ImmutableMap.of()); - } - - private Object eval(String expression, Map variables) throws Exception { - return eval(this.cel, expression, variables); - } - private static Cel setupEnv(CelBuilder celBuilder) { - return celBuilder - .addMessageTypes(TestAllTypes.getDescriptor()) - .setOptions(CEL_OPTIONS) - .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) - .addCompilerLibraries(CelExtensions.sets(CEL_OPTIONS)) - .addRuntimeLibraries(CelExtensions.sets(CEL_OPTIONS)) - .addVar("list", ListType.create(SimpleType.INT)) - .addVar("subList", ListType.create(SimpleType.INT)) - .addFunctionDeclarations( - CelFunctionDecl.newFunctionDeclaration( - "new_int", - CelOverloadDecl.newGlobalOverload("new_int_int64", SimpleType.INT, SimpleType.INT))) - .addFunctionBindings( - CelFunctionBinding.fromOverloads( - "new_int", - CelFunctionBinding.from( - "new_int_int64", - Long.class, - // Intentionally return java.lang.Integer to test primitive type adaptation - Math::toIntExact))) - .build(); - } } diff --git a/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java index e7542b7b7..4b242ddcd 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelStringExtensionsTest.java @@ -33,40 +33,29 @@ import dev.cel.extensions.CelStringExtensions.Function; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelRuntime; -import dev.cel.testing.CelRuntimeFlavor; import java.util.List; -import java.util.Map; import org.junit.Assume; -import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) -public final class CelStringExtensionsTest { - - @TestParameter public CelRuntimeFlavor runtimeFlavor; - @TestParameter public boolean isParseOnly; - - private Cel cel; - - @Before - public void setUp() { - // Legacy runtime does not support parsed-only evaluation mode. - Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); - this.cel = - runtimeFlavor - .builder() - .addCompilerLibraries(CelExtensions.strings()) - .addRuntimeLibraries(CelExtensions.strings()) - .addVar("s", SimpleType.STRING) - .addVar("separator", SimpleType.STRING) - .addVar("index", SimpleType.INT) - .addVar("offset", SimpleType.INT) - .addVar("indexOfParam", SimpleType.STRING) - .addVar("beginIndex", SimpleType.INT) - .addVar("endIndex", SimpleType.INT) - .addVar("limit", SimpleType.INT) - .build(); +public final class CelStringExtensionsTest extends CelExtensionTestBase { + + @Override + protected Cel newCelEnv() { + return runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.strings()) + .addRuntimeLibraries(CelExtensions.strings()) + .addVar("s", SimpleType.STRING) + .addVar("separator", SimpleType.STRING) + .addVar("index", SimpleType.INT) + .addVar("offset", SimpleType.INT) + .addVar("indexOfParam", SimpleType.STRING) + .addVar("beginIndex", SimpleType.INT) + .addVar("endIndex", SimpleType.INT) + .addVar("limit", SimpleType.INT) + .build(); } @Test @@ -388,13 +377,10 @@ public void split_withLimit_separatorIsNonString_throwsException() { @Test public void split_withLimitOverflow_throwsException() throws Exception { + ImmutableMap variables = ImmutableMap.of("limit", 2147483648L); // INT_MAX + 1 CelEvaluationException exception = assertThrows( - CelEvaluationException.class, - () -> - eval( - "'test'.split('', limit)", - ImmutableMap.of("limit", 2147483648L))); // INT_MAX + 1 + CelEvaluationException.class, () -> eval("'test'.split('', limit)", variables)); assertThat(exception) .hasMessageThat() @@ -454,13 +440,10 @@ public void substring_beginAndEndIndex_unicode_success( @TestParameters("{string: '', beginIndex: 2}") public void substring_beginIndexOutOfRange_ascii_throwsException(String string, int beginIndex) throws Exception { + ImmutableMap variables = ImmutableMap.of("s", string, "beginIndex", beginIndex); CelEvaluationException exception = assertThrows( - CelEvaluationException.class, - () -> - eval( - "s.substring(beginIndex)", - ImmutableMap.of("s", string, "beginIndex", beginIndex))); + CelEvaluationException.class, () -> eval("s.substring(beginIndex)", variables)); String exceptionMessage = String.format( @@ -478,13 +461,10 @@ public void substring_beginIndexOutOfRange_ascii_throwsException(String string, @TestParameters("{string: '😁가나', beginIndex: 4, uniqueCharCount: 3}") public void substring_beginIndexOutOfRange_unicode_throwsException( String string, int beginIndex, int uniqueCharCount) throws Exception { + ImmutableMap variables = ImmutableMap.of("s", string, "beginIndex", beginIndex); CelEvaluationException exception = assertThrows( - CelEvaluationException.class, - () -> - eval( - "s.substring(beginIndex)", - ImmutableMap.of("s", string, "beginIndex", beginIndex))); + CelEvaluationException.class, () -> eval("s.substring(beginIndex)", variables)); String exceptionMessage = String.format( @@ -501,13 +481,12 @@ public void substring_beginIndexOutOfRange_unicode_throwsException( @TestParameters("{string: '😁😑😦', beginIndex: 2, endIndex: 1}") public void substring_beginAndEndIndexOutOfRange_throwsException( String string, int beginIndex, int endIndex) throws Exception { + ImmutableMap variables = + ImmutableMap.of("s", string, "beginIndex", beginIndex, "endIndex", endIndex); CelEvaluationException exception = assertThrows( CelEvaluationException.class, - () -> - eval( - "s.substring(beginIndex, endIndex)", - ImmutableMap.of("s", string, "beginIndex", beginIndex, "endIndex", endIndex))); + () -> eval("s.substring(beginIndex, endIndex)", variables)); String exceptionMessage = String.format("substring failure: Range [%d, %d) out of bounds", beginIndex, endIndex); @@ -516,13 +495,11 @@ public void substring_beginAndEndIndexOutOfRange_throwsException( @Test public void substring_beginIndexOverflow_throwsException() throws Exception { + ImmutableMap variables = + ImmutableMap.of("beginIndex", 2147483648L); // INT_MAX + 1 CelEvaluationException exception = assertThrows( - CelEvaluationException.class, - () -> - eval( - "'abcd'.substring(beginIndex)", - ImmutableMap.of("beginIndex", 2147483648L))); // INT_MAX + 1 + CelEvaluationException.class, () -> eval("'abcd'.substring(beginIndex)", variables)); assertThat(exception) .hasMessageThat() @@ -1381,10 +1358,7 @@ public void stringExtension_functionSubset_success() throws Exception { .build(); Object evaluatedResult = - eval( - customCel, - "'test'.substring(2) == 'st' && 'hello'.charAt(1) == 'e'", - ImmutableMap.of()); + eval(customCel, "'test'.substring(2) == 'st' && 'hello'.charAt(1) == 'e'"); assertThat(evaluatedResult).isEqualTo(true); } @@ -1499,21 +1473,5 @@ public void stringExtension_evaluateUnallowedFunction_throws() throws Exception assertThrows(CelEvaluationException.class, () -> customRuntimeCel.createProgram(ast).eval()); } - private Object eval(Cel cel, String expression, Map variables) throws Exception { - CelAbstractSyntaxTree ast; - if (isParseOnly) { - ast = cel.parse(expression).getAst(); - } else { - ast = cel.compile(expression).getAst(); - } - return cel.createProgram(ast).eval(variables); - } - - private Object eval(String expression) throws Exception { - return eval(this.cel, expression, ImmutableMap.of()); - } - private Object eval(String expression, Map variables) throws Exception { - return eval(this.cel, expression, variables); - } } From 61a01d800b24dfb79bc7035119166d2822f3ff1b Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 22 Apr 2026 15:44:18 -0700 Subject: [PATCH 45/66] Change function failures (dispatch / overload match) to always include the function name PiperOrigin-RevId: 904088792 --- .../extensions/CelListsExtensionsTest.java | 3 +- .../src/main/java/dev/cel/runtime/BUILD.bazel | 4 +- .../dev/cel/runtime/CelFunctionBinding.java | 1 + .../cel/runtime/CelLateFunctionBindings.java | 5 +++ .../dev/cel/runtime/CelResolvedOverload.java | 26 ++++++++--- .../java/dev/cel/runtime/CelRuntimeImpl.java | 10 +++++ .../dev/cel/runtime/CelRuntimeLegacyImpl.java | 10 +++++ .../dev/cel/runtime/DefaultDispatcher.java | 23 +++++++--- .../dev/cel/runtime/FunctionBindingImpl.java | 35 +++++++++++++-- .../runtime/InternalCelFunctionBinding.java | 29 ++++++++++++ .../java/dev/cel/runtime/LiteRuntimeImpl.java | 15 +++++-- .../dev/cel/runtime/planner/EvalBinary.java | 9 +++- .../dev/cel/runtime/planner/EvalHelpers.java | 19 ++++++-- .../runtime/planner/EvalLateBoundCall.java | 2 +- .../dev/cel/runtime/planner/EvalUnary.java | 8 +++- .../cel/runtime/planner/EvalVarArgsCall.java | 8 +++- .../cel/runtime/planner/EvalZeroArity.java | 16 +++++-- .../cel/runtime/planner/ProgramPlanner.java | 14 ++++-- .../src/test/java/dev/cel/runtime/BUILD.bazel | 1 + .../cel/runtime/CelResolvedOverloadTest.java | 45 ++++++++++++------- .../cel/runtime/DefaultDispatcherTest.java | 12 ++++- .../cel/runtime/DefaultInterpreterTest.java | 10 ++++- .../cel/runtime/PlannerInterpreterTest.java | 15 ++++--- .../runtime/planner/ProgramPlannerTest.java | 18 +++----- .../planner_optional_errors.baseline | 5 +++ 25 files changed, 266 insertions(+), 77 deletions(-) create mode 100644 runtime/src/main/java/dev/cel/runtime/InternalCelFunctionBinding.java create mode 100644 runtime/src/test/resources/planner_optional_errors.baseline diff --git a/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java index f5536da4e..4520f81ba 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java @@ -29,6 +29,7 @@ import dev.cel.expr.conformance.test.SimpleTest; import dev.cel.parser.CelStandardMacro; import dev.cel.runtime.CelEvaluationException; +import dev.cel.testing.CelRuntimeFlavor; import org.junit.Assume; import org.junit.Test; import org.junit.runner.RunWith; @@ -143,7 +144,7 @@ public void flatten_negativeDepth_throws() { CelEvaluationException e = assertThrows(CelEvaluationException.class, () -> eval(cel, "[1,2,3,4].flatten(-1)")); - if (isParseOnly) { + if (runtimeFlavor.equals(CelRuntimeFlavor.PLANNER)) { assertThat(e) .hasMessageThat() .contains("evaluation error at :17: Function 'flatten' failed"); diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index 6f0607de4..ef0ac71d4 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -53,6 +53,7 @@ LITE_PROGRAM_IMPL_SOURCES = [ FUNCTION_BINDING_SOURCES = [ "CelFunctionBinding.java", "FunctionBindingImpl.java", + "InternalCelFunctionBinding.java", ] # keep sorted @@ -740,6 +741,7 @@ java_library( deps = [ ":evaluation_exception", ":function_overload", + "//common/annotations", "//common/exceptions:overload_not_found", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -754,6 +756,7 @@ cel_android_library( deps = [ ":evaluation_exception", ":function_overload_android", + "//common/annotations", "//common/exceptions:overload_not_found", "@maven//:com_google_errorprone_error_prone_annotations", "@maven_android//:com_google_guava_guava", @@ -890,7 +893,6 @@ java_library( "//common/types:type_providers", "//common/values:cel_value_provider", "//common/values:proto_message_value_provider", - "//runtime/standard:add", "//runtime/standard:int", "//runtime/standard:timestamp", "@maven//:com_google_code_findbugs_annotations", diff --git a/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java b/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java index 88be0d3c3..98991d383 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java +++ b/runtime/src/main/java/dev/cel/runtime/CelFunctionBinding.java @@ -100,6 +100,7 @@ static CelFunctionBinding from( overloadId, ImmutableList.copyOf(argTypes), impl, /* isStrict= */ true); } + /** See {@link #fromOverloads(String, Collection)}. */ static ImmutableSet fromOverloads( String functionName, CelFunctionBinding... overloadBindings) { diff --git a/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java b/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java index 3d75845cf..2da08120c 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java +++ b/runtime/src/main/java/dev/cel/runtime/CelLateFunctionBindings.java @@ -63,7 +63,12 @@ public static CelLateFunctionBindings from(Collection functi } private static CelResolvedOverload createResolvedOverload(CelFunctionBinding binding) { + String functionName = binding.getOverloadId(); + if (binding instanceof InternalCelFunctionBinding) { + functionName = ((InternalCelFunctionBinding) binding).getFunctionName(); + } return CelResolvedOverload.of( + functionName, binding.getOverloadId(), binding.getDefinition(), binding.isStrict(), diff --git a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java index 7063720a1..fbe9a3289 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java @@ -30,6 +30,9 @@ @Internal public abstract class CelResolvedOverload { + /** The base function name. */ + public abstract String getFunctionName(); + /** The overload id of the function. */ public abstract String getOverloadId(); @@ -61,7 +64,7 @@ public Object invoke(Object[] args) throws CelEvaluationException { || CelFunctionOverload.canHandle(args, getParameterTypes(), isStrict())) { return getDefinition().apply(args); } - throw new CelOverloadNotFoundException(getOverloadId()); + throw new CelOverloadNotFoundException(getFunctionName(), ImmutableList.of(getOverloadId())); } public Object invoke(Object arg) throws CelEvaluationException { @@ -69,7 +72,7 @@ public Object invoke(Object arg) throws CelEvaluationException { || CelFunctionOverload.canHandle(arg, getParameterTypes(), isStrict())) { return getOptimizedDefinition().apply(arg); } - throw new CelOverloadNotFoundException(getOverloadId()); + throw new CelOverloadNotFoundException(getFunctionName(), ImmutableList.of(getOverloadId())); } public Object invoke(Object arg1, Object arg2) throws CelEvaluationException { @@ -77,24 +80,28 @@ public Object invoke(Object arg1, Object arg2) throws CelEvaluationException { || CelFunctionOverload.canHandle(arg1, arg2, getParameterTypes(), isStrict())) { return getOptimizedDefinition().apply(arg1, arg2); } - throw new CelOverloadNotFoundException(getOverloadId()); + throw new CelOverloadNotFoundException(getFunctionName(), ImmutableList.of(getOverloadId())); } /** - * Creates a new resolved overload from the given overload id, parameter types, and definition. + * Creates a new resolved overload from the given function name, overload id, parameter types, and + * definition. */ public static CelResolvedOverload of( + String functionName, String overloadId, CelFunctionOverload definition, boolean isStrict, Class... parameterTypes) { - return of(overloadId, definition, isStrict, ImmutableList.copyOf(parameterTypes)); + return of(functionName, overloadId, definition, isStrict, ImmutableList.copyOf(parameterTypes)); } /** - * Creates a new resolved overload from the given overload id, parameter types, and definition. + * Creates a new resolved overload from the given function name, overload id, parameter types, and + * definition. */ public static CelResolvedOverload of( + String functionName, String overloadId, CelFunctionOverload definition, boolean isStrict, @@ -104,7 +111,12 @@ public static CelResolvedOverload of( ? (OptimizedFunctionOverload) definition : definition::apply; return new AutoValue_CelResolvedOverload( - overloadId, ImmutableList.copyOf(parameterTypes), isStrict, definition, optimizedDef); + functionName, + overloadId, + ImmutableList.copyOf(parameterTypes), + isStrict, + definition, + optimizedDef); } /** diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java index cab2c666e..43b223fa0 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java @@ -381,7 +381,12 @@ private static DefaultDispatcher newDispatcher( DefaultDispatcher.Builder builder = DefaultDispatcher.newBuilder(); for (CelFunctionBinding binding : standardFunctions.newFunctionBindings(runtimeEquality, options)) { + String functionName = binding.getOverloadId(); + if (binding instanceof InternalCelFunctionBinding) { + functionName = ((InternalCelFunctionBinding) binding).getFunctionName(); + } builder.addOverload( + functionName, binding.getOverloadId(), binding.getArgTypes(), binding.isStrict(), @@ -389,7 +394,12 @@ private static DefaultDispatcher newDispatcher( } for (CelFunctionBinding binding : customFunctionBindings) { + String functionName = binding.getOverloadId(); + if (binding instanceof InternalCelFunctionBinding) { + functionName = ((InternalCelFunctionBinding) binding).getFunctionName(); + } builder.addOverload( + functionName, binding.getOverloadId(), binding.getArgTypes(), binding.isStrict(), diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java index 8ae4a9e3e..33702b2c6 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java @@ -305,7 +305,12 @@ public CelRuntimeLegacyImpl build() { DefaultDispatcher.Builder dispatcherBuilder = DefaultDispatcher.newBuilder(); for (CelFunctionBinding standardFunctionBinding : newStandardFunctionBindings(runtimeEquality)) { + String functionName = standardFunctionBinding.getOverloadId(); + if (standardFunctionBinding instanceof InternalCelFunctionBinding) { + functionName = ((InternalCelFunctionBinding) standardFunctionBinding).getFunctionName(); + } dispatcherBuilder.addOverload( + functionName, standardFunctionBinding.getOverloadId(), standardFunctionBinding.getArgTypes(), standardFunctionBinding.isStrict(), @@ -313,7 +318,12 @@ public CelRuntimeLegacyImpl build() { } for (CelFunctionBinding customBinding : customFunctionBindings.values()) { + String functionName = customBinding.getOverloadId(); + if (customBinding instanceof InternalCelFunctionBinding) { + functionName = ((InternalCelFunctionBinding) customBinding).getFunctionName(); + } dispatcherBuilder.addOverload( + functionName, customBinding.getOverloadId(), customBinding.getArgTypes(), customBinding.isStrict(), diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java index d6ddf3965..0a467db81 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java @@ -134,6 +134,8 @@ public static class Builder { @AutoValue @Immutable abstract static class OverloadEntry { + abstract String functionName(); + abstract ImmutableList> argTypes(); abstract boolean isStrict(); @@ -141,8 +143,12 @@ abstract static class OverloadEntry { abstract CelFunctionOverload overload(); private static OverloadEntry of( - ImmutableList> argTypes, boolean isStrict, CelFunctionOverload overload) { - return new AutoValue_DefaultDispatcher_Builder_OverloadEntry(argTypes, isStrict, overload); + String functionName, + ImmutableList> argTypes, + boolean isStrict, + CelFunctionOverload overload) { + return new AutoValue_DefaultDispatcher_Builder_OverloadEntry( + functionName, argTypes, isStrict, overload); } } @@ -150,16 +156,19 @@ private static OverloadEntry of( @CanIgnoreReturnValue public Builder addOverload( + String functionName, String overloadId, ImmutableList> argTypes, boolean isStrict, CelFunctionOverload overload) { + checkNotNull(functionName); + checkArgument(!functionName.isEmpty(), "Function name cannot be empty."); checkNotNull(overloadId); checkArgument(!overloadId.isEmpty(), "Overload ID cannot be empty."); checkNotNull(argTypes); checkNotNull(overload); - OverloadEntry newEntry = OverloadEntry.of(argTypes, isStrict, overload); + OverloadEntry newEntry = OverloadEntry.of(functionName, argTypes, isStrict, overload); overloads.merge( overloadId, @@ -188,7 +197,7 @@ private OverloadEntry mergeDynamicDispatchesOrThrow( boolean isStrict = mergedOverload.getOverloadBindings().stream().allMatch(CelFunctionBinding::isStrict); - return OverloadEntry.of(incoming.argTypes(), isStrict, mergedOverload); + return OverloadEntry.of(overloadId, incoming.argTypes(), isStrict, mergedOverload); } throw new IllegalArgumentException("Duplicate overload ID binding: " + overloadId); @@ -204,7 +213,11 @@ public DefaultDispatcher build() { resolvedOverloads.put( overloadId, CelResolvedOverload.of( - overloadId, overloadImpl, overloadEntry.isStrict(), overloadEntry.argTypes())); + overloadEntry.functionName(), + overloadId, + overloadImpl, + overloadEntry.isStrict(), + overloadEntry.argTypes())); } return new DefaultDispatcher(resolvedOverloads.buildOrThrow()); diff --git a/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java b/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java index c1306ce19..7b8efe8fd 100644 --- a/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/FunctionBindingImpl.java @@ -23,7 +23,9 @@ import dev.cel.common.exceptions.CelOverloadNotFoundException; @Immutable -final class FunctionBindingImpl implements CelFunctionBinding { +final class FunctionBindingImpl implements InternalCelFunctionBinding { + + private final String functionName; private final String overloadId; @@ -33,6 +35,11 @@ final class FunctionBindingImpl implements CelFunctionBinding { private final boolean isStrict; + @Override + public String getFunctionName() { + return functionName; + } + @Override public String getOverloadId() { return overloadId; @@ -54,20 +61,34 @@ public boolean isStrict() { } FunctionBindingImpl( + String functionName, String overloadId, ImmutableList> argTypes, CelFunctionOverload definition, boolean isStrict) { + this.functionName = functionName; this.overloadId = overloadId; this.argTypes = argTypes; this.definition = definition; this.isStrict = isStrict; } + FunctionBindingImpl( + String overloadId, + ImmutableList> argTypes, + CelFunctionOverload definition, + boolean isStrict) { + this(overloadId, overloadId, argTypes, definition, isStrict); + } + static ImmutableSet groupOverloadsToFunction( String functionName, ImmutableSet overloadBindings) { ImmutableSet.Builder builder = ImmutableSet.builder(); - builder.addAll(overloadBindings); + for (CelFunctionBinding b : overloadBindings) { + builder.add( + new FunctionBindingImpl( + functionName, b.getOverloadId(), b.getArgTypes(), b.getDefinition(), b.isStrict())); + } // If there is already a binding with the same name as the function, we treat it as a // "Singleton" binding and do not create a dynamic dispatch wrapper for it. @@ -80,11 +101,12 @@ static ImmutableSet groupOverloadsToFunction( CelFunctionBinding singleBinding = Iterables.getOnlyElement(overloadBindings); builder.add( new FunctionBindingImpl( + functionName, functionName, singleBinding.getArgTypes(), singleBinding.getDefinition(), singleBinding.isStrict())); - } else { + } else if (overloadBindings.size() > 1) { builder.add(new DynamicDispatchBinding(functionName, overloadBindings)); } } @@ -93,7 +115,7 @@ static ImmutableSet groupOverloadsToFunction( } @Immutable - static final class DynamicDispatchBinding implements CelFunctionBinding { + static final class DynamicDispatchBinding implements InternalCelFunctionBinding { private final boolean isStrict; private final DynamicDispatchOverload dynamicDispatchOverload; @@ -103,6 +125,11 @@ public String getOverloadId() { return dynamicDispatchOverload.functionName; } + @Override + public String getFunctionName() { + return dynamicDispatchOverload.functionName; + } + @Override public ImmutableList> getArgTypes() { return ImmutableList.of(); diff --git a/runtime/src/main/java/dev/cel/runtime/InternalCelFunctionBinding.java b/runtime/src/main/java/dev/cel/runtime/InternalCelFunctionBinding.java new file mode 100644 index 000000000..48a0f36d1 --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/InternalCelFunctionBinding.java @@ -0,0 +1,29 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime; + +import com.google.errorprone.annotations.Immutable; +import dev.cel.common.annotations.Internal; + +/** + * Internal interface to expose the function name associated with a binding. + * + *

CEL Library Internals. Do Not Use. + */ +@Internal +@Immutable +public interface InternalCelFunctionBinding extends CelFunctionBinding { + String getFunctionName(); +} diff --git a/runtime/src/main/java/dev/cel/runtime/LiteRuntimeImpl.java b/runtime/src/main/java/dev/cel/runtime/LiteRuntimeImpl.java index 0e5c5cf30..d58eb3be4 100644 --- a/runtime/src/main/java/dev/cel/runtime/LiteRuntimeImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/LiteRuntimeImpl.java @@ -162,9 +162,18 @@ public CelLiteRuntime build() { functionBindingsBuilder .buildOrThrow() .forEach( - (String overloadId, CelFunctionBinding func) -> - dispatcherBuilder.addOverload( - overloadId, func.getArgTypes(), func.isStrict(), func.getDefinition())); + (String overloadId, CelFunctionBinding func) -> { + String functionName = func.getOverloadId(); + if (func instanceof InternalCelFunctionBinding) { + functionName = ((InternalCelFunctionBinding) func).getFunctionName(); + } + dispatcherBuilder.addOverload( + functionName, + overloadId, + func.getArgTypes(), + func.isStrict(), + func.getDefinition()); + }); Interpreter interpreter = new DefaultInterpreter( diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalBinary.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalBinary.java index 7771da3e6..16eba3cce 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalBinary.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalBinary.java @@ -25,6 +25,7 @@ final class EvalBinary extends PlannedInterpretable { + private final String functionName; private final CelResolvedOverload resolvedOverload; private final PlannedInterpretable arg1; private final PlannedInterpretable arg2; @@ -48,25 +49,29 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval return unknowns; } - return EvalHelpers.dispatch(resolvedOverload, celValueConverter, argVal1, argVal2); + return EvalHelpers.dispatch( + functionName, resolvedOverload, celValueConverter, argVal1, argVal2); } static EvalBinary create( long exprId, + String functionName, CelResolvedOverload resolvedOverload, PlannedInterpretable arg1, PlannedInterpretable arg2, CelValueConverter celValueConverter) { - return new EvalBinary(exprId, resolvedOverload, arg1, arg2, celValueConverter); + return new EvalBinary(exprId, functionName, resolvedOverload, arg1, arg2, celValueConverter); } private EvalBinary( long exprId, + String functionName, CelResolvedOverload resolvedOverload, PlannedInterpretable arg1, PlannedInterpretable arg2, CelValueConverter celValueConverter) { super(exprId); + this.functionName = functionName; this.resolvedOverload = resolvedOverload; this.arg1 = arg1; this.arg2 = arg2; diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java index a30f91880..220642f4a 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java @@ -56,7 +56,10 @@ static Object evalStrictly( } static Object dispatch( - CelResolvedOverload overload, CelValueConverter valueConverter, Object[] args) + String functionName, + CelResolvedOverload overload, + CelValueConverter valueConverter, + Object[] args) throws CelEvaluationException { try { Object result = overload.invoke(args); @@ -66,7 +69,11 @@ static Object dispatch( } } - static Object dispatch(CelResolvedOverload overload, CelValueConverter valueConverter, Object arg) + static Object dispatch( + String functionName, + CelResolvedOverload overload, + CelValueConverter valueConverter, + Object arg) throws CelEvaluationException { try { Object result = overload.invoke(arg); @@ -77,7 +84,11 @@ static Object dispatch(CelResolvedOverload overload, CelValueConverter valueConv } static Object dispatch( - CelResolvedOverload overload, CelValueConverter valueConverter, Object arg1, Object arg2) + String functionName, + CelResolvedOverload overload, + CelValueConverter valueConverter, + Object arg1, + Object arg2) throws CelEvaluationException { try { Object result = overload.invoke(arg1, arg2); @@ -97,7 +108,7 @@ private static RuntimeException handleDispatchException( return new IllegalArgumentException( String.format( "Function '%s' failed with arg(s) '%s'", - overload.getOverloadId(), Joiner.on(", ").join(args)), + overload.getFunctionName(), Joiner.on(", ").join(args)), e); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalLateBoundCall.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalLateBoundCall.java index cdee878ee..0bd251185 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalLateBoundCall.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalLateBoundCall.java @@ -55,7 +55,7 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval .findOverload(functionName, overloadIds, argVals) .orElseThrow(() -> new CelOverloadNotFoundException(functionName, overloadIds)); - return EvalHelpers.dispatch(resolvedOverload, celValueConverter, argVals); + return EvalHelpers.dispatch(functionName, resolvedOverload, celValueConverter, argVals); } static EvalLateBoundCall create( diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java index 322648ee3..57834161f 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java @@ -24,6 +24,7 @@ final class EvalUnary extends PlannedInterpretable { + private final String functionName; private final CelResolvedOverload resolvedOverload; private final PlannedInterpretable arg; private final CelValueConverter celValueConverter; @@ -34,23 +35,26 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval resolvedOverload.isStrict() ? evalStrictly(arg, resolver, frame) : evalNonstrictly(arg, resolver, frame); - return EvalHelpers.dispatch(resolvedOverload, celValueConverter, argVal); + return EvalHelpers.dispatch(functionName, resolvedOverload, celValueConverter, argVal); } static EvalUnary create( long exprId, + String functionName, CelResolvedOverload resolvedOverload, PlannedInterpretable arg, CelValueConverter celValueConverter) { - return new EvalUnary(exprId, resolvedOverload, arg, celValueConverter); + return new EvalUnary(exprId, functionName, resolvedOverload, arg, celValueConverter); } private EvalUnary( long exprId, + String functionName, CelResolvedOverload resolvedOverload, PlannedInterpretable arg, CelValueConverter celValueConverter) { super(exprId); + this.functionName = functionName; this.resolvedOverload = resolvedOverload; this.arg = arg; this.celValueConverter = celValueConverter; diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java index eb8745632..fe7c6c430 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java @@ -25,6 +25,7 @@ final class EvalVarArgsCall extends PlannedInterpretable { + private final String functionName; private final CelResolvedOverload resolvedOverload; @SuppressWarnings("Immutable") @@ -50,23 +51,26 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval return unknowns; } - return EvalHelpers.dispatch(resolvedOverload, celValueConverter, argVals); + return EvalHelpers.dispatch(functionName, resolvedOverload, celValueConverter, argVals); } static EvalVarArgsCall create( long exprId, + String functionName, CelResolvedOverload resolvedOverload, PlannedInterpretable[] args, CelValueConverter celValueConverter) { - return new EvalVarArgsCall(exprId, resolvedOverload, args, celValueConverter); + return new EvalVarArgsCall(exprId, functionName, resolvedOverload, args, celValueConverter); } private EvalVarArgsCall( long exprId, + String functionName, CelResolvedOverload resolvedOverload, PlannedInterpretable[] args, CelValueConverter celValueConverter) { super(exprId); + this.functionName = functionName; this.resolvedOverload = resolvedOverload; this.args = args; this.celValueConverter = celValueConverter; diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalZeroArity.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalZeroArity.java index 5b3138207..7798c8253 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalZeroArity.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalZeroArity.java @@ -22,22 +22,30 @@ final class EvalZeroArity extends PlannedInterpretable { private static final Object[] EMPTY_ARRAY = new Object[0]; + private final String functionName; private final CelResolvedOverload resolvedOverload; private final CelValueConverter celValueConverter; @Override public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { - return EvalHelpers.dispatch(resolvedOverload, celValueConverter, EMPTY_ARRAY); + return EvalHelpers.dispatch(functionName, resolvedOverload, celValueConverter, EMPTY_ARRAY); } static EvalZeroArity create( - long exprId, CelResolvedOverload resolvedOverload, CelValueConverter celValueConverter) { - return new EvalZeroArity(exprId, resolvedOverload, celValueConverter); + long exprId, + String functionName, + CelResolvedOverload resolvedOverload, + CelValueConverter celValueConverter) { + return new EvalZeroArity(exprId, functionName, resolvedOverload, celValueConverter); } private EvalZeroArity( - long exprId, CelResolvedOverload resolvedOverload, CelValueConverter celValueConverter) { + long exprId, + String functionName, + CelResolvedOverload resolvedOverload, + CelValueConverter celValueConverter) { super(exprId); + this.functionName = functionName; this.resolvedOverload = resolvedOverload; this.celValueConverter = celValueConverter; } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java index 9bd5f3ecd..a0b74fc99 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java @@ -308,15 +308,21 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { switch (argCount) { case 0: - return EvalZeroArity.create(expr.id(), resolvedOverload, celValueConverter); + return EvalZeroArity.create(expr.id(), functionName, resolvedOverload, celValueConverter); case 1: - return EvalUnary.create(expr.id(), resolvedOverload, evaluatedArgs[0], celValueConverter); + return EvalUnary.create( + expr.id(), functionName, resolvedOverload, evaluatedArgs[0], celValueConverter); case 2: return EvalBinary.create( - expr.id(), resolvedOverload, evaluatedArgs[0], evaluatedArgs[1], celValueConverter); + expr.id(), + functionName, + resolvedOverload, + evaluatedArgs[0], + evaluatedArgs[1], + celValueConverter); default: return EvalVarArgsCall.create( - expr.id(), resolvedOverload, evaluatedArgs, celValueConverter); + expr.id(), functionName, resolvedOverload, evaluatedArgs, celValueConverter); } } diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index 577010971..7cd24f040 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -2,6 +2,7 @@ load("@rules_java//java:defs.bzl", "java_library") load("//:cel_android_rules.bzl", "cel_android_local_test") load("//:testing.bzl", "junit4_test_suites") +# Invalidate cache after file removal package( default_applicable_licenses = ["//:license"], default_testonly = True, diff --git a/runtime/src/test/java/dev/cel/runtime/CelResolvedOverloadTest.java b/runtime/src/test/java/dev/cel/runtime/CelResolvedOverloadTest.java index c1210c1ba..471282117 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelResolvedOverloadTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelResolvedOverloadTest.java @@ -27,11 +27,13 @@ public final class CelResolvedOverloadTest { CelResolvedOverload getIncrementIntOverload() { return CelResolvedOverload.of( - "increment_int", - (args) -> { - Long arg = (Long) args[0]; - return arg + 1; - }, + /* functionName= */ "increment_int", + /* overloadId= */ "increment_int_overload", + (CelFunctionOverload) + (args) -> { + Long arg = (Long) args[0]; + return arg + 1; + }, /* isStrict= */ true, Long.class); } @@ -45,14 +47,23 @@ public void canHandle_matchingTypes_returnsTrue() { public void canHandle_nullMessageType_returnsFalse() { CelResolvedOverload overload = CelResolvedOverload.of( - "identity", (args) -> args[0], /* isStrict= */ true, TestAllTypes.class); + /* functionName= */ "identity", + /* overloadId= */ "identity_overload", + (CelFunctionOverload) (args) -> args[0], + /* isStrict= */ true, + TestAllTypes.class); assertThat(overload.canHandle(new Object[] {null})).isFalse(); } @Test public void canHandle_nullPrimitive_returnsFalse() { CelResolvedOverload overload = - CelResolvedOverload.of("identity", (args) -> args[0], /* isStrict= */ true, Long.class); + CelResolvedOverload.of( + /* functionName= */ "identity", + /* overloadId= */ "identity_overload", + (CelFunctionOverload) (args) -> args[0], + /* isStrict= */ true, + Long.class); assertThat(overload.canHandle(new Object[] {null})).isFalse(); } @@ -70,10 +81,12 @@ public void canHandle_nonMatchingArgCount_returnsFalse() { public void canHandle_nonStrictOverload_returnsTrue() { CelResolvedOverload nonStrictOverload = CelResolvedOverload.of( - "non_strict", - (args) -> { - return false; - }, + /* functionName= */ "non_strict", + /* overloadId= */ "non_strict_overload", + (CelFunctionOverload) + (args) -> { + return false; + }, /* isStrict= */ false, Long.class, Long.class); @@ -87,10 +100,12 @@ public void canHandle_nonStrictOverload_returnsTrue() { public void canHandle_nonStrictOverload_returnsFalse() { CelResolvedOverload nonStrictOverload = CelResolvedOverload.of( - "non_strict", - (args) -> { - return false; - }, + /* functionName= */ "non_strict", + /* overloadId= */ "non_strict_overload", + (CelFunctionOverload) + (args) -> { + return false; + }, /* isStrict= */ false, Long.class, Long.class); diff --git a/runtime/src/test/java/dev/cel/runtime/DefaultDispatcherTest.java b/runtime/src/test/java/dev/cel/runtime/DefaultDispatcherTest.java index 255360ee1..d862ddb33 100644 --- a/runtime/src/test/java/dev/cel/runtime/DefaultDispatcherTest.java +++ b/runtime/src/test/java/dev/cel/runtime/DefaultDispatcherTest.java @@ -37,11 +37,19 @@ public void setup() { overloads.put( "overload_1", CelResolvedOverload.of( - "overload_1", args -> (Long) args[0] + 1, /* isStrict= */ true, Long.class)); + /* functionName= */ "overload_1", + /* overloadId= */ "overload_1", + args -> (Long) args[0] + 1, + /* isStrict= */ true, + Long.class)); overloads.put( "overload_2", CelResolvedOverload.of( - "overload_2", args -> (Long) args[0] + 2, /* isStrict= */ true, Long.class)); + /* functionName= */ "overload_2", + /* overloadId= */ "overload_2", + args -> (Long) args[0] + 2, + /* isStrict= */ true, + Long.class)); } @Test diff --git a/runtime/src/test/java/dev/cel/runtime/DefaultInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/DefaultInterpreterTest.java index 1a8f45161..bd0e96856 100644 --- a/runtime/src/test/java/dev/cel/runtime/DefaultInterpreterTest.java +++ b/runtime/src/test/java/dev/cel/runtime/DefaultInterpreterTest.java @@ -77,15 +77,21 @@ public Object adapt(String messageName, Object message) { CelAbstractSyntaxTree ast = celCompiler.compile("[1].all(x, [2].all(y, error()))").getAst(); DefaultDispatcher.Builder dispatcherBuilder = DefaultDispatcher.newBuilder(); dispatcherBuilder.addOverload( - "error", - ImmutableList.of(long.class), + /* functionName= */ "error", + /* overloadId= */ "error_overload", + ImmutableList.>of(long.class), /* isStrict= */ true, (args) -> new IllegalArgumentException("Always throws")); CelFunctionBinding notStrictlyFalseBinding = NotStrictlyFalseOverload.NOT_STRICTLY_FALSE.newFunctionBinding( CelOptions.DEFAULT, RuntimeEquality.create(RuntimeHelpers.create(), CelOptions.DEFAULT)); + String functionName = notStrictlyFalseBinding.getOverloadId(); + if (notStrictlyFalseBinding instanceof InternalCelFunctionBinding) { + functionName = ((InternalCelFunctionBinding) notStrictlyFalseBinding).getFunctionName(); + } dispatcherBuilder.addOverload( + functionName, notStrictlyFalseBinding.getOverloadId(), notStrictlyFalseBinding.getArgTypes(), notStrictlyFalseBinding.isStrict(), diff --git a/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java index 2b0e53298..c0b0f76c4 100644 --- a/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java +++ b/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java @@ -82,13 +82,14 @@ protected CelAbstractSyntaxTree prepareTest(CelTypeProvider typeProvider) { @Override public void optional_errors() { - if (isParseOnly) { - // Parsed-only evaluation contains function name in the - // error message instead of the function overload. - skipBaselineVerification(); - } else { - super.optional_errors(); - } + // Exercised in planner_optional_errors instead + skipBaselineVerification(); + } + + @Test + public void planner_optional_errors() { + source = "optional.unwrap([dyn(1)])"; + runTest(ImmutableMap.of()); } @Override diff --git a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java index de30902d3..c58ae782b 100644 --- a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java +++ b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java @@ -75,6 +75,7 @@ import dev.cel.runtime.CelUnknownSet; import dev.cel.runtime.DefaultDispatcher; import dev.cel.runtime.DescriptorTypeResolver; +import dev.cel.runtime.InternalCelFunctionBinding; import dev.cel.runtime.PartialVars; import dev.cel.runtime.Program; import dev.cel.runtime.RuntimeEquality; @@ -244,6 +245,7 @@ private static void addBindingsToDispatcher( overloadBindings.forEach( overload -> builder.addOverload( + ((InternalCelFunctionBinding) overload).getFunctionName(), overload.getOverloadId(), overload.getArgTypes(), overload.isStrict(), @@ -494,15 +496,11 @@ public void plan_call_zeroArgs() throws Exception { public void plan_call_throws() throws Exception { CelAbstractSyntaxTree ast = compile("error()"); Program program = PLANNER.plan(ast); - String expectedOverloadId = isParseOnly ? "error" : "error_overload"; CelEvaluationException e = assertThrows(CelEvaluationException.class, program::eval); assertThat(e) .hasMessageThat() - .contains( - "evaluation error at :5: Function '" - + expectedOverloadId - + "' failed with arg(s) ''"); + .contains("evaluation error at :5: Function 'error' failed with arg(s) ''"); assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); assertThat(e.getCause()).hasMessageThat().contains("Intentional error"); } @@ -562,13 +560,11 @@ public void plan_call_mapIndex() throws Exception { public void plan_call_noMatchingOverload_throws() throws Exception { CelAbstractSyntaxTree ast = compile("concat(b'abc', dyn_var)"); Program program = PLANNER.plan(ast); - String errorMsg; + String errorMsg = + "No matching overload for function 'concat'. Overload candidates: concat_bytes_bytes"; if (isParseOnly) { - errorMsg = - "No matching overload for function 'concat'. Overload candidates: concat_bytes_bytes," - + " bytes_concat_bytes"; - } else { - errorMsg = "No matching overload for function 'concat_bytes_bytes'"; + // Parsed-only evaluation includes both overloads as candidates due to dynamic dispatch + errorMsg += ", bytes_concat_bytes"; } CelEvaluationException e = diff --git a/runtime/src/test/resources/planner_optional_errors.baseline b/runtime/src/test/resources/planner_optional_errors.baseline new file mode 100644 index 000000000..3d59fefca --- /dev/null +++ b/runtime/src/test/resources/planner_optional_errors.baseline @@ -0,0 +1,5 @@ +Source: optional.unwrap([dyn(1)]) +=====> +bindings: {} +error: evaluation error at test_location:15: Function 'optional.unwrap' failed with arg(s) '[1]' +error_code: INTERNAL_ERROR From f2d69d9358ed6a3d04c9c50e931adda426eedb6d Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 22 Apr 2026 17:07:16 -0700 Subject: [PATCH 46/66] Add planner test coverage for policy compilation PiperOrigin-RevId: 904128325 --- .../main/java/dev/cel/bundle/CelBuilder.java | 8 +++ .../src/main/java/dev/cel/bundle/CelImpl.java | 12 +++++ .../cel/common/values/CelValueConverter.java | 4 +- .../optimizers/ConstantFoldingOptimizer.java | 4 +- .../src/test/java/dev/cel/policy/BUILD.bazel | 2 +- .../cel/policy/CelPolicyCompilerImplTest.java | 53 +++++++++++-------- 6 files changed, 56 insertions(+), 27 deletions(-) diff --git a/bundle/src/main/java/dev/cel/bundle/CelBuilder.java b/bundle/src/main/java/dev/cel/bundle/CelBuilder.java index 1dadaeb39..f603b479f 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelBuilder.java +++ b/bundle/src/main/java/dev/cel/bundle/CelBuilder.java @@ -165,6 +165,14 @@ public interface CelBuilder { @CanIgnoreReturnValue CelBuilder addFunctionBindings(Iterable bindings); + /** Adds bindings for functions that are allowed to be late-bound (resolved at execution time). */ + @CanIgnoreReturnValue + CelBuilder addLateBoundFunctions(String... lateBoundFunctionNames); + + /** Adds bindings for functions that are allowed to be late-bound (resolved at execution time). */ + @CanIgnoreReturnValue + CelBuilder addLateBoundFunctions(Iterable lateBoundFunctionNames); + /** Set the expected {@code resultType} for the type-checked expression. */ @CanIgnoreReturnValue CelBuilder setResultType(CelType resultType); diff --git a/bundle/src/main/java/dev/cel/bundle/CelImpl.java b/bundle/src/main/java/dev/cel/bundle/CelImpl.java index ae0ab2395..f6b985065 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelImpl.java +++ b/bundle/src/main/java/dev/cel/bundle/CelImpl.java @@ -281,6 +281,18 @@ public CelBuilder addFunctionBindings(Iterable lateBoundFunctionNames) { + runtimeBuilder.addLateBoundFunctions(lateBoundFunctionNames); + return this; + } + @Override public CelBuilder setResultType(CelType resultType) { checkNotNull(resultType); diff --git a/common/src/main/java/dev/cel/common/values/CelValueConverter.java b/common/src/main/java/dev/cel/common/values/CelValueConverter.java index 2af0a76cb..70d04acc8 100644 --- a/common/src/main/java/dev/cel/common/values/CelValueConverter.java +++ b/common/src/main/java/dev/cel/common/values/CelValueConverter.java @@ -117,7 +117,7 @@ protected Object normalizePrimitive(Object value) { } /** Adapts a {@link CelValue} to a plain old Java Object. */ - private static Object unwrap(CelValue celValue) { + private Object unwrap(CelValue celValue) { Preconditions.checkNotNull(celValue); if (celValue instanceof OptionalValue) { @@ -126,7 +126,7 @@ private static Object unwrap(CelValue celValue) { return Optional.empty(); } - return Optional.of(optionalValue.value()); + return Optional.of(maybeUnwrap(optionalValue.value())); } if (celValue instanceof ErrorValue) { diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index c017911f9..8a8786ce8 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -46,7 +46,6 @@ import dev.cel.optimizer.AstMutator; import dev.cel.optimizer.CelAstOptimizer; import dev.cel.optimizer.CelOptimizationException; -import dev.cel.runtime.CelAttribute.Qualifier; import dev.cel.runtime.CelAttributePattern; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.PartialVars; @@ -683,8 +682,7 @@ private static Object evaluateExpr(Cel cel, CelNavigableMutableExpr navigableMut .allNodes() .filter(node -> node.getKind().equals(Kind.IDENT)) .map(node -> node.expr().ident().name()) - .filter(Qualifier::isLegalIdentifier) - .map(CelAttributePattern::create) + .map(CelAttributePattern::fromQualifiedIdentifier) .collect(toImmutableList()); CelAbstractSyntaxTree ast = CelAbstractSyntaxTree.newParsedAst( diff --git a/policy/src/test/java/dev/cel/policy/BUILD.bazel b/policy/src/test/java/dev/cel/policy/BUILD.bazel index 9106caf70..8a28caee1 100644 --- a/policy/src/test/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/test/java/dev/cel/policy/BUILD.bazel @@ -35,7 +35,7 @@ java_library( "//policy:validation_exception", "//runtime", "//runtime:function_binding", - "//runtime:late_function_binding", + "//testing:cel_runtime_flavor", "//testing/protos:single_file_java_proto", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", "@maven//:com_google_guava_guava", diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index fec5f9b94..d5254571d 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -26,9 +26,9 @@ import com.google.testing.junit.testparameterinjector.TestParameterValue; import com.google.testing.junit.testparameterinjector.TestParameterValuesProvider; import dev.cel.bundle.Cel; +import dev.cel.bundle.CelBuilder; import dev.cel.bundle.CelEnvironment; import dev.cel.bundle.CelEnvironmentYamlParser; -import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelOptions; import dev.cel.common.types.OptionalType; @@ -45,6 +45,7 @@ import dev.cel.policy.PolicyTestHelper.TestYamlPolicy; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelLateFunctionBindings; +import dev.cel.testing.CelRuntimeFlavor; import dev.cel.testing.testdata.SingleFile; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; import java.io.IOException; @@ -61,7 +62,12 @@ public final class CelPolicyCompilerImplTest { private static final CelEnvironmentYamlParser ENVIRONMENT_PARSER = CelEnvironmentYamlParser.newInstance(); private static final CelOptions CEL_OPTIONS = - CelOptions.current().populateMacroCalls(true).build(); + CelOptions.current() + .populateMacroCalls(true) + .enableHeterogeneousNumericComparisons(true) + .build(); + + @TestParameter public CelRuntimeFlavor runtimeFlavor; @Test public void compileYamlPolicy_success(@TestParameter TestYamlPolicy yamlPolicy) throws Exception { @@ -258,7 +264,6 @@ public void evaluateYamlPolicy_nestedRuleProducesOptionalOutput() throws Excepti CelPolicy policy = POLICY_PARSER.parse(policySource); CelAbstractSyntaxTree compiledPolicyAst = CelPolicyCompilerFactory.newPolicyCompiler(cel).build().compile(policy); - Optional evalResult = (Optional) cel.createProgram(compiledPolicyAst).eval(); // Result is Optional> @@ -278,7 +283,12 @@ public void evaluateYamlPolicy_lateBoundFunction() throws Exception { + " return:\n" + " type_name: 'string'\n"; CelEnvironment celEnvironment = ENVIRONMENT_PARSER.parse(configSource); - Cel cel = celEnvironment.extend(newCel(), CelOptions.DEFAULT); + CelBuilder celBuilder = newCel().toCelBuilder(); + if (runtimeFlavor == CelRuntimeFlavor.PLANNER) { + celBuilder.addLateBoundFunctions("lateBoundFunc"); + } + Cel cel = celEnvironment.extend(celBuilder.build(), CEL_OPTIONS); + String policySource = "name: late_bound_function_policy\n" + "rule:\n" @@ -298,7 +308,6 @@ public void evaluateYamlPolicy_lateBoundFunction() throws Exception { (String) cel.createProgram(compiledPolicyAst) .eval((unused) -> Optional.empty(), lateFunctionBindings); - assertThat(evalResult).isEqualTo("foo" + exampleValue); } @@ -319,7 +328,6 @@ public void evaluateYamlPolicy_withSimpleVariable() throws Exception { CelAbstractSyntaxTree compiledPolicyAst = CelPolicyCompilerFactory.newPolicyCompiler(cel).build().compile(policy); - boolean evalResult = (boolean) cel.createProgram(compiledPolicyAst).eval(); assertThat(evalResult).isFalse(); @@ -358,8 +366,9 @@ protected ImmutableList provideValues(Context context) throw } } - private static Cel newCel() { - return CelFactory.standardCelBuilder() + private Cel newCel() { + return runtimeFlavor + .builder() .setStandardMacros(CelStandardMacro.STANDARD_MACROS) .addCompilerLibraries(CelOptionalLibrary.INSTANCE) .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) @@ -367,19 +376,21 @@ private static Cel newCel() { .addMessageTypes(TestAllTypes.getDescriptor(), SingleFile.getDescriptor()) .setOptions(CEL_OPTIONS) .addFunctionBindings( - CelFunctionBinding.from( - "locationCode_string", - String.class, - (ip) -> { - switch (ip) { - case "10.0.0.1": - return "us"; - case "10.0.0.2": - return "de"; - default: - return "ir"; - } - })) + CelFunctionBinding.fromOverloads( + "locationCode", + CelFunctionBinding.from( + "locationCode_string", + String.class, + (ip) -> { + switch (ip) { + case "10.0.0.1": + return "us"; + case "10.0.0.2": + return "de"; + default: + return "ir"; + } + }))) .build(); } From 1cd28068e3f4fe6fcc881fb8007ee3fb2796b18e Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 22 Apr 2026 17:34:49 -0700 Subject: [PATCH 47/66] Implement trace for planner PiperOrigin-RevId: 904138053 --- .../main/java/dev/cel/extensions/BUILD.bazel | 1 - .../cel/extensions/CelExtensionTestBase.java | 4 +- .../cel/extensions/CelMathExtensionsTest.java | 4 +- .../cel/policy/CelPolicyCompilerImplTest.java | 2 +- runtime/planner/BUILD.bazel | 6 + .../src/main/java/dev/cel/runtime/BUILD.bazel | 5 + .../java/dev/cel/runtime/CelRuntimeImpl.java | 56 ++++++- .../java/dev/cel/runtime/planner/BUILD.bazel | 30 +++- .../cel/runtime/planner/BlockMemoizer.java | 2 +- .../java/dev/cel/runtime/planner/EvalAnd.java | 13 +- .../cel/runtime/planner/EvalAttribute.java | 19 ++- .../dev/cel/runtime/planner/EvalBinary.java | 11 +- .../dev/cel/runtime/planner/EvalBlock.java | 21 +-- .../cel/runtime/planner/EvalConditional.java | 11 +- .../dev/cel/runtime/planner/EvalConstant.java | 11 +- .../cel/runtime/planner/EvalCreateList.java | 12 +- .../cel/runtime/planner/EvalCreateMap.java | 15 +- .../cel/runtime/planner/EvalCreateStruct.java | 12 +- .../dev/cel/runtime/planner/EvalFold.java | 11 +- .../dev/cel/runtime/planner/EvalHelpers.java | 6 +- .../runtime/planner/EvalLateBoundCall.java | 11 +- .../cel/runtime/planner/EvalOptionalOr.java | 11 +- .../runtime/planner/EvalOptionalOrValue.java | 11 +- .../planner/EvalOptionalSelectField.java | 11 +- .../java/dev/cel/runtime/planner/EvalOr.java | 13 +- .../dev/cel/runtime/planner/EvalTestOnly.java | 15 +- .../dev/cel/runtime/planner/EvalUnary.java | 11 +- .../cel/runtime/planner/EvalVarArgsCall.java | 11 +- .../cel/runtime/planner/EvalZeroArity.java | 11 +- .../cel/runtime/planner/ExecutionFrame.java | 21 ++- .../planner/InterpretableAttribute.java | 7 +- .../runtime/planner/PlannedInterpretable.java | 24 ++- .../cel/runtime/planner/PlannedProgram.java | 68 ++++++-- .../cel/runtime/planner/ProgramPlanner.java | 88 +++++----- .../src/test/java/dev/cel/runtime/BUILD.bazel | 1 + .../java/dev/cel/runtime/CelRuntimeTest.java | 155 ++++++++++++++---- 36 files changed, 483 insertions(+), 238 deletions(-) diff --git a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel index 454b2a2fd..f8e4bfc8c 100644 --- a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel @@ -122,7 +122,6 @@ java_library( ":extension_library", "//checker:checker_builder", "//common:compiler_common", - "//common:options", "//common/ast", "//common/exceptions:numeric_overflow", "//common/internal:comparison_functions", diff --git a/extensions/src/test/java/dev/cel/extensions/CelExtensionTestBase.java b/extensions/src/test/java/dev/cel/extensions/CelExtensionTestBase.java index c80ee38b6..3a509b003 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelExtensionTestBase.java +++ b/extensions/src/test/java/dev/cel/extensions/CelExtensionTestBase.java @@ -29,8 +29,8 @@ * planner runtime, along with parsed-only and checked expression evaluations for the planner. */ abstract class CelExtensionTestBase { - @TestParameter public CelRuntimeFlavor runtimeFlavor; - @TestParameter public boolean isParseOnly; + @TestParameter CelRuntimeFlavor runtimeFlavor; + @TestParameter boolean isParseOnly; @Before public void setUpBase() { diff --git a/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java index 383e50aa2..16d5c4c83 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java @@ -46,8 +46,8 @@ @RunWith(TestParameterInjector.class) public class CelMathExtensionsTest { - @TestParameter public CelRuntimeFlavor runtimeFlavor; - @TestParameter public boolean isParseOnly; + @TestParameter private CelRuntimeFlavor runtimeFlavor; + @TestParameter private boolean isParseOnly; private Cel cel; diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index d5254571d..d4ca76324 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -67,7 +67,7 @@ public final class CelPolicyCompilerImplTest { .enableHeterogeneousNumericComparisons(true) .build(); - @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter private CelRuntimeFlavor runtimeFlavor; @Test public void compileYamlPolicy_success(@TestParameter TestYamlPolicy yamlPolicy) throws Exception { diff --git a/runtime/planner/BUILD.bazel b/runtime/planner/BUILD.bazel index 8da29f270..9b5dbee6a 100644 --- a/runtime/planner/BUILD.bazel +++ b/runtime/planner/BUILD.bazel @@ -9,3 +9,9 @@ java_library( name = "program_planner", exports = ["//runtime/src/main/java/dev/cel/runtime/planner:program_planner"], ) + +java_library( + name = "planned_program", + visibility = ["//:internal"], + exports = ["//runtime/src/main/java/dev/cel/runtime/planner:planned_program"], +) diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index ef0ac71d4..0da68d548 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -853,6 +853,11 @@ java_library( "//common/values:cel_value_provider", "//common/values:combined_cel_value_provider", "//common/values:proto_message_value_provider", + "//runtime:activation", + "//runtime:interpretable", + "//runtime:proto_message_activation_factory", + "//runtime:resolved_overload", + "//runtime/planner:planned_program", "//runtime/planner:program_planner", "//runtime/standard:type", "@maven//:com_google_errorprone_error_prone_annotations", diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java index 43b223fa0..4cc738b6d 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java @@ -45,12 +45,14 @@ import dev.cel.common.values.CelValueProvider; import dev.cel.common.values.CombinedCelValueProvider; import dev.cel.common.values.ProtoMessageValueProvider; +import dev.cel.runtime.planner.PlannedProgram; import dev.cel.runtime.planner.ProgramPlanner; import dev.cel.runtime.standard.TypeFunction; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.function.Function; import org.jspecify.annotations.Nullable; @@ -98,6 +100,21 @@ public Program createProgram(CelAbstractSyntaxTree ast) throws CelEvaluationExce return toRuntimeProgram(planner().plan(ast)); } + private static final CelFunctionResolver EMPTY_FUNCTION_RESOLVER = + new CelFunctionResolver() { + @Override + public Optional findOverloadMatchingArgs( + String functionName, Collection overloadIds, Object[] args) { + return Optional.empty(); + } + + @Override + public Optional findOverloadMatchingArgs( + String functionName, Object[] args) { + return Optional.empty(); + } + }; + public Program toRuntimeProgram(dev.cel.runtime.Program program) { return new Program() { @@ -119,7 +136,13 @@ public Object eval(Map mapValue, CelFunctionResolver lateBoundFunctio @Override public Object eval(Message message) throws CelEvaluationException { - throw new UnsupportedOperationException("Not yet supported."); + PlannedProgram plannedProgram = (PlannedProgram) program; + return plannedProgram.evalOrThrow( + plannedProgram.interpretable(), + ProtoMessageActivationFactory.fromProto(message, plannedProgram.options()), + EMPTY_FUNCTION_RESOLVER, + /* partialVars= */ null, + /* listener= */ null); } @Override @@ -141,25 +164,38 @@ public Object eval(PartialVars partialVars) throws CelEvaluationException { @Override public Object trace(CelEvaluationListener listener) throws CelEvaluationException { - throw new UnsupportedOperationException("Trace is not yet supported."); + return ((PlannedProgram) program) + .trace(GlobalResolver.EMPTY, EMPTY_FUNCTION_RESOLVER, null, listener); } @Override public Object trace(Map mapValue, CelEvaluationListener listener) throws CelEvaluationException { - throw new UnsupportedOperationException("Trace is not yet supported."); + return ((PlannedProgram) program) + .trace(Activation.copyOf(mapValue), EMPTY_FUNCTION_RESOLVER, null, listener); } @Override public Object trace(Message message, CelEvaluationListener listener) throws CelEvaluationException { - throw new UnsupportedOperationException("Trace is not yet supported."); + PlannedProgram plannedProgram = (PlannedProgram) program; + return plannedProgram.evalOrThrow( + plannedProgram.interpretable(), + ProtoMessageActivationFactory.fromProto(message, plannedProgram.options()), + EMPTY_FUNCTION_RESOLVER, + /* partialVars= */ null, + listener); } @Override public Object trace(CelVariableResolver resolver, CelEvaluationListener listener) throws CelEvaluationException { - throw new UnsupportedOperationException("Trace is not yet supported."); + return ((PlannedProgram) program) + .trace( + (name) -> resolver.find(name).orElse(null), + EMPTY_FUNCTION_RESOLVER, + null, + listener); } @Override @@ -168,7 +204,12 @@ public Object trace( CelFunctionResolver lateBoundFunctionResolver, CelEvaluationListener listener) throws CelEvaluationException { - throw new UnsupportedOperationException("Trace is not yet supported."); + return ((PlannedProgram) program) + .trace( + (name) -> resolver.find(name).orElse(null), + lateBoundFunctionResolver, + null, + listener); } @Override @@ -177,7 +218,8 @@ public Object trace( CelFunctionResolver lateBoundFunctionResolver, CelEvaluationListener listener) throws CelEvaluationException { - throw new UnsupportedOperationException("Trace is not yet supported."); + return ((PlannedProgram) program) + .trace(Activation.copyOf(mapValue), lateBoundFunctionResolver, null, listener); } @Override diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel index 96382b9a9..c13d5857f 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel @@ -66,17 +66,21 @@ java_library( java_library( name = "planned_program", srcs = ["PlannedProgram.java"], + tags = [ + ], deps = [ ":error_metadata", ":localized_evaluation_exception", ":planned_interpretable", "//:auto_value", "//common:options", + "//common/annotations", "//common/exceptions:runtime_exception", "//common/values", "//runtime:activation", "//runtime:evaluation_exception", "//runtime:evaluation_exception_builder", + "//runtime:evaluation_listener", "//runtime:function_resolver", "//runtime:interpretable", "//runtime:interpreter_util", @@ -85,6 +89,7 @@ java_library( "//runtime:resolved_overload", "//runtime:variable_resolver", "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:org_jspecify_jspecify", ], ) @@ -93,6 +98,7 @@ java_library( srcs = ["EvalConstant.java"], deps = [ ":planned_interpretable", + "//common/ast", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", ], @@ -104,6 +110,7 @@ java_library( deps = [ ":planned_interpretable", ":qualifier", + "//common/ast", "@maven//:com_google_errorprone_error_prone_annotations", ], ) @@ -183,6 +190,7 @@ java_library( ":interpretable_attribute", ":planned_interpretable", ":qualifier", + "//common/ast", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", ], @@ -196,6 +204,7 @@ java_library( ":planned_interpretable", ":presence_test_qualifier", ":qualifier", + "//common/ast", "//runtime:evaluation_exception", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", @@ -208,6 +217,7 @@ java_library( deps = [ ":eval_helpers", ":planned_interpretable", + "//common/ast", "//common/values", "//runtime:evaluation_exception", "//runtime:interpretable", @@ -221,6 +231,7 @@ java_library( deps = [ ":eval_helpers", ":planned_interpretable", + "//common/ast", "//common/values", "//runtime:evaluation_exception", "//runtime:interpretable", @@ -234,6 +245,7 @@ java_library( deps = [ ":eval_helpers", ":planned_interpretable", + "//common/ast", "//common/values", "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", @@ -248,6 +260,7 @@ java_library( deps = [ ":eval_helpers", ":planned_interpretable", + "//common/ast", "//common/values", "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", @@ -262,6 +275,7 @@ java_library( deps = [ ":eval_helpers", ":planned_interpretable", + "//common/ast", "//common/exceptions:overload_not_found", "//common/values", "//runtime:accumulated_unknowns", @@ -278,6 +292,7 @@ java_library( deps = [ ":eval_helpers", ":planned_interpretable", + "//common/ast", "//common/values", "//runtime:accumulated_unknowns", "//runtime:interpretable", @@ -291,6 +306,7 @@ java_library( deps = [ ":eval_helpers", ":planned_interpretable", + "//common/ast", "//common/values", "//runtime:accumulated_unknowns", "//runtime:interpretable", @@ -303,6 +319,7 @@ java_library( srcs = ["EvalConditional.java"], deps = [ ":planned_interpretable", + "//common/ast", "//runtime:accumulated_unknowns", "//runtime:evaluation_exception", "//runtime:interpretable", @@ -316,11 +333,11 @@ java_library( deps = [ ":eval_helpers", ":planned_interpretable", + "//common/ast", "//common/types:type_providers", "//common/values", "//common/values:cel_value_provider", "//runtime:accumulated_unknowns", - "//runtime:evaluation_exception", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -333,8 +350,8 @@ java_library( deps = [ ":eval_helpers", ":planned_interpretable", + "//common/ast", "//runtime:accumulated_unknowns", - "//runtime:evaluation_exception", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -348,6 +365,7 @@ java_library( ":eval_helpers", ":localized_evaluation_exception", ":planned_interpretable", + "//common/ast", "//common/exceptions:duplicate_key", "//common/exceptions:invalid_argument", "//runtime:accumulated_unknowns", @@ -364,6 +382,7 @@ java_library( deps = [ ":activation_wrapper", ":planned_interpretable", + "//common/ast", "//common/exceptions:runtime_exception", "//common/values:mutable_map_value", "//runtime:accumulated_unknowns", @@ -421,13 +440,16 @@ java_library( deps = [ ":localized_evaluation_exception", "//common:options", + "//common/ast", "//common/exceptions:iteration_budget_exceeded", "//runtime:evaluation_exception", + "//runtime:evaluation_listener", "//runtime:function_resolver", "//runtime:interpretable", "//runtime:partial_vars", "//runtime:resolved_overload", "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:org_jspecify_jspecify", ], ) @@ -437,6 +459,7 @@ java_library( deps = [ ":eval_helpers", ":planned_interpretable", + "//common/ast", "//common/exceptions:overload_not_found", "//runtime:accumulated_unknowns", "//runtime:interpretable", @@ -451,6 +474,7 @@ java_library( deps = [ ":eval_helpers", ":planned_interpretable", + "//common/ast", "//common/exceptions:overload_not_found", "//runtime:accumulated_unknowns", "//runtime:interpretable", @@ -465,6 +489,7 @@ java_library( deps = [ ":eval_helpers", ":planned_interpretable", + "//common/ast", "//common/values", "//runtime:accumulated_unknowns", "//runtime:interpretable", @@ -478,6 +503,7 @@ java_library( srcs = ["EvalBlock.java"], deps = [ ":planned_interpretable", + "//common/ast", "//runtime:evaluation_exception", "//runtime:interpretable", "@maven//:com_google_errorprone_error_prone_annotations", diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java b/runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java index 978029b3d..80a0a5de0 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/BlockMemoizer.java @@ -61,7 +61,7 @@ Object resolveSlot(int idx, GlobalResolver resolver) { return result; } catch (CelEvaluationException e) { LocalizedEvaluationException localizedException = - new LocalizedEvaluationException(e, e.getErrorCode(), slotExprs[idx].exprId()); + new LocalizedEvaluationException(e, e.getErrorCode(), slotExprs[idx].expr().id()); slotVals[idx] = localizedException; throw localizedException; } catch (RuntimeException e) { diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalAnd.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalAnd.java index 91f5b2ff4..11da26a50 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalAnd.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalAnd.java @@ -17,6 +17,7 @@ import static dev.cel.runtime.planner.EvalHelpers.evalNonstrictly; import com.google.common.base.Preconditions; +import dev.cel.common.ast.CelExpr; import dev.cel.common.values.ErrorValue; import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.GlobalResolver; @@ -27,7 +28,7 @@ final class EvalAnd extends PlannedInterpretable { private final PlannedInterpretable[] args; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) { ErrorValue errorValue = null; AccumulatedUnknowns unknowns = null; for (PlannedInterpretable arg : args) { @@ -47,7 +48,7 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) { } else { errorValue = ErrorValue.create( - arg.exprId(), + arg.expr().id(), new IllegalArgumentException( String.format("Expected boolean value, found: %s", argVal))); } @@ -64,12 +65,12 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) { return true; } - static EvalAnd create(long exprId, PlannedInterpretable[] args) { - return new EvalAnd(exprId, args); + static EvalAnd create(CelExpr expr, PlannedInterpretable[] args) { + return new EvalAnd(expr, args); } - private EvalAnd(long exprId, PlannedInterpretable[] args) { - super(exprId); + private EvalAnd(CelExpr expr, PlannedInterpretable[] args) { + super(expr); Preconditions.checkArgument(args.length == 2); this.args = args; } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalAttribute.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalAttribute.java index a0a95c47a..56ea8a832 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalAttribute.java @@ -15,6 +15,7 @@ package dev.cel.runtime.planner; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; import dev.cel.runtime.GlobalResolver; @Immutable @@ -23,27 +24,27 @@ final class EvalAttribute extends InterpretableAttribute { private final Attribute attr; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) { - Object resolved = attr.resolve(exprId(), resolver, frame); + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) { + Object resolved = attr.resolve(expr().id(), resolver, frame); if (resolved instanceof MissingAttribute) { - ((MissingAttribute) resolved).resolve(exprId(), resolver, frame); + ((MissingAttribute) resolved).resolve(expr().id(), resolver, frame); } return resolved; } @Override - public EvalAttribute addQualifier(long exprId, Qualifier qualifier) { + public EvalAttribute addQualifier(CelExpr expr, Qualifier qualifier) { Attribute newAttribute = attr.addQualifier(qualifier); - return create(exprId, newAttribute); + return create(expr, newAttribute); } - static EvalAttribute create(long exprId, Attribute attr) { - return new EvalAttribute(exprId, attr); + static EvalAttribute create(CelExpr expr, Attribute attr) { + return new EvalAttribute(expr, attr); } - private EvalAttribute(long exprId, Attribute attr) { - super(exprId); + private EvalAttribute(CelExpr expr, Attribute attr) { + super(expr); this.attr = attr; } } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalBinary.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalBinary.java index 16eba3cce..fcade7789 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalBinary.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalBinary.java @@ -17,6 +17,7 @@ import static dev.cel.runtime.planner.EvalHelpers.evalNonstrictly; import static dev.cel.runtime.planner.EvalHelpers.evalStrictly; +import dev.cel.common.ast.CelExpr; import dev.cel.common.values.CelValueConverter; import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.CelEvaluationException; @@ -32,7 +33,7 @@ final class EvalBinary extends PlannedInterpretable { private final CelValueConverter celValueConverter; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { Object argVal1 = resolvedOverload.isStrict() ? evalStrictly(arg1, resolver, frame) @@ -54,23 +55,23 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval } static EvalBinary create( - long exprId, + CelExpr expr, String functionName, CelResolvedOverload resolvedOverload, PlannedInterpretable arg1, PlannedInterpretable arg2, CelValueConverter celValueConverter) { - return new EvalBinary(exprId, functionName, resolvedOverload, arg1, arg2, celValueConverter); + return new EvalBinary(expr, functionName, resolvedOverload, arg1, arg2, celValueConverter); } private EvalBinary( - long exprId, + CelExpr expr, String functionName, CelResolvedOverload resolvedOverload, PlannedInterpretable arg1, PlannedInterpretable arg2, CelValueConverter celValueConverter) { - super(exprId); + super(expr); this.functionName = functionName; this.resolvedOverload = resolvedOverload; this.arg1 = arg1; diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java index 41ad4034e..eed8791d4 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalBlock.java @@ -15,6 +15,7 @@ package dev.cel.runtime.planner; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.GlobalResolver; @@ -28,19 +29,19 @@ final class EvalBlock extends PlannedInterpretable { private final PlannedInterpretable resultExpr; static EvalBlock create( - long exprId, PlannedInterpretable[] slotExprs, PlannedInterpretable resultExpr) { - return new EvalBlock(exprId, slotExprs, resultExpr); + CelExpr expr, PlannedInterpretable[] slotExprs, PlannedInterpretable resultExpr) { + return new EvalBlock(expr, slotExprs, resultExpr); } private EvalBlock( - long exprId, PlannedInterpretable[] slotExprs, PlannedInterpretable resultExpr) { - super(exprId); + CelExpr expr, PlannedInterpretable[] slotExprs, PlannedInterpretable resultExpr) { + super(expr); this.slotExprs = slotExprs; this.resultExpr = resultExpr; } @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { BlockMemoizer memoizer = BlockMemoizer.create(slotExprs, frame); frame.setBlockMemoizer(memoizer); return resultExpr.eval(resolver, frame); @@ -50,17 +51,17 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval static final class EvalBlockSlot extends PlannedInterpretable { private final int slotIndex; - static EvalBlockSlot create(long exprId, int slotIndex) { - return new EvalBlockSlot(exprId, slotIndex); + static EvalBlockSlot create(CelExpr expr, int slotIndex) { + return new EvalBlockSlot(expr, slotIndex); } - private EvalBlockSlot(long exprId, int slotIndex) { - super(exprId); + private EvalBlockSlot(CelExpr expr, int slotIndex) { + super(expr); this.slotIndex = slotIndex; } @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) { return frame.getBlockMemoizer().resolveSlot(slotIndex, resolver); } } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalConditional.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalConditional.java index 3be1f016a..c2d730cdf 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalConditional.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalConditional.java @@ -15,6 +15,7 @@ package dev.cel.runtime.planner; import com.google.common.base.Preconditions; +import dev.cel.common.ast.CelExpr; import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.GlobalResolver; @@ -25,7 +26,7 @@ final class EvalConditional extends PlannedInterpretable { private final PlannedInterpretable[] args; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { PlannedInterpretable condition = args[0]; PlannedInterpretable truthy = args[1]; PlannedInterpretable falsy = args[2]; @@ -46,12 +47,12 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval return falsy.eval(resolver, frame); } - static EvalConditional create(long exprId, PlannedInterpretable[] args) { - return new EvalConditional(exprId, args); + static EvalConditional create(CelExpr expr, PlannedInterpretable[] args) { + return new EvalConditional(expr, args); } - private EvalConditional(long exprId, PlannedInterpretable[] args) { - super(exprId); + private EvalConditional(CelExpr expr, PlannedInterpretable[] args) { + super(expr); Preconditions.checkArgument(args.length == 3); this.args = args; } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalConstant.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalConstant.java index 2bebb059b..55554069b 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalConstant.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalConstant.java @@ -15,6 +15,7 @@ package dev.cel.runtime.planner; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; import dev.cel.runtime.GlobalResolver; @Immutable @@ -24,16 +25,16 @@ final class EvalConstant extends PlannedInterpretable { private final Object constant; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) { return constant; } - static EvalConstant create(long exprId, Object value) { - return new EvalConstant(exprId, value); + static EvalConstant create(CelExpr expr, Object value) { + return new EvalConstant(expr, value); } - private EvalConstant(long exprId, Object constant) { - super(exprId); + private EvalConstant(CelExpr expr, Object constant) { + super(expr); this.constant = constant; } } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateList.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateList.java index bae1e9302..265da3ab3 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateList.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateList.java @@ -16,8 +16,8 @@ import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; import dev.cel.runtime.AccumulatedUnknowns; -import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.GlobalResolver; import java.util.Optional; @@ -31,7 +31,7 @@ final class EvalCreateList extends PlannedInterpretable { private final boolean[] isOptional; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) { ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(values.length); AccumulatedUnknowns unknowns = null; for (int i = 0; i < values.length; i++) { @@ -66,12 +66,12 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval return builder.build(); } - static EvalCreateList create(long exprId, PlannedInterpretable[] values, boolean[] isOptional) { - return new EvalCreateList(exprId, values, isOptional); + static EvalCreateList create(CelExpr expr, PlannedInterpretable[] values, boolean[] isOptional) { + return new EvalCreateList(expr, values, isOptional); } - private EvalCreateList(long exprId, PlannedInterpretable[] values, boolean[] isOptional) { - super(exprId); + private EvalCreateList(CelExpr expr, PlannedInterpretable[] values, boolean[] isOptional) { + super(expr); this.values = values; this.isOptional = isOptional; } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateMap.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateMap.java index 1e1b831bb..8d34c10d0 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateMap.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateMap.java @@ -19,6 +19,7 @@ import com.google.common.collect.Sets; import com.google.common.primitives.UnsignedLong; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; import dev.cel.common.exceptions.CelDuplicateKeyException; import dev.cel.common.exceptions.CelInvalidArgumentException; import dev.cel.runtime.AccumulatedUnknowns; @@ -43,7 +44,7 @@ final class EvalCreateMap extends PlannedInterpretable { private final boolean[] isOptional; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { ImmutableMap.Builder builder = ImmutableMap.builderWithExpectedSize(keys.length); HashSet keysSeen = Sets.newHashSetWithExpectedSize(keys.length); @@ -62,7 +63,7 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval || key instanceof Boolean)) { throw new LocalizedEvaluationException( new CelInvalidArgumentException("Unsupported key type: " + key), - keyInterpretable.exprId()); + keyInterpretable.expr().id()); } boolean isDuplicate = !keysSeen.add(key); @@ -80,7 +81,7 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval if (isDuplicate) { throw new LocalizedEvaluationException( - CelDuplicateKeyException.of(key), keyInterpretable.exprId()); + CelDuplicateKeyException.of(key), keyInterpretable.expr().id()); } } @@ -119,19 +120,19 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval } static EvalCreateMap create( - long exprId, + CelExpr expr, PlannedInterpretable[] keys, PlannedInterpretable[] values, boolean[] isOptional) { - return new EvalCreateMap(exprId, keys, values, isOptional); + return new EvalCreateMap(expr, keys, values, isOptional); } private EvalCreateMap( - long exprId, + CelExpr expr, PlannedInterpretable[] keys, PlannedInterpretable[] values, boolean[] isOptional) { - super(exprId); + super(expr); Preconditions.checkArgument(keys.length == values.length); Preconditions.checkArgument(keys.length == isOptional.length); this.keys = keys; diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java index cdeb0c574..36485d5be 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java @@ -16,11 +16,11 @@ import com.google.common.collect.Maps; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; import dev.cel.common.types.CelType; import dev.cel.common.values.CelValueProvider; import dev.cel.common.values.StructValue; import dev.cel.runtime.AccumulatedUnknowns; -import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.GlobalResolver; import java.util.Collections; import java.util.Map; @@ -45,7 +45,7 @@ final class EvalCreateStruct extends PlannedInterpretable { private final boolean[] isOptional; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) { Map fieldValues = Maps.newHashMapWithExpectedSize(keys.length); AccumulatedUnknowns unknowns = null; for (int i = 0; i < keys.length; i++) { @@ -96,23 +96,23 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval } static EvalCreateStruct create( - long exprId, + CelExpr expr, CelValueProvider valueProvider, CelType structType, String[] keys, PlannedInterpretable[] values, boolean[] isOptional) { - return new EvalCreateStruct(exprId, valueProvider, structType, keys, values, isOptional); + return new EvalCreateStruct(expr, valueProvider, structType, keys, values, isOptional); } private EvalCreateStruct( - long exprId, + CelExpr expr, CelValueProvider valueProvider, CelType structType, String[] keys, PlannedInterpretable[] values, boolean[] isOptional) { - super(exprId); + super(expr); this.valueProvider = valueProvider; this.structType = structType; this.keys = keys; diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java index 090a8bfae..2de52e982 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; import dev.cel.common.exceptions.CelRuntimeException; import dev.cel.common.values.MutableMapValue; import dev.cel.runtime.AccumulatedUnknowns; @@ -40,7 +41,7 @@ final class EvalFold extends PlannedInterpretable { private final PlannedInterpretable result; static EvalFold create( - long exprId, + CelExpr expr, String accuVar, PlannedInterpretable accuInit, String iterVar, @@ -50,11 +51,11 @@ static EvalFold create( PlannedInterpretable loopStep, PlannedInterpretable result) { return new EvalFold( - exprId, accuVar, accuInit, iterVar, iterVar2, iterRange, loopCondition, loopStep, result); + expr, accuVar, accuInit, iterVar, iterVar2, iterRange, loopCondition, loopStep, result); } private EvalFold( - long exprId, + CelExpr expr, String accuVar, PlannedInterpretable accuInit, String iterVar, @@ -63,7 +64,7 @@ private EvalFold( PlannedInterpretable condition, PlannedInterpretable loopStep, PlannedInterpretable result) { - super(exprId); + super(expr); this.accuVar = accuVar; this.accuInit = accuInit; this.iterVar = iterVar; @@ -75,7 +76,7 @@ private EvalFold( } @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { Object iterRangeRaw = iterRange.eval(resolver, frame); if (iterRangeRaw instanceof AccumulatedUnknowns) { return iterRangeRaw; diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java index 220642f4a..f9812793e 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalHelpers.java @@ -34,7 +34,7 @@ static Object evalNonstrictly( // Example: foo [1] && strict_err [2] -> ID 2 is propagated. return ErrorValue.create(e.exprId(), e); } catch (Exception e) { - return ErrorValue.create(interpretable.exprId(), e); + return ErrorValue.create(interpretable.expr().id(), e); } } @@ -47,11 +47,11 @@ static Object evalStrictly( throw e; } catch (CelRuntimeException e) { // Wrap with current interpretable's location - throw new LocalizedEvaluationException(e, interpretable.exprId()); + throw new LocalizedEvaluationException(e, interpretable.expr().id()); } catch (Exception e) { // Wrap generic exceptions with location throw new LocalizedEvaluationException( - e, CelErrorCode.INTERNAL_ERROR, interpretable.exprId()); + e, CelErrorCode.INTERNAL_ERROR, interpretable.expr().id()); } } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalLateBoundCall.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalLateBoundCall.java index 0bd251185..719b4af21 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalLateBoundCall.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalLateBoundCall.java @@ -17,6 +17,7 @@ import static dev.cel.runtime.planner.EvalHelpers.evalStrictly; import com.google.common.collect.ImmutableList; +import dev.cel.common.ast.CelExpr; import dev.cel.common.exceptions.CelOverloadNotFoundException; import dev.cel.common.values.CelValueConverter; import dev.cel.runtime.AccumulatedUnknowns; @@ -35,7 +36,7 @@ final class EvalLateBoundCall extends PlannedInterpretable { private final CelValueConverter celValueConverter; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { Object[] argVals = new Object[args.length]; AccumulatedUnknowns unknowns = null; for (int i = 0; i < args.length; i++) { @@ -59,21 +60,21 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval } static EvalLateBoundCall create( - long exprId, + CelExpr expr, String functionName, ImmutableList overloadIds, PlannedInterpretable[] args, CelValueConverter celValueConverter) { - return new EvalLateBoundCall(exprId, functionName, overloadIds, args, celValueConverter); + return new EvalLateBoundCall(expr, functionName, overloadIds, args, celValueConverter); } private EvalLateBoundCall( - long exprId, + CelExpr expr, String functionName, ImmutableList overloadIds, PlannedInterpretable[] args, CelValueConverter celValueConverter) { - super(exprId); + super(expr); this.functionName = functionName; this.overloadIds = overloadIds; this.args = args; diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOr.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOr.java index 5ad1933d7..37e3a8ccb 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOr.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOr.java @@ -16,6 +16,7 @@ import com.google.common.base.Preconditions; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; import dev.cel.common.exceptions.CelOverloadNotFoundException; import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.GlobalResolver; @@ -27,7 +28,7 @@ final class EvalOptionalOr extends PlannedInterpretable { private final PlannedInterpretable rhs; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) { Object lhsValue = EvalHelpers.evalStrictly(lhs, resolver, frame); if (lhsValue instanceof AccumulatedUnknowns) { @@ -46,12 +47,12 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) { return EvalHelpers.evalStrictly(rhs, resolver, frame); } - static EvalOptionalOr create(long exprId, PlannedInterpretable lhs, PlannedInterpretable rhs) { - return new EvalOptionalOr(exprId, lhs, rhs); + static EvalOptionalOr create(CelExpr expr, PlannedInterpretable lhs, PlannedInterpretable rhs) { + return new EvalOptionalOr(expr, lhs, rhs); } - private EvalOptionalOr(long exprId, PlannedInterpretable lhs, PlannedInterpretable rhs) { - super(exprId); + private EvalOptionalOr(CelExpr expr, PlannedInterpretable lhs, PlannedInterpretable rhs) { + super(expr); this.lhs = Preconditions.checkNotNull(lhs); this.rhs = Preconditions.checkNotNull(rhs); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOrValue.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOrValue.java index 6634d60f6..b64c6d433 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOrValue.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalOrValue.java @@ -16,6 +16,7 @@ import com.google.common.base.Preconditions; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; import dev.cel.common.exceptions.CelOverloadNotFoundException; import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.GlobalResolver; @@ -27,7 +28,7 @@ final class EvalOptionalOrValue extends PlannedInterpretable { private final PlannedInterpretable rhs; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) { Object lhsValue = EvalHelpers.evalStrictly(lhs, resolver, frame); if (lhsValue instanceof AccumulatedUnknowns) { return lhsValue; @@ -46,12 +47,12 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) { } static EvalOptionalOrValue create( - long exprId, PlannedInterpretable lhs, PlannedInterpretable rhs) { - return new EvalOptionalOrValue(exprId, lhs, rhs); + CelExpr expr, PlannedInterpretable lhs, PlannedInterpretable rhs) { + return new EvalOptionalOrValue(expr, lhs, rhs); } - private EvalOptionalOrValue(long exprId, PlannedInterpretable lhs, PlannedInterpretable rhs) { - super(exprId); + private EvalOptionalOrValue(CelExpr expr, PlannedInterpretable lhs, PlannedInterpretable rhs) { + super(expr); this.lhs = Preconditions.checkNotNull(lhs); this.rhs = Preconditions.checkNotNull(rhs); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalSelectField.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalSelectField.java index 8887aa697..4122a6e8e 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalSelectField.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalOptionalSelectField.java @@ -16,6 +16,7 @@ import com.google.common.base.Preconditions; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; import dev.cel.common.values.CelValueConverter; import dev.cel.common.values.SelectableValue; import dev.cel.runtime.AccumulatedUnknowns; @@ -31,7 +32,7 @@ final class EvalOptionalSelectField extends PlannedInterpretable { private final CelValueConverter celValueConverter; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) { Object operandValue = EvalHelpers.evalStrictly(operand, resolver, frame); if (operandValue instanceof Optional) { @@ -75,21 +76,21 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) { } static EvalOptionalSelectField create( - long exprId, + CelExpr expr, PlannedInterpretable operand, String field, PlannedInterpretable selectAttribute, CelValueConverter celValueConverter) { - return new EvalOptionalSelectField(exprId, operand, field, selectAttribute, celValueConverter); + return new EvalOptionalSelectField(expr, operand, field, selectAttribute, celValueConverter); } private EvalOptionalSelectField( - long exprId, + CelExpr expr, PlannedInterpretable operand, String field, PlannedInterpretable selectAttribute, CelValueConverter celValueConverter) { - super(exprId); + super(expr); this.operand = Preconditions.checkNotNull(operand); this.field = Preconditions.checkNotNull(field); this.selectAttribute = Preconditions.checkNotNull(selectAttribute); diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalOr.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalOr.java index 62e617d9d..849b6e7b4 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalOr.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalOr.java @@ -17,6 +17,7 @@ import static dev.cel.runtime.planner.EvalHelpers.evalNonstrictly; import com.google.common.base.Preconditions; +import dev.cel.common.ast.CelExpr; import dev.cel.common.values.ErrorValue; import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.GlobalResolver; @@ -27,7 +28,7 @@ final class EvalOr extends PlannedInterpretable { private final PlannedInterpretable[] args; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) { ErrorValue errorValue = null; AccumulatedUnknowns unknowns = null; for (PlannedInterpretable arg : args) { @@ -47,7 +48,7 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) { } else { errorValue = ErrorValue.create( - arg.exprId(), + arg.expr().id(), new IllegalArgumentException( String.format("Expected boolean value, found: %s", argVal))); } @@ -64,12 +65,12 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) { return false; } - static EvalOr create(long exprId, PlannedInterpretable[] args) { - return new EvalOr(exprId, args); + static EvalOr create(CelExpr expr, PlannedInterpretable[] args) { + return new EvalOr(expr, args); } - private EvalOr(long exprId, PlannedInterpretable[] args) { - super(exprId); + private EvalOr(CelExpr expr, PlannedInterpretable[] args) { + super(expr); Preconditions.checkArgument(args.length == 2); this.args = args; } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalTestOnly.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalTestOnly.java index 30ecdbd83..b3d2563f0 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalTestOnly.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalTestOnly.java @@ -15,6 +15,7 @@ package dev.cel.runtime.planner; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.GlobalResolver; @@ -24,22 +25,22 @@ final class EvalTestOnly extends InterpretableAttribute { private final InterpretableAttribute attr; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { return attr.eval(resolver, frame); } @Override - public EvalTestOnly addQualifier(long exprId, Qualifier qualifier) { + public EvalTestOnly addQualifier(CelExpr expr, Qualifier qualifier) { PresenceTestQualifier presenceTestQualifier = PresenceTestQualifier.create(qualifier.value()); - return new EvalTestOnly(exprId(), attr.addQualifier(exprId, presenceTestQualifier)); + return new EvalTestOnly(expr(), attr.addQualifier(expr, presenceTestQualifier)); } - static EvalTestOnly create(long exprId, InterpretableAttribute attr) { - return new EvalTestOnly(exprId, attr); + static EvalTestOnly create(CelExpr expr, InterpretableAttribute attr) { + return new EvalTestOnly(expr, attr); } - private EvalTestOnly(long exprId, InterpretableAttribute attr) { - super(exprId); + private EvalTestOnly(CelExpr expr, InterpretableAttribute attr) { + super(expr); this.attr = attr; } } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java index 57834161f..867371ff1 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java @@ -17,6 +17,7 @@ import static dev.cel.runtime.planner.EvalHelpers.evalNonstrictly; import static dev.cel.runtime.planner.EvalHelpers.evalStrictly; +import dev.cel.common.ast.CelExpr; import dev.cel.common.values.CelValueConverter; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelResolvedOverload; @@ -30,7 +31,7 @@ final class EvalUnary extends PlannedInterpretable { private final CelValueConverter celValueConverter; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { Object argVal = resolvedOverload.isStrict() ? evalStrictly(arg, resolver, frame) @@ -39,21 +40,21 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval } static EvalUnary create( - long exprId, + CelExpr expr, String functionName, CelResolvedOverload resolvedOverload, PlannedInterpretable arg, CelValueConverter celValueConverter) { - return new EvalUnary(exprId, functionName, resolvedOverload, arg, celValueConverter); + return new EvalUnary(expr, functionName, resolvedOverload, arg, celValueConverter); } private EvalUnary( - long exprId, + CelExpr expr, String functionName, CelResolvedOverload resolvedOverload, PlannedInterpretable arg, CelValueConverter celValueConverter) { - super(exprId); + super(expr); this.functionName = functionName; this.resolvedOverload = resolvedOverload; this.arg = arg; diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java index fe7c6c430..4b0171b8f 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java @@ -17,6 +17,7 @@ import static dev.cel.runtime.planner.EvalHelpers.evalNonstrictly; import static dev.cel.runtime.planner.EvalHelpers.evalStrictly; +import dev.cel.common.ast.CelExpr; import dev.cel.common.values.CelValueConverter; import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.CelEvaluationException; @@ -34,7 +35,7 @@ final class EvalVarArgsCall extends PlannedInterpretable { private final CelValueConverter celValueConverter; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { Object[] argVals = new Object[args.length]; AccumulatedUnknowns unknowns = null; for (int i = 0; i < args.length; i++) { @@ -55,21 +56,21 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval } static EvalVarArgsCall create( - long exprId, + CelExpr expr, String functionName, CelResolvedOverload resolvedOverload, PlannedInterpretable[] args, CelValueConverter celValueConverter) { - return new EvalVarArgsCall(exprId, functionName, resolvedOverload, args, celValueConverter); + return new EvalVarArgsCall(expr, functionName, resolvedOverload, args, celValueConverter); } private EvalVarArgsCall( - long exprId, + CelExpr expr, String functionName, CelResolvedOverload resolvedOverload, PlannedInterpretable[] args, CelValueConverter celValueConverter) { - super(exprId); + super(expr); this.functionName = functionName; this.resolvedOverload = resolvedOverload; this.args = args; diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalZeroArity.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalZeroArity.java index 7798c8253..6e35bc22b 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalZeroArity.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalZeroArity.java @@ -14,6 +14,7 @@ package dev.cel.runtime.planner; +import dev.cel.common.ast.CelExpr; import dev.cel.common.values.CelValueConverter; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelResolvedOverload; @@ -27,24 +28,24 @@ final class EvalZeroArity extends PlannedInterpretable { private final CelValueConverter celValueConverter; @Override - public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { return EvalHelpers.dispatch(functionName, resolvedOverload, celValueConverter, EMPTY_ARRAY); } static EvalZeroArity create( - long exprId, + CelExpr expr, String functionName, CelResolvedOverload resolvedOverload, CelValueConverter celValueConverter) { - return new EvalZeroArity(exprId, functionName, resolvedOverload, celValueConverter); + return new EvalZeroArity(expr, functionName, resolvedOverload, celValueConverter); } private EvalZeroArity( - long exprId, + CelExpr expr, String functionName, CelResolvedOverload resolvedOverload, CelValueConverter celValueConverter) { - super(exprId); + super(expr); this.functionName = functionName; this.resolvedOverload = resolvedOverload; this.celValueConverter = celValueConverter; diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java b/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java index 282b7c83a..b67f5520c 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ExecutionFrame.java @@ -17,11 +17,13 @@ import dev.cel.common.CelOptions; import dev.cel.common.exceptions.CelIterationLimitExceededException; import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelEvaluationListener; import dev.cel.runtime.CelFunctionResolver; import dev.cel.runtime.CelResolvedOverload; import dev.cel.runtime.PartialVars; import java.util.Collection; import java.util.Optional; +import org.jspecify.annotations.Nullable; /** Tracks execution context within a planned program. */ final class ExecutionFrame { @@ -29,6 +31,7 @@ final class ExecutionFrame { private final int comprehensionIterationLimit; private final CelFunctionResolver functionResolver; private final PartialVars partialVars; + private final @Nullable CelEvaluationListener listener; private int iterationCount; private BlockMemoizer blockMemoizer; @@ -62,18 +65,30 @@ BlockMemoizer getBlockMemoizer() { } static ExecutionFrame create( - CelFunctionResolver functionResolver, PartialVars partialVars, CelOptions celOptions) { + CelFunctionResolver functionResolver, + CelOptions celOptions, + @Nullable PartialVars partialVars, + @Nullable CelEvaluationListener listener) { return new ExecutionFrame( - functionResolver, partialVars, celOptions.comprehensionMaxIterations()); + functionResolver, celOptions.comprehensionMaxIterations(), partialVars, listener); } Optional partialVars() { return Optional.ofNullable(partialVars); } - private ExecutionFrame(CelFunctionResolver functionResolver, PartialVars partialVars, int limit) { + @Nullable CelEvaluationListener getListener() { + return listener; + } + + private ExecutionFrame( + CelFunctionResolver functionResolver, + int limit, + @Nullable PartialVars partialVars, + @Nullable CelEvaluationListener listener) { this.comprehensionIterationLimit = limit; this.functionResolver = functionResolver; this.partialVars = partialVars; + this.listener = listener; } } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/InterpretableAttribute.java b/runtime/src/main/java/dev/cel/runtime/planner/InterpretableAttribute.java index 547380c11..9ce726f0e 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/InterpretableAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/InterpretableAttribute.java @@ -15,13 +15,14 @@ package dev.cel.runtime.planner; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; @Immutable abstract class InterpretableAttribute extends PlannedInterpretable { - abstract InterpretableAttribute addQualifier(long exprId, Qualifier qualifier); + abstract InterpretableAttribute addQualifier(CelExpr expr, Qualifier qualifier); - InterpretableAttribute(long exprId) { - super(exprId); + InterpretableAttribute(CelExpr expr) { + super(expr); } } \ No newline at end of file diff --git a/runtime/src/main/java/dev/cel/runtime/planner/PlannedInterpretable.java b/runtime/src/main/java/dev/cel/runtime/planner/PlannedInterpretable.java index 6f3a9d7ff..8fa52db97 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/PlannedInterpretable.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/PlannedInterpretable.java @@ -15,21 +15,33 @@ package dev.cel.runtime.planner; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelEvaluationListener; import dev.cel.runtime.GlobalResolver; @Immutable abstract class PlannedInterpretable { - private final long exprId; + private final CelExpr expr; /** Runs interpretation with the given activation which supplies name/value bindings. */ - abstract Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException; + final Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + Object result = evalInternal(resolver, frame); + CelEvaluationListener listener = frame.getListener(); + if (listener != null) { + listener.callback(expr, result); + } + return result; + } + + abstract Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) + throws CelEvaluationException; - long exprId() { - return exprId; + CelExpr expr() { + return expr; } - PlannedInterpretable(long exprId) { - this.exprId = exprId; + PlannedInterpretable(CelExpr expr) { + this.expr = expr; } } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/PlannedProgram.java b/runtime/src/main/java/dev/cel/runtime/planner/PlannedProgram.java index 34fc34b50..1470e4909 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/PlannedProgram.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/PlannedProgram.java @@ -17,11 +17,13 @@ import com.google.auto.value.AutoValue; import com.google.errorprone.annotations.Immutable; import dev.cel.common.CelOptions; +import dev.cel.common.annotations.Internal; import dev.cel.common.exceptions.CelRuntimeException; import dev.cel.common.values.ErrorValue; import dev.cel.runtime.Activation; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelEvaluationExceptionBuilder; +import dev.cel.runtime.CelEvaluationListener; import dev.cel.runtime.CelFunctionResolver; import dev.cel.runtime.CelResolvedOverload; import dev.cel.runtime.CelVariableResolver; @@ -32,10 +34,17 @@ import java.util.Collection; import java.util.Map; import java.util.Optional; - +import org.jspecify.annotations.Nullable; + +/** + * Internal implementation of a {@link Program} that executes a planned interpretable tree. + * + *

CEL-Java internals. Do not use. + */ +@Internal @Immutable @AutoValue -abstract class PlannedProgram implements Program { +public abstract class PlannedProgram implements Program { private static final CelFunctionResolver EMPTY_FUNCTION_RESOLVER = new CelFunctionResolver() { @@ -52,33 +61,51 @@ public Optional findOverloadMatchingArgs( } }; - abstract PlannedInterpretable interpretable(); + public abstract PlannedInterpretable interpretable(); abstract ErrorMetadata metadata(); - abstract CelOptions options(); + public abstract CelOptions options(); @Override public Object eval() throws CelEvaluationException { - return evalOrThrow(interpretable(), GlobalResolver.EMPTY, EMPTY_FUNCTION_RESOLVER, null); + return evalOrThrow( + interpretable(), + GlobalResolver.EMPTY, + EMPTY_FUNCTION_RESOLVER, + /* partialVars= */ null, + /* listener= */ null); } @Override public Object eval(Map mapValue) throws CelEvaluationException { - return evalOrThrow(interpretable(), Activation.copyOf(mapValue), EMPTY_FUNCTION_RESOLVER, null); + return evalOrThrow( + interpretable(), + Activation.copyOf(mapValue), + EMPTY_FUNCTION_RESOLVER, + /* partialVars= */ null, + /* listener= */ null); } @Override public Object eval(Map mapValue, CelFunctionResolver lateBoundFunctionResolver) throws CelEvaluationException { return evalOrThrow( - interpretable(), Activation.copyOf(mapValue), lateBoundFunctionResolver, null); + interpretable(), + Activation.copyOf(mapValue), + lateBoundFunctionResolver, + /* partialVars= */ null, + /* listener= */ null); } @Override public Object eval(CelVariableResolver resolver) throws CelEvaluationException { return evalOrThrow( - interpretable(), (name) -> resolver.find(name).orElse(null), EMPTY_FUNCTION_RESOLVER, null); + interpretable(), + (name) -> resolver.find(name).orElse(null), + EMPTY_FUNCTION_RESOLVER, + /* partialVars= */ null, + /* listener= */ null); } @Override @@ -88,7 +115,8 @@ public Object eval(CelVariableResolver resolver, CelFunctionResolver lateBoundFu interpretable(), (name) -> resolver.find(name).orElse(null), lateBoundFunctionResolver, - null); + /* partialVars= */ null, + /* listener= */ null); } @Override @@ -97,17 +125,20 @@ public Object eval(PartialVars partialVars) throws CelEvaluationException { interpretable(), (name) -> partialVars.resolver().find(name).orElse(null), EMPTY_FUNCTION_RESOLVER, - partialVars); + partialVars, + /* listener= */ null); } - private Object evalOrThrow( + public Object evalOrThrow( PlannedInterpretable interpretable, GlobalResolver resolver, CelFunctionResolver functionResolver, - PartialVars partialVars) + @Nullable PartialVars partialVars, + @Nullable CelEvaluationListener listener) throws CelEvaluationException { try { - ExecutionFrame frame = ExecutionFrame.create(functionResolver, partialVars, options()); + ExecutionFrame frame = + ExecutionFrame.create(functionResolver, options(), partialVars, listener); Object evalResult = interpretable.eval(resolver, frame); if (evalResult instanceof ErrorValue) { ErrorValue errorValue = (ErrorValue) evalResult; @@ -116,10 +147,19 @@ private Object evalOrThrow( return InterpreterUtil.maybeAdaptToCelUnknownSet(evalResult); } catch (RuntimeException e) { - throw newCelEvaluationException(interpretable.exprId(), e); + throw newCelEvaluationException(interpretable.expr().id(), e); } } + public Object trace( + GlobalResolver resolver, + CelFunctionResolver functionResolver, + PartialVars partialVars, + CelEvaluationListener listener) + throws CelEvaluationException { + return evalOrThrow(interpretable(), resolver, functionResolver, partialVars, listener); + } + private CelEvaluationException newCelEvaluationException(long exprId, Exception e) { CelEvaluationExceptionBuilder builder; if (e instanceof LocalizedEvaluationException) { diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java index a0b74fc99..affe64381 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java @@ -93,7 +93,7 @@ public Program plan(CelAbstractSyntaxTree ast) throws CelEvaluationException { private PlannedInterpretable plan(CelExpr celExpr, PlannerContext ctx) { switch (celExpr.getKind()) { case CONSTANT: - return planConstant(celExpr.id(), celExpr.constant()); + return planConstant(celExpr, celExpr.constant()); case IDENT: return planIdent(celExpr, ctx); case SELECT: @@ -123,35 +123,34 @@ private PlannedInterpretable planSelect(CelExpr celExpr, PlannerContext ctx) { if (operand instanceof EvalAttribute) { attribute = (EvalAttribute) operand; } else { - attribute = - EvalAttribute.create(celExpr.id(), attributeFactory.newRelativeAttribute(operand)); + attribute = EvalAttribute.create(celExpr, attributeFactory.newRelativeAttribute(operand)); } if (select.testOnly()) { - attribute = EvalTestOnly.create(celExpr.id(), attribute); + attribute = EvalTestOnly.create(celExpr, attribute); } Qualifier qualifier = StringQualifier.create(select.field()); - return attribute.addQualifier(celExpr.id(), qualifier); + return attribute.addQualifier(celExpr, qualifier); } - private PlannedInterpretable planConstant(long exprId, CelConstant celConstant) { + private PlannedInterpretable planConstant(CelExpr expr, CelConstant celConstant) { switch (celConstant.getKind()) { case NULL_VALUE: - return EvalConstant.create(exprId, celConstant.nullValue()); + return EvalConstant.create(expr, celConstant.nullValue()); case BOOLEAN_VALUE: - return EvalConstant.create(exprId, celConstant.booleanValue()); + return EvalConstant.create(expr, celConstant.booleanValue()); case INT64_VALUE: - return EvalConstant.create(exprId, celConstant.int64Value()); + return EvalConstant.create(expr, celConstant.int64Value()); case UINT64_VALUE: - return EvalConstant.create(exprId, celConstant.uint64Value()); + return EvalConstant.create(expr, celConstant.uint64Value()); case DOUBLE_VALUE: - return EvalConstant.create(exprId, celConstant.doubleValue()); + return EvalConstant.create(expr, celConstant.doubleValue()); case STRING_VALUE: - return EvalConstant.create(exprId, celConstant.stringValue()); + return EvalConstant.create(expr, celConstant.stringValue()); case BYTES_VALUE: - return EvalConstant.create(exprId, celConstant.bytesValue()); + return EvalConstant.create(expr, celConstant.bytesValue()); default: throw new IllegalStateException("Unsupported kind: " + celConstant.getKind()); } @@ -160,29 +159,29 @@ private PlannedInterpretable planConstant(long exprId, CelConstant celConstant) private PlannedInterpretable planIdent(CelExpr celExpr, PlannerContext ctx) { CelReference ref = ctx.referenceMap().get(celExpr.id()); if (ref != null) { - return planCheckedIdent(celExpr.id(), ref, ctx.typeMap()); + return planCheckedIdent(celExpr, ref, ctx.typeMap()); } String identName = celExpr.ident().name(); - PlannedInterpretable blockSlot = maybeInterceptBlockSlot(celExpr.id(), identName).orElse(null); + PlannedInterpretable blockSlot = maybeInterceptBlockSlot(celExpr, identName).orElse(null); if (blockSlot != null) { return blockSlot; } if (ctx.isLocalVar(identName)) { - return EvalAttribute.create(celExpr.id(), attributeFactory.newAbsoluteAttribute(identName)); + return EvalAttribute.create(celExpr, attributeFactory.newAbsoluteAttribute(identName)); } - return EvalAttribute.create(celExpr.id(), attributeFactory.newMaybeAttribute(identName)); + return EvalAttribute.create(celExpr, attributeFactory.newMaybeAttribute(identName)); } private PlannedInterpretable planCheckedIdent( - long id, CelReference identRef, ImmutableMap typeMap) { + CelExpr expr, CelReference identRef, ImmutableMap typeMap) { if (identRef.value().isPresent()) { - return planConstant(id, identRef.value().get()); + return planConstant(expr, identRef.value().get()); } - CelType type = typeMap.get(id); + CelType type = typeMap.get(expr.id()); if (type.kind().equals(CelKind.TYPE)) { TypeType identType = typeProvider @@ -198,19 +197,19 @@ private PlannedInterpretable planCheckedIdent( () -> new NoSuchElementException( "Reference to an undefined type: " + identRef.name())); - return EvalConstant.create(id, identType); + return EvalConstant.create(expr, identType); } String identName = identRef.name(); - PlannedInterpretable blockSlot = maybeInterceptBlockSlot(id, identName).orElse(null); + PlannedInterpretable blockSlot = maybeInterceptBlockSlot(expr, identName).orElse(null); if (blockSlot != null) { return blockSlot; } - return EvalAttribute.create(id, attributeFactory.newAbsoluteAttribute(identRef.name())); + return EvalAttribute.create(expr, attributeFactory.newAbsoluteAttribute(identRef.name())); } - private Optional maybeInterceptBlockSlot(long id, String identName) { + private Optional maybeInterceptBlockSlot(CelExpr expr, String identName) { if (!identName.startsWith("@index")) { return Optional.empty(); } @@ -222,7 +221,7 @@ private Optional maybeInterceptBlockSlot(long id, String i if (slotIndex < 0) { throw new IllegalArgumentException("Negative block slot index: " + identName); } - return Optional.of(EvalBlock.EvalBlockSlot.create(id, slotIndex)); + return Optional.of(EvalBlock.EvalBlockSlot.create(expr, slotIndex)); } catch (NumberFormatException e) { throw new IllegalArgumentException("Invalid block slot index: " + identName, e); } @@ -260,11 +259,11 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { if (operator != null) { switch (operator) { case LOGICAL_OR: - return EvalOr.create(expr.id(), evaluatedArgs); + return EvalOr.create(expr, evaluatedArgs); case LOGICAL_AND: - return EvalAnd.create(expr.id(), evaluatedArgs); + return EvalAnd.create(expr, evaluatedArgs); case CONDITIONAL: - return EvalConditional.create(expr.id(), evaluatedArgs); + return EvalConditional.create(expr, evaluatedArgs); default: // fall-through } @@ -303,18 +302,18 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { } return EvalLateBoundCall.create( - expr.id(), functionName, overloadIds, evaluatedArgs, celValueConverter); + expr, functionName, overloadIds, evaluatedArgs, celValueConverter); } switch (argCount) { case 0: - return EvalZeroArity.create(expr.id(), functionName, resolvedOverload, celValueConverter); + return EvalZeroArity.create(expr, functionName, resolvedOverload, celValueConverter); case 1: return EvalUnary.create( - expr.id(), functionName, resolvedOverload, evaluatedArgs[0], celValueConverter); + expr, functionName, resolvedOverload, evaluatedArgs[0], celValueConverter); case 2: return EvalBinary.create( - expr.id(), + expr, functionName, resolvedOverload, evaluatedArgs[0], @@ -322,7 +321,7 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { celValueConverter); default: return EvalVarArgsCall.create( - expr.id(), functionName, resolvedOverload, evaluatedArgs, celValueConverter); + expr, functionName, resolvedOverload, evaluatedArgs, celValueConverter); } } @@ -345,7 +344,7 @@ private Optional maybeInterceptBlockCall( slotExprs[i] = plan(exprList.elements().get(i), ctx); } PlannedInterpretable resultExpr = plan(blockCall.args().get(1), ctx); - return Optional.of(EvalBlock.create(expr.id(), slotExprs, resultExpr)); + return Optional.of(EvalBlock.create(expr, slotExprs, resultExpr)); } /** @@ -368,14 +367,13 @@ private Optional maybeInterceptOptionalCalls( switch (functionName) { case "or": if (overloadId.isEmpty() || overloadId.equals("optional_or_optional")) { - return Optional.of(EvalOptionalOr.create(expr.id(), evaluatedArgs[0], evaluatedArgs[1])); + return Optional.of(EvalOptionalOr.create(expr, evaluatedArgs[0], evaluatedArgs[1])); } return Optional.empty(); case "orValue": if (overloadId.isEmpty() || overloadId.equals("optional_orValue_value")) { - return Optional.of( - EvalOptionalOrValue.create(expr.id(), evaluatedArgs[0], evaluatedArgs[1])); + return Optional.of(EvalOptionalOrValue.create(expr, evaluatedArgs[0], evaluatedArgs[1])); } return Optional.empty(); @@ -390,15 +388,14 @@ private Optional maybeInterceptOptionalCalls( attribute = (EvalAttribute) evaluatedArgs[0]; } else { attribute = - EvalAttribute.create( - expr.id(), attributeFactory.newRelativeAttribute(evaluatedArgs[0])); + EvalAttribute.create(expr, attributeFactory.newRelativeAttribute(evaluatedArgs[0])); } Qualifier qualifier = StringQualifier.create(field); - PlannedInterpretable selectAttribute = attribute.addQualifier(expr.id(), qualifier); + PlannedInterpretable selectAttribute = attribute.addQualifier(expr, qualifier); return Optional.of( EvalOptionalSelectField.create( - expr.id(), evaluatedArgs[0], field, selectAttribute, celValueConverter)); + expr, evaluatedArgs[0], field, selectAttribute, celValueConverter)); } return Optional.empty(); @@ -420,8 +417,7 @@ private PlannedInterpretable planCreateStruct(CelExpr celExpr, PlannerContext ct isOptional[i] = entry.optionalEntry(); } - return EvalCreateStruct.create( - celExpr.id(), valueProvider, structType, keys, values, isOptional); + return EvalCreateStruct.create(celExpr, valueProvider, structType, keys, values, isOptional); } private PlannedInterpretable planCreateList(CelExpr celExpr, PlannerContext ctx) { @@ -438,7 +434,7 @@ private PlannedInterpretable planCreateList(CelExpr celExpr, PlannerContext ctx) isOptional[optionalIndex] = true; } - return EvalCreateList.create(celExpr.id(), values, isOptional); + return EvalCreateList.create(celExpr, values, isOptional); } private PlannedInterpretable planCreateMap(CelExpr celExpr, PlannerContext ctx) { @@ -456,7 +452,7 @@ private PlannedInterpretable planCreateMap(CelExpr celExpr, PlannerContext ctx) isOptional[i] = entry.optionalEntry(); } - return EvalCreateMap.create(celExpr.id(), keys, values, isOptional); + return EvalCreateMap.create(celExpr, keys, values, isOptional); } private PlannedInterpretable planComprehension(CelExpr expr, PlannerContext ctx) { @@ -477,7 +473,7 @@ private PlannedInterpretable planComprehension(CelExpr expr, PlannerContext ctx) ctx.popLocalVars(comprehension.accuVar()); return EvalFold.create( - expr.id(), + expr, comprehension.accuVar(), accuInit, comprehension.iterVar(), diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index 7cd24f040..9461e5e6a 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -89,6 +89,7 @@ java_library( "//runtime/standard:not_strictly_false", "//runtime/standard:standard_overload", "//runtime/standard:subtract", + "//testing:cel_runtime_flavor", "//testing/protos:message_with_enum_cel_java_proto", "//testing/protos:message_with_enum_java_proto", "//testing/protos:multi_file_cel_java_proto", diff --git a/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java b/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java index c7f142602..3e29a00db 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java @@ -16,6 +16,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import static org.junit.Assume.assumeTrue; import com.google.api.expr.v1alpha1.Constant; import com.google.api.expr.v1alpha1.Expr; @@ -51,6 +52,7 @@ import dev.cel.extensions.CelExtensions; import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparserFactory; +import dev.cel.testing.CelRuntimeFlavor; import java.util.List; import java.util.Map; import java.util.Optional; @@ -61,6 +63,8 @@ @RunWith(TestParameterInjector.class) public class CelRuntimeTest { + @TestParameter private CelRuntimeFlavor runtimeFlavor; + @Test public void evaluate_anyPackedEqualityUsingProtoDifferencer_success() throws Exception { Cel cel = @@ -273,7 +277,8 @@ public void trace_callExpr_identifyFalseBranch() throws Exception { } }; Cel cel = - CelFactory.standardCelBuilder() + runtimeFlavor + .builder() .addVar("a", SimpleType.INT) .addVar("b", SimpleType.INT) .addVar("c", SimpleType.INT) @@ -297,7 +302,7 @@ public void trace_constant() throws Exception { assertThat(res).isEqualTo("hello world"); assertThat(expr.constant().getKind()).isEqualTo(CelConstant.Kind.STRING_VALUE); }; - Cel cel = CelFactory.standardCelBuilder().build(); + Cel cel = runtimeFlavor.builder().build(); CelAbstractSyntaxTree ast = cel.compile("'hello world'").getAst(); String result = (String) cel.createProgram(ast).trace(listener); @@ -312,7 +317,7 @@ public void trace_ident() throws Exception { assertThat(res).isEqualTo("test"); assertThat(expr.ident().name()).isEqualTo("a"); }; - Cel cel = CelFactory.standardCelBuilder().addVar("a", SimpleType.STRING).build(); + Cel cel = runtimeFlavor.builder().addVar("a", SimpleType.STRING).build(); CelAbstractSyntaxTree ast = cel.compile("a").getAst(); String result = (String) cel.createProgram(ast).trace(ImmutableMap.of("a", "test"), listener); @@ -330,7 +335,8 @@ public void trace_select() throws Exception { } }; Cel cel = - CelFactory.standardCelBuilder() + runtimeFlavor + .builder() .addMessageTypes(TestAllTypes.getDescriptor()) .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) .build(); @@ -350,7 +356,8 @@ public void trace_struct() throws Exception { .isEqualTo("cel.expr.conformance.proto3.TestAllTypes"); }; Cel cel = - CelFactory.standardCelBuilder() + runtimeFlavor + .builder() .addMessageTypes(TestAllTypes.getDescriptor()) .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) .build(); @@ -371,7 +378,7 @@ public void trace_list() throws Exception { assertThat(expr.list().elements()).hasSize(3); } }; - Cel cel = CelFactory.standardCelBuilder().build(); + Cel cel = runtimeFlavor.builder().build(); CelAbstractSyntaxTree ast = cel.compile("[1, 2, 3]").getAst(); List result = (List) cel.createProgram(ast).trace(listener); @@ -389,7 +396,7 @@ public void trace_map() throws Exception { assertThat(expr.map().entries()).hasSize(1); } }; - Cel cel = CelFactory.standardCelBuilder().build(); + Cel cel = runtimeFlavor.builder().build(); CelAbstractSyntaxTree ast = cel.compile("{1: 'a'}").getAst(); Map result = (Map) cel.createProgram(ast).trace(listener); @@ -405,8 +412,7 @@ public void trace_comprehension() throws Exception { assertThat(expr.comprehension().iterVar()).isEqualTo("i"); } }; - Cel cel = - CelFactory.standardCelBuilder().setStandardMacros(CelStandardMacro.STANDARD_MACROS).build(); + Cel cel = runtimeFlavor.builder().setStandardMacros(CelStandardMacro.STANDARD_MACROS).build(); CelAbstractSyntaxTree ast = cel.compile("[true].exists(i, i)").getAst(); boolean result = (boolean) cel.createProgram(ast).trace(listener); @@ -422,7 +428,8 @@ public void trace_withMessageInput() throws Exception { assertThat(expr.ident().name()).isEqualTo("single_int64"); }; Cel cel = - CelFactory.standardCelBuilder() + runtimeFlavor + .builder() .addMessageTypes(TestAllTypes.getDescriptor()) .addVar("single_int64", SimpleType.INT) .build(); @@ -444,7 +451,8 @@ public void trace_withVariableResolver() throws Exception { assertThat(expr.ident().name()).isEqualTo("variable"); }; Cel cel = - CelFactory.standardCelBuilder() + runtimeFlavor + .builder() .addMessageTypes(TestAllTypes.getDescriptor()) .addVar("variable", SimpleType.STRING) .build(); @@ -461,6 +469,8 @@ public void trace_withVariableResolver() throws Exception { public void trace_shortCircuitingDisabled_logicalAndAllBranchesVisited( @TestParameter boolean first, @TestParameter boolean second, @TestParameter boolean third) throws Exception { + // TODO: Implement exhaustive eval + assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); String expression = String.format("%s && %s && %s", first, second, third); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = @@ -470,12 +480,22 @@ public void trace_shortCircuitingDisabled_logicalAndAllBranchesVisited( } }; Cel celWithShortCircuit = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableShortCircuiting(true).build()) + runtimeFlavor + .builder() + .setOptions( + CelOptions.current() + .enableShortCircuiting(true) + .enableHeterogeneousNumericComparisons(true) + .build()) .build(); Cel cel = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableShortCircuiting(false).build()) + runtimeFlavor + .builder() + .setOptions( + CelOptions.current() + .enableShortCircuiting(false) + .enableHeterogeneousNumericComparisons(true) + .build()) .build(); CelAbstractSyntaxTree ast = cel.compile(expression).getAst(); @@ -496,6 +516,8 @@ public void trace_shortCircuitingDisabled_logicalAndAllBranchesVisited( @TestParameters("{source: 'x && false && false'}") public void trace_shortCircuitingDisabledWithUnknownsAndedToFalse_returnsFalse(String source) throws Exception { + // TODO: Implement exhaustive eval + assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -509,9 +531,14 @@ public void trace_shortCircuitingDisabledWithUnknownsAndedToFalse_returnsFalse(S } }; Cel cel = - CelFactory.standardCelBuilder() + runtimeFlavor + .builder() .addVar("x", SimpleType.BOOL) - .setOptions(CelOptions.current().enableShortCircuiting(false).build()) + .setOptions( + CelOptions.current() + .enableShortCircuiting(false) + .enableHeterogeneousNumericComparisons(true) + .build()) .build(); CelAbstractSyntaxTree ast = cel.compile(source).getAst(); @@ -527,6 +554,8 @@ public void trace_shortCircuitingDisabledWithUnknownsAndedToFalse_returnsFalse(S @TestParameters("{source: 'x && true && true'}") public void trace_shortCircuitingDisabledWithUnknownAndedToTrue_returnsUnknown(String source) throws Exception { + // TODO: Implement exhaustive eval + assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -536,9 +565,14 @@ public void trace_shortCircuitingDisabledWithUnknownAndedToTrue_returnsUnknown(S } }; Cel cel = - CelFactory.standardCelBuilder() + runtimeFlavor + .builder() .addVar("x", SimpleType.BOOL) - .setOptions(CelOptions.current().enableShortCircuiting(false).build()) + .setOptions( + CelOptions.current() + .enableShortCircuiting(false) + .enableHeterogeneousNumericComparisons(true) + .build()) .build(); CelAbstractSyntaxTree ast = cel.compile(source).getAst(); @@ -552,6 +586,8 @@ public void trace_shortCircuitingDisabledWithUnknownAndedToTrue_returnsUnknown(S public void trace_shortCircuitingDisabled_logicalOrAllBranchesVisited( @TestParameter boolean first, @TestParameter boolean second, @TestParameter boolean third) throws Exception { + // TODO: Implement exhaustive eval + assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); String expression = String.format("%s || %s || %s", first, second, third); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = @@ -561,12 +597,22 @@ public void trace_shortCircuitingDisabled_logicalOrAllBranchesVisited( } }; Cel celWithShortCircuit = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableShortCircuiting(true).build()) + runtimeFlavor + .builder() + .setOptions( + CelOptions.current() + .enableShortCircuiting(true) + .enableHeterogeneousNumericComparisons(true) + .build()) .build(); Cel cel = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableShortCircuiting(false).build()) + runtimeFlavor + .builder() + .setOptions( + CelOptions.current() + .enableShortCircuiting(false) + .enableHeterogeneousNumericComparisons(true) + .build()) .build(); CelAbstractSyntaxTree ast = cel.compile(expression).getAst(); @@ -587,6 +633,8 @@ public void trace_shortCircuitingDisabled_logicalOrAllBranchesVisited( @TestParameters("{source: 'x || false || false'}") public void trace_shortCircuitingDisabledWithUnknownsOredToFalse_returnsUnknown(String source) throws Exception { + // TODO: Implement exhaustive eval + assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -596,9 +644,14 @@ public void trace_shortCircuitingDisabledWithUnknownsOredToFalse_returnsUnknown( } }; Cel cel = - CelFactory.standardCelBuilder() + runtimeFlavor + .builder() .addVar("x", SimpleType.BOOL) - .setOptions(CelOptions.current().enableShortCircuiting(false).build()) + .setOptions( + CelOptions.current() + .enableShortCircuiting(false) + .enableHeterogeneousNumericComparisons(true) + .build()) .build(); CelAbstractSyntaxTree ast = cel.compile(source).getAst(); @@ -614,6 +667,8 @@ public void trace_shortCircuitingDisabledWithUnknownsOredToFalse_returnsUnknown( @TestParameters("{source: 'x || true || true'}") public void trace_shortCircuitingDisabledWithUnknownOredToTrue_returnsTrue(String source) throws Exception { + // TODO: Implement exhaustive eval + assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -627,9 +682,14 @@ public void trace_shortCircuitingDisabledWithUnknownOredToTrue_returnsTrue(Strin } }; Cel cel = - CelFactory.standardCelBuilder() + runtimeFlavor + .builder() .addVar("x", SimpleType.BOOL) - .setOptions(CelOptions.current().enableShortCircuiting(false).build()) + .setOptions( + CelOptions.current() + .enableShortCircuiting(false) + .enableHeterogeneousNumericComparisons(true) + .build()) .build(); CelAbstractSyntaxTree ast = cel.compile(source).getAst(); @@ -641,6 +701,7 @@ public void trace_shortCircuitingDisabledWithUnknownOredToTrue_returnsTrue(Strin @Test public void trace_shortCircuitingDisabled_ternaryAllBranchesVisited() throws Exception { + assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -649,8 +710,13 @@ public void trace_shortCircuitingDisabled_ternaryAllBranchesVisited() throws Exc } }; Cel cel = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableShortCircuiting(false).build()) + runtimeFlavor + .builder() + .setOptions( + CelOptions.current() + .enableShortCircuiting(false) + .enableHeterogeneousNumericComparisons(true) + .build()) .build(); CelAbstractSyntaxTree ast = cel.compile("true ? false : true").getAst(); @@ -665,6 +731,8 @@ public void trace_shortCircuitingDisabled_ternaryAllBranchesVisited() throws Exc @TestParameters("{source: 'true ? x : false'}") @TestParameters("{source: 'x ? true : false'}") public void trace_shortCircuitingDisabled_ternaryWithUnknowns(String source) throws Exception { + // TODO: Implement exhaustive eval + assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -674,9 +742,14 @@ public void trace_shortCircuitingDisabled_ternaryWithUnknowns(String source) thr } }; Cel cel = - CelFactory.standardCelBuilder() + runtimeFlavor + .builder() .addVar("x", SimpleType.BOOL) - .setOptions(CelOptions.current().enableShortCircuiting(false).build()) + .setOptions( + CelOptions.current() + .enableShortCircuiting(false) + .enableHeterogeneousNumericComparisons(true) + .build()) .build(); CelAbstractSyntaxTree ast = cel.compile(source).getAst(); @@ -697,6 +770,8 @@ public void trace_shortCircuitingDisabled_ternaryWithUnknowns(String source) thr "{expression: 'true ? true : (1 / 0) > 2', firstVisited: true, secondVisited: true}") public void trace_shortCircuitingDisabled_ternaryWithError( String expression, boolean firstVisited, boolean secondVisited) throws Exception { + // TODO: Implement exhaustive eval + assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -705,12 +780,22 @@ public void trace_shortCircuitingDisabled_ternaryWithError( } }; Cel celWithShortCircuit = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableShortCircuiting(true).build()) + runtimeFlavor + .builder() + .setOptions( + CelOptions.current() + .enableShortCircuiting(true) + .enableHeterogeneousNumericComparisons(true) + .build()) .build(); Cel cel = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableShortCircuiting(false).build()) + runtimeFlavor + .builder() + .setOptions( + CelOptions.current() + .enableShortCircuiting(false) + .enableHeterogeneousNumericComparisons(true) + .build()) .build(); CelAbstractSyntaxTree ast = cel.compile(expression).getAst(); From 1e1d8ea0c4db9abe689d9b8327b8ace42c496f4c Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 22 Apr 2026 20:11:32 -0700 Subject: [PATCH 48/66] Implement exhaustive eval for planner PiperOrigin-RevId: 904185953 --- .../main/java/dev/cel/runtime/CelRuntime.java | 8 + .../java/dev/cel/runtime/CelRuntimeImpl.java | 11 ++ .../java/dev/cel/runtime/ProgramImpl.java | 9 + .../java/dev/cel/runtime/planner/BUILD.bazel | 46 +++++ .../runtime/planner/EvalExhaustiveAnd.java | 92 +++++++++ .../planner/EvalExhaustiveConditional.java | 68 +++++++ .../cel/runtime/planner/EvalExhaustiveOr.java | 92 +++++++++ .../runtime/planner/PlannedInterpretable.java | 3 +- .../cel/runtime/planner/ProgramPlanner.java | 12 +- .../src/test/java/dev/cel/runtime/BUILD.bazel | 1 + .../java/dev/cel/runtime/CelRuntimeTest.java | 174 +++++++++++++++--- 11 files changed, 485 insertions(+), 31 deletions(-) create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/EvalExhaustiveAnd.java create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/EvalExhaustiveConditional.java create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/EvalExhaustiveOr.java diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntime.java b/runtime/src/main/java/dev/cel/runtime/CelRuntime.java index 416bca132..1e7fdcac8 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntime.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntime.java @@ -90,6 +90,14 @@ Object trace( CelEvaluationListener listener) throws CelEvaluationException; + /** + * Trace evaluates a compiled program using {@code partialVars} as the source of input variables + * and unknown attribute patterns. The listener is invoked as evaluation progresses through the + * AST. + */ + Object trace(PartialVars partialVars, CelEvaluationListener listener) + throws CelEvaluationException; + /** * Advance evaluation based on the current unknown context. * diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java index 4cc738b6d..ed203d612 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java @@ -222,6 +222,17 @@ public Object trace( .trace(Activation.copyOf(mapValue), lateBoundFunctionResolver, null, listener); } + @Override + public Object trace(PartialVars partialVars, CelEvaluationListener listener) + throws CelEvaluationException { + return ((PlannedProgram) program) + .trace( + (name) -> partialVars.resolver().find(name).orElse(null), + EMPTY_FUNCTION_RESOLVER, + partialVars, + listener); + } + @Override public Object advanceEvaluation(UnknownContext context) throws CelEvaluationException { throw new UnsupportedOperationException("Unsupported operation."); diff --git a/runtime/src/main/java/dev/cel/runtime/ProgramImpl.java b/runtime/src/main/java/dev/cel/runtime/ProgramImpl.java index c9f4d083b..2543a9525 100644 --- a/runtime/src/main/java/dev/cel/runtime/ProgramImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/ProgramImpl.java @@ -110,6 +110,15 @@ public Object trace( return evalInternal(Activation.copyOf(mapValue), lateBoundFunctionResolver, listener); } + @Override + public Object trace(PartialVars partialVars, CelEvaluationListener listener) + throws CelEvaluationException { + return evalInternal( + UnknownContext.create(partialVars.resolver(), partialVars.unknowns()), + /* lateBoundFunctionResolver= */ Optional.empty(), + Optional.of(listener)); + } + @Override public Object advanceEvaluation(UnknownContext context) throws CelEvaluationException { return evalInternal(context, Optional.empty(), Optional.empty()); diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel index c13d5857f..801e56d73 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel @@ -24,6 +24,9 @@ java_library( ":eval_create_list", ":eval_create_map", ":eval_create_struct", + ":eval_exhaustive_and", + ":eval_exhaustive_conditional", + ":eval_exhaustive_or", ":eval_fold", ":eval_late_bound_call", ":eval_optional_or", @@ -446,6 +449,7 @@ java_library( "//runtime:evaluation_listener", "//runtime:function_resolver", "//runtime:interpretable", + "//runtime:interpreter_util", "//runtime:partial_vars", "//runtime:resolved_overload", "@maven//:com_google_errorprone_error_prone_annotations", @@ -498,6 +502,48 @@ java_library( ], ) +java_library( + name = "eval_exhaustive_and", + srcs = ["EvalExhaustiveAnd.java"], + deps = [ + ":eval_helpers", + ":planned_interpretable", + "//common/ast", + "//common/values", + "//runtime:accumulated_unknowns", + "//runtime:interpretable", + "@maven//:com_google_errorprone_error_prone_annotations", + ], +) + +java_library( + name = "eval_exhaustive_or", + srcs = ["EvalExhaustiveOr.java"], + deps = [ + ":eval_helpers", + ":planned_interpretable", + "//common/ast", + "//common/values", + "//runtime:accumulated_unknowns", + "//runtime:interpretable", + "@maven//:com_google_errorprone_error_prone_annotations", + ], +) + +java_library( + name = "eval_exhaustive_conditional", + srcs = ["EvalExhaustiveConditional.java"], + deps = [ + ":eval_helpers", + ":planned_interpretable", + "//common/ast", + "//runtime:accumulated_unknowns", + "//runtime:evaluation_exception", + "//runtime:interpretable", + "@maven//:com_google_errorprone_error_prone_annotations", + ], +) + java_library( name = "eval_block", srcs = ["EvalBlock.java"], diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalExhaustiveAnd.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalExhaustiveAnd.java new file mode 100644 index 000000000..ac3d07200 --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalExhaustiveAnd.java @@ -0,0 +1,92 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import static dev.cel.runtime.planner.EvalHelpers.evalNonstrictly; + +import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.values.ErrorValue; +import dev.cel.runtime.AccumulatedUnknowns; +import dev.cel.runtime.GlobalResolver; + +/** + * Implementation of logical AND with exhaustive evaluation (non-short-circuiting). + * + *

It evaluates all arguments, but prioritizes a false result over unknowns and errors to + * maintain semantic consistency with short-circuiting evaluation. + */ +@Immutable +final class EvalExhaustiveAnd extends PlannedInterpretable { + + @SuppressWarnings("Immutable") + private final PlannedInterpretable[] args; + + @Override + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) { + AccumulatedUnknowns accumulatedUnknowns = null; + ErrorValue errorValue = null; + boolean hasFalse = false; + + for (PlannedInterpretable arg : args) { + Object argVal = evalNonstrictly(arg, resolver, frame); + if (argVal instanceof Boolean) { + if (!((boolean) argVal)) { + hasFalse = true; + } + } + + // If we already encountered a false, we do not need to accumulate unknowns or errors + // from subsequent terms because the final result will be false anyway. + if (hasFalse) { + continue; + } + + if (argVal instanceof AccumulatedUnknowns) { + accumulatedUnknowns = + accumulatedUnknowns == null + ? (AccumulatedUnknowns) argVal + : accumulatedUnknowns.merge((AccumulatedUnknowns) argVal); + } else if (argVal instanceof ErrorValue) { + if (errorValue == null) { + errorValue = (ErrorValue) argVal; + } + } + } + + if (hasFalse) { + return false; + } + + if (accumulatedUnknowns != null) { + return accumulatedUnknowns; + } + + if (errorValue != null) { + return errorValue; + } + + return true; + } + + static EvalExhaustiveAnd create(CelExpr expr, PlannedInterpretable[] args) { + return new EvalExhaustiveAnd(expr, args); + } + + private EvalExhaustiveAnd(CelExpr expr, PlannedInterpretable[] args) { + super(expr); + this.args = args; + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalExhaustiveConditional.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalExhaustiveConditional.java new file mode 100644 index 000000000..01e242c0f --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalExhaustiveConditional.java @@ -0,0 +1,68 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import static dev.cel.runtime.planner.EvalHelpers.evalNonstrictly; + +import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; +import dev.cel.runtime.AccumulatedUnknowns; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.GlobalResolver; + +/** + * Implementation of conditional operator (ternary) with exhaustive evaluation + * (non-short-circuiting). + * + *

It evaluates all three arguments (condition, truthy, and falsy branches) but returns the + * result based on the condition, maintaining semantic consistency with short-circuiting evaluation. + */ +@Immutable +final class EvalExhaustiveConditional extends PlannedInterpretable { + + @SuppressWarnings("Immutable") + private final PlannedInterpretable[] args; + + @Override + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) throws CelEvaluationException { + PlannedInterpretable condition = args[0]; + PlannedInterpretable truthy = args[1]; + PlannedInterpretable falsy = args[2]; + + Object condResult = condition.eval(resolver, frame); + Object truthyVal = evalNonstrictly(truthy, resolver, frame); + Object falsyVal = evalNonstrictly(falsy, resolver, frame); + + if (condResult instanceof AccumulatedUnknowns) { + return condResult; + } + + if (!(condResult instanceof Boolean)) { + throw new IllegalArgumentException( + String.format("Expected boolean value, found :%s", condResult)); + } + + return (boolean) condResult ? truthyVal : falsyVal; + } + + static EvalExhaustiveConditional create(CelExpr expr, PlannedInterpretable[] args) { + return new EvalExhaustiveConditional(expr, args); + } + + private EvalExhaustiveConditional(CelExpr expr, PlannedInterpretable[] args) { + super(expr); + this.args = args; + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalExhaustiveOr.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalExhaustiveOr.java new file mode 100644 index 000000000..07164f8c7 --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalExhaustiveOr.java @@ -0,0 +1,92 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import static dev.cel.runtime.planner.EvalHelpers.evalNonstrictly; + +import com.google.errorprone.annotations.Immutable; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.values.ErrorValue; +import dev.cel.runtime.AccumulatedUnknowns; +import dev.cel.runtime.GlobalResolver; + +/** + * Implementation of logical OR with exhaustive evaluation (non-short-circuiting). + * + *

It evaluates all arguments, but prioritizes a true result over unknowns and errors to maintain + * semantic consistency with short-circuiting evaluation. + */ +@Immutable +final class EvalExhaustiveOr extends PlannedInterpretable { + + @SuppressWarnings("Immutable") + private final PlannedInterpretable[] args; + + @Override + Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) { + AccumulatedUnknowns accumulatedUnknowns = null; + ErrorValue errorValue = null; + boolean hasTrue = false; + + for (PlannedInterpretable arg : args) { + Object argVal = evalNonstrictly(arg, resolver, frame); + if (argVal instanceof Boolean) { + if ((boolean) argVal) { + hasTrue = true; + } + } + + // If we already encountered a true, we do not need to accumulate unknowns or errors + // from subsequent terms because the final result will be true anyway. + if (hasTrue) { + continue; + } + + if (argVal instanceof AccumulatedUnknowns) { + accumulatedUnknowns = + accumulatedUnknowns == null + ? (AccumulatedUnknowns) argVal + : accumulatedUnknowns.merge((AccumulatedUnknowns) argVal); + } else if (argVal instanceof ErrorValue) { + if (errorValue == null) { + errorValue = (ErrorValue) argVal; + } + } + } + + if (hasTrue) { + return true; + } + + if (accumulatedUnknowns != null) { + return accumulatedUnknowns; + } + + if (errorValue != null) { + return errorValue; + } + + return false; + } + + static EvalExhaustiveOr create(CelExpr expr, PlannedInterpretable[] args) { + return new EvalExhaustiveOr(expr, args); + } + + private EvalExhaustiveOr(CelExpr expr, PlannedInterpretable[] args) { + super(expr); + this.args = args; + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/PlannedInterpretable.java b/runtime/src/main/java/dev/cel/runtime/planner/PlannedInterpretable.java index 8fa52db97..6bdeaf1df 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/PlannedInterpretable.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/PlannedInterpretable.java @@ -19,6 +19,7 @@ import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelEvaluationListener; import dev.cel.runtime.GlobalResolver; +import dev.cel.runtime.InterpreterUtil; @Immutable abstract class PlannedInterpretable { @@ -29,7 +30,7 @@ final Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEvalu Object result = evalInternal(resolver, frame); CelEvaluationListener listener = frame.getListener(); if (listener != null) { - listener.callback(expr, result); + listener.callback(expr, InterpreterUtil.maybeAdaptToCelUnknownSet(result)); } return result; } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java index affe64381..e38d08f8f 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java @@ -259,11 +259,17 @@ private PlannedInterpretable planCall(CelExpr expr, PlannerContext ctx) { if (operator != null) { switch (operator) { case LOGICAL_OR: - return EvalOr.create(expr, evaluatedArgs); + return options.enableShortCircuiting() + ? EvalOr.create(expr, evaluatedArgs) + : EvalExhaustiveOr.create(expr, evaluatedArgs); case LOGICAL_AND: - return EvalAnd.create(expr, evaluatedArgs); + return options.enableShortCircuiting() + ? EvalAnd.create(expr, evaluatedArgs) + : EvalExhaustiveAnd.create(expr, evaluatedArgs); case CONDITIONAL: - return EvalConditional.create(expr, evaluatedArgs); + return options.enableShortCircuiting() + ? EvalConditional.create(expr, evaluatedArgs) + : EvalExhaustiveConditional.create(expr, evaluatedArgs); default: // fall-through } diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index 9461e5e6a..3200a80e0 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -75,6 +75,7 @@ java_library( "//runtime:late_function_binding", "//runtime:lite_runtime", "//runtime:lite_runtime_factory", + "//runtime:partial_vars", "//runtime:proto_message_activation_factory", "//runtime:proto_message_runtime_equality", "//runtime:proto_message_runtime_helpers", diff --git a/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java b/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java index 3e29a00db..13d5dd550 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelRuntimeTest.java @@ -16,8 +16,8 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; -import static org.junit.Assume.assumeTrue; +import com.google.api.expr.v1alpha1.CheckedExpr; import com.google.api.expr.v1alpha1.Constant; import com.google.api.expr.v1alpha1.Expr; import com.google.api.expr.v1alpha1.Type.PrimitiveType; @@ -36,7 +36,10 @@ import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; +import dev.cel.common.CelErrorCode; +import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; +import dev.cel.common.CelOverloadDecl; import dev.cel.common.CelProtoV1Alpha1AbstractSyntaxTree; import dev.cel.common.CelSource; import dev.cel.common.ast.CelConstant; @@ -104,8 +107,8 @@ public void evaluate_anyPackedEqualityUsingProtoDifferencer_success() throws Exc public void evaluate_v1alpha1CheckedExpr() throws Exception { // Note: v1alpha1 proto support exists only to help migrate existing consumers. // New users of CEL should use the canonical protos instead (I.E: dev.cel.expr) - com.google.api.expr.v1alpha1.CheckedExpr checkedExpr = - com.google.api.expr.v1alpha1.CheckedExpr.newBuilder() + CheckedExpr checkedExpr = + CheckedExpr.newBuilder() .setExpr( Expr.newBuilder() .setId(1) @@ -469,8 +472,6 @@ public void trace_withVariableResolver() throws Exception { public void trace_shortCircuitingDisabled_logicalAndAllBranchesVisited( @TestParameter boolean first, @TestParameter boolean second, @TestParameter boolean third) throws Exception { - // TODO: Implement exhaustive eval - assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); String expression = String.format("%s && %s && %s", first, second, third); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = @@ -516,8 +517,6 @@ public void trace_shortCircuitingDisabled_logicalAndAllBranchesVisited( @TestParameters("{source: 'x && false && false'}") public void trace_shortCircuitingDisabledWithUnknownsAndedToFalse_returnsFalse(String source) throws Exception { - // TODO: Implement exhaustive eval - assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -542,7 +541,8 @@ public void trace_shortCircuitingDisabledWithUnknownsAndedToFalse_returnsFalse(S .build(); CelAbstractSyntaxTree ast = cel.compile(source).getAst(); - boolean result = (boolean) cel.createProgram(ast).trace(listener); + PartialVars partialVars = PartialVars.of(CelAttributePattern.create("x")); + boolean result = (boolean) cel.createProgram(ast).trace(partialVars, listener); assertThat(result).isFalse(); assertThat(branchResults.build()).containsExactly(false, false, "x"); @@ -554,8 +554,6 @@ public void trace_shortCircuitingDisabledWithUnknownsAndedToFalse_returnsFalse(S @TestParameters("{source: 'x && true && true'}") public void trace_shortCircuitingDisabledWithUnknownAndedToTrue_returnsUnknown(String source) throws Exception { - // TODO: Implement exhaustive eval - assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -576,7 +574,8 @@ public void trace_shortCircuitingDisabledWithUnknownAndedToTrue_returnsUnknown(S .build(); CelAbstractSyntaxTree ast = cel.compile(source).getAst(); - Object unknownResult = cel.createProgram(ast).trace(listener); + PartialVars partialVars = PartialVars.of(CelAttributePattern.create("x")); + Object unknownResult = cel.createProgram(ast).trace(partialVars, listener); assertThat(InterpreterUtil.isUnknown(unknownResult)).isTrue(); assertThat(branchResults.build()).containsExactly(true, true, unknownResult); @@ -586,8 +585,6 @@ public void trace_shortCircuitingDisabledWithUnknownAndedToTrue_returnsUnknown(S public void trace_shortCircuitingDisabled_logicalOrAllBranchesVisited( @TestParameter boolean first, @TestParameter boolean second, @TestParameter boolean third) throws Exception { - // TODO: Implement exhaustive eval - assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); String expression = String.format("%s || %s || %s", first, second, third); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = @@ -633,8 +630,6 @@ public void trace_shortCircuitingDisabled_logicalOrAllBranchesVisited( @TestParameters("{source: 'x || false || false'}") public void trace_shortCircuitingDisabledWithUnknownsOredToFalse_returnsUnknown(String source) throws Exception { - // TODO: Implement exhaustive eval - assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -655,7 +650,8 @@ public void trace_shortCircuitingDisabledWithUnknownsOredToFalse_returnsUnknown( .build(); CelAbstractSyntaxTree ast = cel.compile(source).getAst(); - Object unknownResult = cel.createProgram(ast).trace(listener); + PartialVars partialVars = PartialVars.of(CelAttributePattern.create("x")); + Object unknownResult = cel.createProgram(ast).trace(partialVars, listener); assertThat(InterpreterUtil.isUnknown(unknownResult)).isTrue(); assertThat(branchResults.build()).containsExactly(false, false, unknownResult); @@ -667,8 +663,6 @@ public void trace_shortCircuitingDisabledWithUnknownsOredToFalse_returnsUnknown( @TestParameters("{source: 'x || true || true'}") public void trace_shortCircuitingDisabledWithUnknownOredToTrue_returnsTrue(String source) throws Exception { - // TODO: Implement exhaustive eval - assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -693,7 +687,8 @@ public void trace_shortCircuitingDisabledWithUnknownOredToTrue_returnsTrue(Strin .build(); CelAbstractSyntaxTree ast = cel.compile(source).getAst(); - boolean result = (boolean) cel.createProgram(ast).trace(listener); + PartialVars partialVars = PartialVars.of(CelAttributePattern.create("x")); + boolean result = (boolean) cel.createProgram(ast).trace(partialVars, listener); assertThat(result).isTrue(); assertThat(branchResults.build()).containsExactly(true, true, "x"); @@ -701,7 +696,6 @@ public void trace_shortCircuitingDisabledWithUnknownOredToTrue_returnsTrue(Strin @Test public void trace_shortCircuitingDisabled_ternaryAllBranchesVisited() throws Exception { - assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -731,8 +725,6 @@ public void trace_shortCircuitingDisabled_ternaryAllBranchesVisited() throws Exc @TestParameters("{source: 'true ? x : false'}") @TestParameters("{source: 'x ? true : false'}") public void trace_shortCircuitingDisabled_ternaryWithUnknowns(String source) throws Exception { - // TODO: Implement exhaustive eval - assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -753,7 +745,8 @@ public void trace_shortCircuitingDisabled_ternaryWithUnknowns(String source) thr .build(); CelAbstractSyntaxTree ast = cel.compile(source).getAst(); - Object unknownResult = cel.createProgram(ast).trace(listener); + PartialVars partialVars = PartialVars.of(CelAttributePattern.create("x")); + Object unknownResult = cel.createProgram(ast).trace(partialVars, listener); assertThat(InterpreterUtil.isUnknown(unknownResult)).isTrue(); assertThat(branchResults.build()).containsExactly(false, unknownResult, true); @@ -770,8 +763,6 @@ public void trace_shortCircuitingDisabled_ternaryWithUnknowns(String source) thr "{expression: 'true ? true : (1 / 0) > 2', firstVisited: true, secondVisited: true}") public void trace_shortCircuitingDisabled_ternaryWithError( String expression, boolean firstVisited, boolean secondVisited) throws Exception { - // TODO: Implement exhaustive eval - assumeTrue(runtimeFlavor != CelRuntimeFlavor.PLANNER); ImmutableList.Builder branchResults = ImmutableList.builder(); CelEvaluationListener listener = (expr, res) -> { @@ -810,6 +801,55 @@ public void trace_shortCircuitingDisabled_ternaryWithError( assertThat(branchResults.build()).containsExactly(firstVisited, secondVisited).inOrder(); } + @Test + public void trace_shortCircuitingDisabled_ternaryWithSelectedError() throws Exception { + Cel cel = + runtimeFlavor + .builder() + .setOptions( + CelOptions.current() + .enableShortCircuiting(false) + .enableHeterogeneousNumericComparisons(true) + .build()) + .build(); + CelAbstractSyntaxTree ast = cel.compile("true ? (1 / 0) : 2").getAst(); + CelRuntime.Program program = cel.createProgram(ast); + + CelEvaluationException e = assertThrows(CelEvaluationException.class, () -> program.eval()); + assertThat(e).hasMessageThat().contains("evaluation error at :10: / by zero"); + assertThat(e.getErrorCode()).isEqualTo(CelErrorCode.DIVIDE_BY_ZERO); + } + + @Test + public void trace_shortCircuitingDisabled_ternaryWithCustomError() throws Exception { + Cel cel = + runtimeFlavor + .builder() + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "error_func", + CelOverloadDecl.newGlobalOverload( + "error_func_overload", SimpleType.BOOL, ImmutableList.of()))) + .addFunctionBindings( + CelFunctionBinding.from( + "error_func_overload", + ImmutableList.of(), + args -> { + throw new IllegalArgumentException("custom error"); + })) + .setOptions( + CelOptions.current() + .enableShortCircuiting(false) + .enableHeterogeneousNumericComparisons(true) + .build()) + .build(); + CelAbstractSyntaxTree ast = cel.compile("true ? error_func() : false").getAst(); + CelRuntime.Program program = cel.createProgram(ast); + + CelEvaluationException e = assertThrows(CelEvaluationException.class, () -> program.eval()); + assertThat(e).hasCauseThat().hasMessageThat().contains("custom error"); + } + @Test public void standardEnvironmentDisabledForRuntime_throws() throws Exception { CelCompiler celCompiler = @@ -817,11 +857,91 @@ public void standardEnvironmentDisabledForRuntime_throws() throws Exception { CelRuntime celRuntime = CelRuntimeFactory.standardCelRuntimeBuilder().setStandardEnvironmentEnabled(false).build(); CelAbstractSyntaxTree ast = celCompiler.compile("size('hello')").getAst(); + CelRuntime.Program program = celRuntime.createProgram(ast); - CelEvaluationException e = - assertThrows(CelEvaluationException.class, () -> celRuntime.createProgram(ast).eval()); + CelEvaluationException e = assertThrows(CelEvaluationException.class, () -> program.eval()); assertThat(e) .hasMessageThat() .contains("No matching overload for function 'size'. Overload candidates: size_string"); } + + @Test + public void trace_shortCircuitingDisabled_logicalAndPrefersFirstError() throws Exception { + Cel cel = + runtimeFlavor + .builder() + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "error_1", + CelOverloadDecl.newGlobalOverload( + "error_1_overload", SimpleType.BOOL, ImmutableList.of())), + CelFunctionDecl.newFunctionDeclaration( + "error_2", + CelOverloadDecl.newGlobalOverload( + "error_2_overload", SimpleType.BOOL, ImmutableList.of()))) + .addFunctionBindings( + CelFunctionBinding.from( + "error_1_overload", + ImmutableList.of(), + args -> { + throw new IllegalArgumentException("error 1"); + }), + CelFunctionBinding.from( + "error_2_overload", + ImmutableList.of(), + args -> { + throw new IllegalArgumentException("error 2"); + })) + .setOptions( + CelOptions.current() + .enableShortCircuiting(false) + .enableHeterogeneousNumericComparisons(true) + .build()) + .build(); + CelAbstractSyntaxTree ast = cel.compile("error_1() && error_2()").getAst(); + CelRuntime.Program program = cel.createProgram(ast); + + CelEvaluationException e = assertThrows(CelEvaluationException.class, () -> program.eval()); + assertThat(e).hasCauseThat().hasMessageThat().contains("error 1"); + } + + @Test + public void trace_shortCircuitingDisabled_logicalOrPrefersFirstError() throws Exception { + Cel cel = + runtimeFlavor + .builder() + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "error_1", + CelOverloadDecl.newGlobalOverload( + "error_1_overload", SimpleType.BOOL, ImmutableList.of())), + CelFunctionDecl.newFunctionDeclaration( + "error_2", + CelOverloadDecl.newGlobalOverload( + "error_2_overload", SimpleType.BOOL, ImmutableList.of()))) + .addFunctionBindings( + CelFunctionBinding.from( + "error_1_overload", + ImmutableList.of(), + args -> { + throw new IllegalArgumentException("error 1"); + }), + CelFunctionBinding.from( + "error_2_overload", + ImmutableList.of(), + args -> { + throw new IllegalArgumentException("error 2"); + })) + .setOptions( + CelOptions.current() + .enableShortCircuiting(false) + .enableHeterogeneousNumericComparisons(true) + .build()) + .build(); + CelAbstractSyntaxTree ast = cel.compile("error_1() || error_2()").getAst(); + CelRuntime.Program program = cel.createProgram(ast); + + CelEvaluationException e = assertThrows(CelEvaluationException.class, () -> program.eval()); + assertThat(e).hasCauseThat().hasMessageThat().contains("error 1"); + } } From cd71e408cea826ff4bfd08e9d677f9ba4f100061 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 23 Apr 2026 18:48:18 -0700 Subject: [PATCH 49/66] Promote planner runtime builders to the main factories PiperOrigin-RevId: 904740634 --- bundle/BUILD.bazel | 8 --- .../src/main/java/dev/cel/bundle/BUILD.bazel | 19 ------- .../cel/bundle/CelExperimentalFactory.java | 57 ------------------- .../main/java/dev/cel/bundle/CelFactory.java | 25 ++++++++ .../test/java/dev/cel/extensions/BUILD.bazel | 1 - .../extensions/CelOptionalLibraryTest.java | 3 +- publish/BUILD.bazel | 2 - runtime/BUILD.bazel | 8 --- .../src/main/java/dev/cel/runtime/BUILD.bazel | 12 ---- .../CelRuntimeExperimentalFactory.java | 48 ---------------- .../dev/cel/runtime/CelRuntimeFactory.java | 19 +++++++ .../java/dev/cel/runtime/CelRuntimeImpl.java | 3 +- .../src/test/java/dev/cel/runtime/BUILD.bazel | 1 - .../cel/runtime/PlannerInterpreterTest.java | 2 +- .../src/main/java/dev/cel/testing/BUILD.bazel | 2 - .../dev/cel/testing/CelRuntimeFlavor.java | 3 +- 16 files changed, 49 insertions(+), 164 deletions(-) delete mode 100644 bundle/src/main/java/dev/cel/bundle/CelExperimentalFactory.java delete mode 100644 runtime/src/main/java/dev/cel/runtime/CelRuntimeExperimentalFactory.java diff --git a/bundle/BUILD.bazel b/bundle/BUILD.bazel index 11e0b8a6d..1eaf0bec8 100644 --- a/bundle/BUILD.bazel +++ b/bundle/BUILD.bazel @@ -13,14 +13,6 @@ java_library( ], ) -java_library( - name = "cel_experimental_factory", - visibility = ["//:internal"], - exports = [ - "//bundle/src/main/java/dev/cel/bundle:cel_experimental_factory", - ], -) - java_library( name = "environment", exports = ["//bundle/src/main/java/dev/cel/bundle:environment"], diff --git a/bundle/src/main/java/dev/cel/bundle/BUILD.bazel b/bundle/src/main/java/dev/cel/bundle/BUILD.bazel index 822511e4c..742f718f1 100644 --- a/bundle/src/main/java/dev/cel/bundle/BUILD.bazel +++ b/bundle/src/main/java/dev/cel/bundle/BUILD.bazel @@ -35,7 +35,6 @@ java_library( "@cel_spec//proto/cel/expr:checked_java_proto", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", - "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", ], ) @@ -54,22 +53,6 @@ java_library( "//compiler:compiler_builder", "//parser", "//runtime", - ], -) - -java_library( - name = "cel_experimental_factory", - srcs = ["CelExperimentalFactory.java"], - tags = [ - ], - deps = [ - ":cel", - ":cel_impl", - "//checker", - "//common:options", - "//common/annotations", - "//compiler", - "//parser", "//runtime:runtime_planner_impl", ], ) @@ -117,7 +100,6 @@ java_library( tags = [ ], deps = [ - ":cel_factory", ":environment_exception", ":required_fields_checker", "//:auto_value", @@ -190,7 +172,6 @@ java_library( "//common:options", "//common/internal:env_visitor", "//common/types:cel_proto_types", - "//common/types:cel_types", "//common/types:type_providers", "//compiler:compiler_builder", "//extensions", diff --git a/bundle/src/main/java/dev/cel/bundle/CelExperimentalFactory.java b/bundle/src/main/java/dev/cel/bundle/CelExperimentalFactory.java deleted file mode 100644 index 9a3e95dd8..000000000 --- a/bundle/src/main/java/dev/cel/bundle/CelExperimentalFactory.java +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2026 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dev.cel.bundle; - -import dev.cel.checker.CelCheckerLegacyImpl; -import dev.cel.common.CelOptions; -import dev.cel.common.annotations.Beta; -import dev.cel.compiler.CelCompilerImpl; -import dev.cel.parser.CelParserImpl; -import dev.cel.runtime.CelRuntimeImpl; - -/** - * Experimental helper class to configure the entire CEL stack in a common interface, backed by the - * new {@code ProgramPlanner} architecture. - * - *

All APIs and behaviors surfaced here are subject to change. - */ -@Beta -public final class CelExperimentalFactory { - - /** - * Creates a builder for configuring CEL for the parsing, optional type-checking, and evaluation - * of expressions using the Program Planner. - * - *

The {@code ProgramPlanner} architecture provides key benefits over the legacy runtime: - * - *

    - *
  • Performance: Programs can be cached for improving evaluation speed. - *
  • Parsed-only expression evaluation: Unlike the traditional stack which required - * supplying type-checked expressions, this architecture handles both parsed-only and - * type-checked expressions. - *
- */ - public static CelBuilder plannerCelBuilder() { - return CelImpl.newBuilder( - CelCompilerImpl.newBuilder( - CelParserImpl.newBuilder(), - CelCheckerLegacyImpl.newBuilder().setStandardEnvironmentEnabled(true)), - CelRuntimeImpl.newBuilder()) - // CEL-Internal-2 - .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()); - } - - private CelExperimentalFactory() {} -} diff --git a/bundle/src/main/java/dev/cel/bundle/CelFactory.java b/bundle/src/main/java/dev/cel/bundle/CelFactory.java index 6cc6d8192..ac589cfe6 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelFactory.java +++ b/bundle/src/main/java/dev/cel/bundle/CelFactory.java @@ -20,6 +20,7 @@ import dev.cel.compiler.CelCompilerImpl; import dev.cel.parser.CelParserImpl; import dev.cel.runtime.CelRuntime; +import dev.cel.runtime.CelRuntimeImpl; import dev.cel.runtime.CelRuntimeLegacyImpl; /** Helper class to configure the entire CEL stack in a common interface. */ @@ -44,6 +45,30 @@ public static CelBuilder standardCelBuilder() { .setStandardEnvironmentEnabled(true); } + /** + * Creates a builder for configuring CEL for the parsing, optional type-checking, and evaluation + * of expressions using the Program Planner. + * + *

The {@code ProgramPlanner} architecture provides key benefits over the {@link + * #standardCelBuilder()}: + * + *

    + *
  • Performance: Programs can be cached for improving evaluation speed. + *
  • Parsed-only expression evaluation: Unlike the traditional stack which required + * supplying type-checked expressions, this architecture handles both parsed-only and + * type-checked expressions. + *
+ */ + public static CelBuilder plannerCelBuilder() { + return CelImpl.newBuilder( + CelCompilerImpl.newBuilder( + CelParserImpl.newBuilder(), + CelCheckerLegacyImpl.newBuilder().setStandardEnvironmentEnabled(true)), + CelRuntimeImpl.newBuilder()) + // CEL-Internal-2 + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()); + } + /** Combines a prebuilt {@link CelCompiler} and {@link CelRuntime} into {@link Cel}. */ public static Cel combine(CelCompiler celCompiler, CelRuntime celRuntime) { return CelImpl.combine(celCompiler, celRuntime); diff --git a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel index eed240317..f926e25b6 100644 --- a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel @@ -10,7 +10,6 @@ java_library( deps = [ "//:java_truth", "//bundle:cel", - "//bundle:cel_experimental_factory", "//common:cel_ast", "//common:cel_exception", "//common:compiler_common", diff --git a/extensions/src/test/java/dev/cel/extensions/CelOptionalLibraryTest.java b/extensions/src/test/java/dev/cel/extensions/CelOptionalLibraryTest.java index 4f348c12a..34c7f89f9 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelOptionalLibraryTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelOptionalLibraryTest.java @@ -26,7 +26,6 @@ import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; import dev.cel.bundle.CelBuilder; -import dev.cel.bundle.CelExperimentalFactory; import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; @@ -113,7 +112,7 @@ private CelBuilder newCelBuilder(int version) { switch (testMode) { case PLANNER_PARSE_ONLY: case PLANNER_CHECKED: - celBuilder = CelExperimentalFactory.plannerCelBuilder(); + celBuilder = CelFactory.plannerCelBuilder(); break; case LEGACY_CHECKED: celBuilder = CelFactory.standardCelBuilder(); diff --git a/publish/BUILD.bazel b/publish/BUILD.bazel index d905edc4b..69622aada 100644 --- a/publish/BUILD.bazel +++ b/publish/BUILD.bazel @@ -32,7 +32,6 @@ RUNTIME_TARGETS = [ "//runtime/src/main/java/dev/cel/runtime:base", "//runtime/src/main/java/dev/cel/runtime:interpreter", "//runtime/src/main/java/dev/cel/runtime:late_function_binding", - "//runtime/src/main/java/dev/cel/runtime:runtime_experimental_factory", "//runtime/src/main/java/dev/cel/runtime:runtime_factory", "//runtime/src/main/java/dev/cel/runtime:runtime_helpers", "//runtime/src/main/java/dev/cel/runtime:runtime_legacy_impl", @@ -125,7 +124,6 @@ EXTENSION_TARGETS = [ # keep sorted BUNDLE_TARGETS = [ "//bundle/src/main/java/dev/cel/bundle:cel", - "//bundle/src/main/java/dev/cel/bundle:cel_experimental_factory", "//bundle/src/main/java/dev/cel/bundle:environment", "//bundle/src/main/java/dev/cel/bundle:environment_yaml_parser", ] diff --git a/runtime/BUILD.bazel b/runtime/BUILD.bazel index 3e183d236..f4a150ff3 100644 --- a/runtime/BUILD.bazel +++ b/runtime/BUILD.bazel @@ -29,14 +29,6 @@ java_library( ], ) -java_library( - name = "runtime_experimental_factory", - visibility = ["//:internal"], - exports = [ - "//runtime/src/main/java/dev/cel/runtime:runtime_experimental_factory", - ], -) - java_library( name = "runtime_legacy_impl", visibility = ["//:internal"], diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index 0da68d548..5178ae27c 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -916,20 +916,8 @@ java_library( deps = [ ":runtime", ":runtime_legacy_impl", - "//common:options", - ], -) - -java_library( - name = "runtime_experimental_factory", - srcs = ["CelRuntimeExperimentalFactory.java"], - tags = [ - ], - deps = [ - ":runtime", ":runtime_planner_impl", "//common:options", - "//common/annotations", ], ) diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeExperimentalFactory.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeExperimentalFactory.java deleted file mode 100644 index 743f90669..000000000 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeExperimentalFactory.java +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2026 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package dev.cel.runtime; - -import dev.cel.common.CelOptions; -import dev.cel.common.annotations.Beta; - -/** - * Experimental helper class to construct new {@code CelRuntime} instances backed by the new {@code - * ProgramPlanner} architecture. - * - *

All APIs and behaviors surfaced here are subject to change. - */ -@Beta -public final class CelRuntimeExperimentalFactory { - - /** - * Create a new builder for constructing a {@code CelRuntime} instance. - * - *

The {@code ProgramPlanner} architecture provides key benefits over the legacy runtime: - * - *

    - *
  • Performance: Programs can be cached for improving evaluation speed. - *
  • Parsed-only expression evaluation: Unlike the traditional legacy runtime, which - * only supported evaluating type-checked expressions, this architecture handles both - * parsed-only and type-checked expressions. - *
- */ - public static CelRuntimeBuilder plannerRuntimeBuilder() { - return CelRuntimeImpl.newBuilder() - // CEL-Internal-2 - .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()); - } - - private CelRuntimeExperimentalFactory() {} -} diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeFactory.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeFactory.java index 322985b22..6615b59e0 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeFactory.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeFactory.java @@ -32,5 +32,24 @@ public static CelRuntimeBuilder standardCelRuntimeBuilder() { .setStandardEnvironmentEnabled(true); } + /** + * Create a new builder for constructing a {@code CelRuntime} instance. + * + *

The {@code ProgramPlanner} architecture provides key benefits over the {@link + * #standardCelRuntimeBuilder()}: + * + *

    + *
  • Performance: Programs can be cached for improving evaluation speed. + *
  • Parsed-only expression evaluation: Unlike the runtime returned by {@link + * #standardCelRuntimeBuilder()}, which only supported evaluating type-checked expressions, + * this architecture handles both parsed-only and type-checked expressions. + *
+ */ + public static CelRuntimeBuilder plannerRuntimeBuilder() { + return CelRuntimeImpl.newBuilder() + // CEL-Internal-2 + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()); + } + private CelRuntimeFactory() {} } diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java index ed203d612..dcdf3be52 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java @@ -474,7 +474,8 @@ private static CelDescriptorPool newDescriptorPool( @Override public CelRuntime build() { - assertAllowedCelOptions(options()); + CelOptions options = options(); + assertAllowedCelOptions(options); CelDescriptors celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptorsBuilder().build()); diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index 3200a80e0..e886c3d8a 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -146,7 +146,6 @@ java_library( "//extensions", "//runtime", "//runtime:function_binding", - "//runtime:runtime_experimental_factory", "//runtime:unknown_attributes", "//testing:base_interpreter_test", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", diff --git a/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java index c0b0f76c4..9ae8590d5 100644 --- a/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java +++ b/runtime/src/test/java/dev/cel/runtime/PlannerInterpreterTest.java @@ -42,7 +42,7 @@ public class PlannerInterpreterTest extends BaseInterpreterTest { @Override protected CelRuntimeBuilder newBaseRuntimeBuilder(CelOptions celOptions) { - return CelRuntimeExperimentalFactory.plannerRuntimeBuilder() + return CelRuntimeFactory.plannerRuntimeBuilder() .addLateBoundFunctions("record") .setOptions(celOptions) .addLibraries(CelExtensions.optional()) diff --git a/testing/src/main/java/dev/cel/testing/BUILD.bazel b/testing/src/main/java/dev/cel/testing/BUILD.bazel index b52026ec4..69765b549 100644 --- a/testing/src/main/java/dev/cel/testing/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/BUILD.bazel @@ -102,7 +102,6 @@ java_library( "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_google_protobuf_protobuf_java_util", "@maven//:junit_junit", - "@maven_android//:com_google_protobuf_protobuf_javalite", ], ) @@ -113,6 +112,5 @@ java_library( ], deps = [ "//bundle:cel", - "//bundle:cel_experimental_factory", ], ) diff --git a/testing/src/main/java/dev/cel/testing/CelRuntimeFlavor.java b/testing/src/main/java/dev/cel/testing/CelRuntimeFlavor.java index 576e0c1d3..66ce8d802 100644 --- a/testing/src/main/java/dev/cel/testing/CelRuntimeFlavor.java +++ b/testing/src/main/java/dev/cel/testing/CelRuntimeFlavor.java @@ -15,7 +15,6 @@ package dev.cel.testing; import dev.cel.bundle.CelBuilder; -import dev.cel.bundle.CelExperimentalFactory; import dev.cel.bundle.CelFactory; /** Enumeration of supported CEL runtime environments for testing. */ @@ -29,7 +28,7 @@ public CelBuilder builder() { PLANNER { @Override public CelBuilder builder() { - return CelExperimentalFactory.plannerCelBuilder(); + return CelFactory.plannerCelBuilder(); } }; From d699e28a52a6239deb887c3fa28897e56d3feb0d Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 23 Apr 2026 20:31:22 -0700 Subject: [PATCH 50/66] Add CombinedCelValueConverter PiperOrigin-RevId: 904770100 --- .../java/dev/cel/common/values/BUILD.bazel | 44 ++++++- .../cel/common/values/CelValueConverter.java | 73 +++++------- .../values/CombinedCelValueConverter.java | 84 +++++++++++++ .../values/CombinedCelValueProvider.java | 9 ++ .../java/dev/cel/common/values/BUILD.bazel | 1 + .../values/CombinedCelValueConverterTest.java | 112 ++++++++++++++++++ common/values/BUILD.bazel | 12 ++ .../java/dev/cel/runtime/CelRuntimeImpl.java | 12 +- 8 files changed, 298 insertions(+), 49 deletions(-) create mode 100644 common/src/main/java/dev/cel/common/values/CombinedCelValueConverter.java create mode 100644 common/src/test/java/dev/cel/common/values/CombinedCelValueConverterTest.java diff --git a/common/src/main/java/dev/cel/common/values/BUILD.bazel b/common/src/main/java/dev/cel/common/values/BUILD.bazel index 0d1d5431f..d572bb2bc 100644 --- a/common/src/main/java/dev/cel/common/values/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/values/BUILD.bazel @@ -78,10 +78,14 @@ cel_android_library( java_library( name = "combined_cel_value_provider", - srcs = ["CombinedCelValueProvider.java"], + srcs = [ + "CombinedCelValueProvider.java", + ], tags = [ ], deps = [ + ":combined_cel_value_converter", + ":values", "//common/values:cel_value_provider", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -90,16 +94,52 @@ java_library( cel_android_library( name = "combined_cel_value_provider_android", - srcs = ["CombinedCelValueProvider.java"], + srcs = [ + "CombinedCelValueProvider.java", + ], tags = [ ], deps = [ + ":combined_cel_value_converter_android", + ":values_android", "//common/values:cel_value_provider_android", "@maven//:com_google_errorprone_error_prone_annotations", "@maven_android//:com_google_guava_guava", ], ) +java_library( + name = "combined_cel_value_converter", + srcs = [ + "CombinedCelValueConverter.java", + ], + tags = [ + ], + deps = [ + ":values", + "//common/annotations", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + "@maven//:org_jspecify_jspecify", + ], +) + +cel_android_library( + name = "combined_cel_value_converter_android", + srcs = [ + "CombinedCelValueConverter.java", + ], + tags = [ + ], + deps = [ + ":values_android", + "//common/annotations", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:org_jspecify_jspecify", + "@maven_android//:com_google_guava_guava", + ], +) + java_library( name = "values", srcs = CEL_VALUES_SOURCES, diff --git a/common/src/main/java/dev/cel/common/values/CelValueConverter.java b/common/src/main/java/dev/cel/common/values/CelValueConverter.java index 70d04acc8..89f5ab100 100644 --- a/common/src/main/java/dev/cel/common/values/CelValueConverter.java +++ b/common/src/main/java/dev/cel/common/values/CelValueConverter.java @@ -21,8 +21,8 @@ import dev.cel.common.annotations.Internal; import java.util.Collection; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; +import java.util.function.Function; /** * {@code CelValueConverter} handles bidirectional conversion between native Java objects to {@link @@ -37,6 +37,12 @@ public class CelValueConverter { private static final CelValueConverter DEFAULT_INSTANCE = new CelValueConverter(); + @SuppressWarnings("Immutable") // Method reference is immutable + private final Function maybeUnwrapFunction; + + @SuppressWarnings("Immutable") // Method reference is immutable + private final Function toRuntimeValueFunction; + public static CelValueConverter getDefaultInstance() { return DEFAULT_INSTANCE; } @@ -51,14 +57,26 @@ public Object maybeUnwrap(Object value) { return unwrap((CelValue) value); } + Object mapped = mapContainer(value, maybeUnwrapFunction); + if (mapped != value) { + return mapped; + } + + return value; + } + + /** + * Maps a container (Collection or Map) by applying the provided mapper function to its elements. + * Returns the original value if it's not a supported container. + */ + protected Object mapContainer(Object value, Function mapper) { if (value instanceof Collection) { Collection collection = (Collection) value; ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(collection.size()); for (Object element : collection) { - builder.add(maybeUnwrap(element)); + builder.add(mapper.apply(element)); } - return builder.build(); } @@ -67,19 +85,14 @@ public Object maybeUnwrap(Object value) { ImmutableMap.Builder builder = ImmutableMap.builderWithExpectedSize(map.size()); for (Map.Entry entry : map.entrySet()) { - builder.put(maybeUnwrap(entry.getKey()), maybeUnwrap(entry.getValue())); + builder.put(mapper.apply(entry.getKey()), mapper.apply(entry.getValue())); } - return builder.buildOrThrow(); } return value; } - /** - * Canonicalizes an inbound {@code value} into a suitable Java object representation for - * evaluation. - */ public Object toRuntimeValue(Object value) { Preconditions.checkNotNull(value); @@ -87,14 +100,15 @@ public Object toRuntimeValue(Object value) { return value; } - if (value instanceof Collection) { - return toListValue((Collection) value); - } else if (value instanceof Map) { - return toMapValue((Map) value); - } else if (value instanceof Optional) { + Object mapped = mapContainer(value, toRuntimeValueFunction); + if (mapped != value) { + return mapped; + } + + if (value instanceof Optional) { Optional optionalValue = (Optional) value; return optionalValue - .map(this::toRuntimeValue) + .map(toRuntimeValueFunction) .map(OptionalValue::create) .orElse(OptionalValue.EMPTY); } @@ -136,31 +150,8 @@ private Object unwrap(CelValue celValue) { return celValue.value(); } - private ImmutableList toListValue(Collection iterable) { - Preconditions.checkNotNull(iterable); - - ImmutableList.Builder listBuilder = - ImmutableList.builderWithExpectedSize(iterable.size()); - for (Object entry : iterable) { - listBuilder.add(toRuntimeValue(entry)); - } - - return listBuilder.build(); - } - - private ImmutableMap toMapValue(Map map) { - Preconditions.checkNotNull(map); - - ImmutableMap.Builder mapBuilder = - ImmutableMap.builderWithExpectedSize(map.size()); - for (Entry entry : map.entrySet()) { - Object mapKey = toRuntimeValue(entry.getKey()); - Object mapValue = toRuntimeValue(entry.getValue()); - mapBuilder.put(mapKey, mapValue); - } - - return mapBuilder.buildOrThrow(); + protected CelValueConverter() { + this.maybeUnwrapFunction = this::maybeUnwrap; + this.toRuntimeValueFunction = this::toRuntimeValue; } - - protected CelValueConverter() {} } diff --git a/common/src/main/java/dev/cel/common/values/CombinedCelValueConverter.java b/common/src/main/java/dev/cel/common/values/CombinedCelValueConverter.java new file mode 100644 index 000000000..46e5fc3f1 --- /dev/null +++ b/common/src/main/java/dev/cel/common/values/CombinedCelValueConverter.java @@ -0,0 +1,84 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.values; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableList; +import dev.cel.common.annotations.Internal; +import org.jspecify.annotations.Nullable; + +/** + * {@code CombinedCelValueConverter} delegates value conversion to a list of underlying {@link + * CelValueConverter}s. + */ +@Internal +public final class CombinedCelValueConverter extends CelValueConverter { + private final ImmutableList converters; + + public static CombinedCelValueConverter combine(ImmutableList converters) { + return new CombinedCelValueConverter(converters); + } + + private CombinedCelValueConverter(ImmutableList converters) { + this.converters = checkNotNull(converters); + } + + @Override + public @Nullable Object toRuntimeValue(Object value) { + if (value == null) { + return null; + } + + // Let the base class handle CelValues, Optionals, Collections, Maps, and primitives. + Object baseResult = super.toRuntimeValue(value); + if (baseResult != value) { + return baseResult; + } + + // If the base class left the object unchanged (e.g. a raw POJO), try the delegates. + for (CelValueConverter converter : converters) { + Object result = converter.toRuntimeValue(value); + if (result != value) { + return result; + } + } + + return value; + } + + @Override + public @Nullable Object maybeUnwrap(Object value) { + if (value == null) { + return null; + } + + // Let the base class handle standard unwrapping and container unrolling. + Object baseResult = super.maybeUnwrap(value); + if (baseResult != value) { + return baseResult; + } + + // Try delegates for specialized unwrapping. + for (CelValueConverter converter : converters) { + Object result = converter.maybeUnwrap(value); + if (result != value) { + return result; + } + } + + return value; + } +} diff --git a/common/src/main/java/dev/cel/common/values/CombinedCelValueProvider.java b/common/src/main/java/dev/cel/common/values/CombinedCelValueProvider.java index 8fe62cb7b..d51c3afce 100644 --- a/common/src/main/java/dev/cel/common/values/CombinedCelValueProvider.java +++ b/common/src/main/java/dev/cel/common/values/CombinedCelValueProvider.java @@ -16,6 +16,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.collect.ImmutableList.toImmutableList; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; @@ -49,6 +50,14 @@ public Optional newValue(String structType, Map fields) return Optional.empty(); } + @Override + public CelValueConverter celValueConverter() { + return CombinedCelValueConverter.combine( + celValueProviders.stream() + .map(CelValueProvider::celValueConverter) + .collect(toImmutableList())); + } + /** Returns the underlying {@link CelValueProvider}s in order. */ public ImmutableList valueProviders() { return celValueProviders; diff --git a/common/src/test/java/dev/cel/common/values/BUILD.bazel b/common/src/test/java/dev/cel/common/values/BUILD.bazel index ab7eae8dd..bf151fcb7 100644 --- a/common/src/test/java/dev/cel/common/values/BUILD.bazel +++ b/common/src/test/java/dev/cel/common/values/BUILD.bazel @@ -24,6 +24,7 @@ java_library( "//common/values", "//common/values:cel_byte_string", "//common/values:cel_value_provider", + "//common/values:combined_cel_value_converter", "//common/values:combined_cel_value_provider", "//common/values:proto_message_lite_value", "//common/values:proto_message_lite_value_provider", diff --git a/common/src/test/java/dev/cel/common/values/CombinedCelValueConverterTest.java b/common/src/test/java/dev/cel/common/values/CombinedCelValueConverterTest.java new file mode 100644 index 000000000..8574587bc --- /dev/null +++ b/common/src/test/java/dev/cel/common/values/CombinedCelValueConverterTest.java @@ -0,0 +1,112 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.values; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import java.util.Map; +import java.util.Optional; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class CombinedCelValueConverterTest { + + @Test + public void toRuntimeValue_delegatesToUnderlyingConverters() { + CustomConverter converter1 = new CustomConverter("target1", "replacement1"); + CustomConverter converter2 = new CustomConverter("target2", "replacement2"); + CelValueConverter combined = + CombinedCelValueConverter.combine(ImmutableList.of(converter1, converter2)); + + assertThat(combined.toRuntimeValue("target1")).isEqualTo("replacement1"); + assertThat(combined.toRuntimeValue("target2")).isEqualTo("replacement2"); + assertThat(combined.toRuntimeValue("unhandled")).isEqualTo("unhandled"); + } + + @Test + public void maybeUnwrap_delegatesToUnderlyingConverters() { + CustomConverter converter1 = new CustomConverter("target1", "replacement1"); + CustomConverter converter2 = new CustomConverter("target2", "replacement2"); + CelValueConverter combined = + CombinedCelValueConverter.combine(ImmutableList.of(converter1, converter2)); + + assertThat(combined.maybeUnwrap("replacement1")).isEqualTo("target1"); + assertThat(combined.maybeUnwrap("replacement2")).isEqualTo("target2"); + assertThat(combined.maybeUnwrap("unhandled")).isEqualTo("unhandled"); + } + + @Test + public void combinedCelValueProvider_returnsCombinedConverter() { + CustomConverter converter1 = new CustomConverter("target1", "replacement1"); + CustomConverter converter2 = new CustomConverter("target2", "replacement2"); + CustomProvider provider1 = new CustomProvider(converter1); + CustomProvider provider2 = new CustomProvider(converter2); + + CombinedCelValueProvider combinedProvider = + CombinedCelValueProvider.combine(provider1, provider2); + CelValueConverter combinedConverter = combinedProvider.celValueConverter(); + + assertThat(combinedConverter).isInstanceOf(CombinedCelValueConverter.class); + assertThat(combinedConverter.toRuntimeValue("target1")).isEqualTo("replacement1"); + assertThat(combinedConverter.toRuntimeValue("target2")).isEqualTo("replacement2"); + } + + private static class CustomConverter extends CelValueConverter { + private final String target; + private final String replacement; + + private CustomConverter(String target, String replacement) { + this.target = target; + this.replacement = replacement; + } + + @Override + public Object toRuntimeValue(Object value) { + if (value.equals(target)) { + return replacement; + } + return value; + } + + @Override + public Object maybeUnwrap(Object value) { + if (value.equals(replacement)) { + return target; + } + return value; + } + } + + private static class CustomProvider implements CelValueProvider { + private final CelValueConverter converter; + + private CustomProvider(CelValueConverter converter) { + this.converter = converter; + } + + @Override + public Optional newValue(String structType, Map fields) { + return Optional.empty(); + } + + @Override + public CelValueConverter celValueConverter() { + return converter; + } + } +} diff --git a/common/values/BUILD.bazel b/common/values/BUILD.bazel index 74bfa9e0f..9853289a9 100644 --- a/common/values/BUILD.bazel +++ b/common/values/BUILD.bazel @@ -37,6 +37,18 @@ cel_android_library( exports = ["//common/src/main/java/dev/cel/common/values:combined_cel_value_provider_android"], ) +java_library( + name = "combined_cel_value_converter", + visibility = ["//:internal"], + exports = ["//common/src/main/java/dev/cel/common/values:combined_cel_value_converter"], +) + +cel_android_library( + name = "combined_cel_value_converter_android", + visibility = ["//:internal"], + exports = ["//common/src/main/java/dev/cel/common/values:combined_cel_value_converter_android"], +) + java_library( name = "values", exports = ["//common/src/main/java/dev/cel/common/values"], diff --git a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java index dcdf3be52..adfba967b 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java +++ b/runtime/src/main/java/dev/cel/runtime/CelRuntimeImpl.java @@ -487,12 +487,6 @@ public CelRuntime build() { DynamicProto dynamicProto = DynamicProto.create(defaultMessageFactory); CelValueProvider protoMessageValueProvider = ProtoMessageValueProvider.newInstance(options(), dynamicProto); - CelValueConverter celValueConverter = protoMessageValueProvider.celValueConverter(); - if (valueProvider() != null) { - protoMessageValueProvider = - CombinedCelValueProvider.combine(protoMessageValueProvider, valueProvider()); - } - RuntimeEquality runtimeEquality = ProtoMessageRuntimeEquality.create(dynamicProto, options()); ImmutableSet runtimeLibraries = runtimeLibrariesBuilder().build(); // Add libraries, such as extensions @@ -505,6 +499,12 @@ public CelRuntime build() { } } + if (valueProvider() != null) { + protoMessageValueProvider = + CombinedCelValueProvider.combine(protoMessageValueProvider, valueProvider()); + } + CelValueConverter celValueConverter = protoMessageValueProvider.celValueConverter(); + CelTypeProvider messageTypeProvider = ProtoMessageTypeProvider.newBuilder() .setCelDescriptors(celDescriptors) From f5664501d9401930481276523fcb9681cf6bda1f Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 23 Apr 2026 22:53:50 -0700 Subject: [PATCH 51/66] Add value type parameter in StructValue PiperOrigin-RevId: 904815361 --- .../common/values/ProtoMessageLiteValue.java | 2 +- .../cel/common/values/ProtoMessageValue.java | 2 +- .../dev/cel/common/values/StructValue.java | 16 +++- .../cel/common/values/OptionalValueTest.java | 2 +- .../cel/common/values/StructValueTest.java | 88 +++++++++++++------ .../cel/runtime/planner/EvalCreateStruct.java | 3 +- 6 files changed, 76 insertions(+), 37 deletions(-) diff --git a/common/src/main/java/dev/cel/common/values/ProtoMessageLiteValue.java b/common/src/main/java/dev/cel/common/values/ProtoMessageLiteValue.java index 52f0f1594..2e4d980c7 100644 --- a/common/src/main/java/dev/cel/common/values/ProtoMessageLiteValue.java +++ b/common/src/main/java/dev/cel/common/values/ProtoMessageLiteValue.java @@ -35,7 +35,7 @@ */ @AutoValue @Immutable -public abstract class ProtoMessageLiteValue extends StructValue { +public abstract class ProtoMessageLiteValue extends StructValue { @Override public abstract MessageLite value(); diff --git a/common/src/main/java/dev/cel/common/values/ProtoMessageValue.java b/common/src/main/java/dev/cel/common/values/ProtoMessageValue.java index 12d47c253..627bd2c1d 100644 --- a/common/src/main/java/dev/cel/common/values/ProtoMessageValue.java +++ b/common/src/main/java/dev/cel/common/values/ProtoMessageValue.java @@ -28,7 +28,7 @@ /** ProtoMessageValue is a struct value with protobuf support. */ @AutoValue @Immutable -public abstract class ProtoMessageValue extends StructValue { +public abstract class ProtoMessageValue extends StructValue { @Override public abstract Message value(); diff --git a/common/src/main/java/dev/cel/common/values/StructValue.java b/common/src/main/java/dev/cel/common/values/StructValue.java index 8775ef5c4..aa44ec420 100644 --- a/common/src/main/java/dev/cel/common/values/StructValue.java +++ b/common/src/main/java/dev/cel/common/values/StructValue.java @@ -19,13 +19,21 @@ /** * StructValue is a representation of a structured object with typed properties. * - *

Users may extend from this class to provide a custom struct that CEL can understand (ex: - * POJOs). Custom struct implementations must provide all functionalities denoted in the CEL - * specification, such as field selection, presence testing and new object creation. + *

Users may extend from this class to provide a custom struct that CEL can understand by + * wrapping a native Java object (e.g., a POJO or a Map). Custom struct implementations must provide + * all functionalities denoted in the CEL specification, such as field selection, presence testing + * and new object creation. * *

For an expression `e` selecting a field `f`, `e.f` must throw an exception if `f` does not * exist in the struct (i.e: hasField returns false). If the field exists but is not set, the * implementation should return an appropriate default value based on the struct's semantics. + * + * @param The type of the field identifier. Only {@code String} is supported for now, but we may + * extend support to other types in the future. + * @param The type of the wrapped native object. */ @Immutable -public abstract class StructValue extends CelValue implements SelectableValue {} +public abstract class StructValue extends CelValue implements SelectableValue { + @Override + public abstract V value(); +} diff --git a/common/src/test/java/dev/cel/common/values/OptionalValueTest.java b/common/src/test/java/dev/cel/common/values/OptionalValueTest.java index 24b3ea30b..f00954e3d 100644 --- a/common/src/test/java/dev/cel/common/values/OptionalValueTest.java +++ b/common/src/test/java/dev/cel/common/values/OptionalValueTest.java @@ -141,7 +141,7 @@ public void celTypeTest() { } @SuppressWarnings("Immutable") // Test only - private static class CelCustomStruct extends StructValue { + private static class CelCustomStruct extends StructValue { private final long data; @Override diff --git a/common/src/test/java/dev/cel/common/values/StructValueTest.java b/common/src/test/java/dev/cel/common/values/StructValueTest.java index b8d6371a8..f25db8e87 100644 --- a/common/src/test/java/dev/cel/common/values/StructValueTest.java +++ b/common/src/test/java/dev/cel/common/values/StructValueTest.java @@ -59,18 +59,34 @@ public Optional findType(String typeName) { }; private static final CelValueProvider CUSTOM_STRUCT_VALUE_PROVIDER = - (structType, fields) -> { - if (structType.equals(CUSTOM_STRUCT_TYPE.name())) { - return Optional.of(new CelCustomStructValue(fields)); + new CelValueProvider() { + @Override + public Optional newValue(String structType, Map fields) { + if (structType.equals(CUSTOM_STRUCT_TYPE.name())) { + return Optional.of(new CelCustomStructValue(fields)); + } + return Optional.empty(); + } + + @Override + public CelValueConverter celValueConverter() { + return new CelValueConverter() { + @Override + public Object toRuntimeValue(Object value) { + if (value instanceof CustomPojo) { + return new CelCustomStructValue((CustomPojo) value); + } + return super.toRuntimeValue(value); + } + }; } - return Optional.empty(); }; @Test public void emptyStruct() { CelCustomStructValue celCustomStruct = new CelCustomStructValue(0); - assertThat(celCustomStruct.value()).isEqualTo(celCustomStruct); + assertThat(celCustomStruct.value().getData()).isEqualTo(0L); assertThat(celCustomStruct.isZeroValue()).isTrue(); } @@ -78,7 +94,7 @@ public void emptyStruct() { public void constructStruct() { CelCustomStructValue celCustomStruct = new CelCustomStructValue(5); - assertThat(celCustomStruct.value()).isEqualTo(celCustomStruct); + assertThat(celCustomStruct.value().getData()).isEqualTo(5L); assertThat(celCustomStruct.isZeroValue()).isFalse(); } @@ -115,41 +131,41 @@ public void celTypeTest() { @Test public void evaluate_usingCustomClass_createNewStruct() throws Exception { Cel cel = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableCelValue(true).build()) + CelFactory.plannerCelBuilder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) .setTypeProvider(CUSTOM_STRUCT_TYPE_PROVIDER) .setValueProvider(CUSTOM_STRUCT_VALUE_PROVIDER) .build(); CelAbstractSyntaxTree ast = cel.compile("custom_struct{data: 50}").getAst(); - CelCustomStructValue result = (CelCustomStructValue) cel.createProgram(ast).eval(); + CustomPojo result = (CustomPojo) cel.createProgram(ast).eval(); - assertThat(result.data).isEqualTo(50); + assertThat(result.getData()).isEqualTo(50); } @Test public void evaluate_usingCustomClass_asVariable() throws Exception { Cel cel = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableCelValue(true).build()) + CelFactory.plannerCelBuilder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) .addVar("a", CUSTOM_STRUCT_TYPE) .setTypeProvider(CUSTOM_STRUCT_TYPE_PROVIDER) .setValueProvider(CUSTOM_STRUCT_VALUE_PROVIDER) .build(); CelAbstractSyntaxTree ast = cel.compile("a").getAst(); - CelCustomStructValue result = - (CelCustomStructValue) + CustomPojo result = + (CustomPojo) cel.createProgram(ast).eval(ImmutableMap.of("a", new CelCustomStructValue(10))); - assertThat(result.data).isEqualTo(10); + assertThat(result.getData()).isEqualTo(10); } @Test public void evaluate_usingCustomClass_asVariableSelectField() throws Exception { Cel cel = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableCelValue(true).build()) + CelFactory.plannerCelBuilder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) .addVar("a", CUSTOM_STRUCT_TYPE) .setTypeProvider(CUSTOM_STRUCT_TYPE_PROVIDER) .setValueProvider(CUSTOM_STRUCT_VALUE_PROVIDER) @@ -163,8 +179,8 @@ public void evaluate_usingCustomClass_asVariableSelectField() throws Exception { @Test public void evaluate_usingCustomClass_selectField() throws Exception { Cel cel = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableCelValue(true).build()) + CelFactory.plannerCelBuilder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) .setTypeProvider(CUSTOM_STRUCT_TYPE_PROVIDER) .setValueProvider(CUSTOM_STRUCT_VALUE_PROVIDER) .build(); @@ -178,8 +194,8 @@ public void evaluate_usingCustomClass_selectField() throws Exception { @Test public void evaluate_usingMultipleProviders_selectFieldFromCustomClass() throws Exception { Cel cel = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableCelValue(true).build()) + CelFactory.plannerCelBuilder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) .setTypeProvider(CUSTOM_STRUCT_TYPE_PROVIDER) .setValueProvider( CombinedCelValueProvider.combine( @@ -197,19 +213,31 @@ public void evaluate_usingMultipleProviders_selectFieldFromCustomClass() throws // TODO: Bring back evaluate_usingMultipleProviders_selectFieldFromProtobufMessage // once planner is exposed from factory + private static class CustomPojo { + private final long data; + + CustomPojo(long data) { + this.data = data; + } + + long getData() { + return data; + } + } + @SuppressWarnings("Immutable") // Test only - private static class CelCustomStructValue extends StructValue { + private static class CelCustomStructValue extends StructValue { - private final long data; + private final CustomPojo pojo; @Override - public CelCustomStructValue value() { - return this; + public CustomPojo value() { + return pojo; } @Override public boolean isZeroValue() { - return data == 0; + return pojo.getData() == 0; } @Override @@ -226,7 +254,7 @@ public Object select(String field) { @Override public Optional find(String field) { if (field.equals("data")) { - return Optional.of(value().data); + return Optional.of(pojo.getData()); } return Optional.empty(); @@ -237,7 +265,11 @@ private CelCustomStructValue(Map fields) { } private CelCustomStructValue(long data) { - this.data = data; + this.pojo = new CustomPojo(data); + } + + private CelCustomStructValue(CustomPojo pojo) { + this.pojo = pojo; } } } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java index 36485d5be..a2e8a9da6 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalCreateStruct.java @@ -87,9 +87,8 @@ Object evalInternal(GlobalResolver resolver, ExecutionFrame frame) { .newValue(structType.name(), Collections.unmodifiableMap(fieldValues)) .orElseThrow( () -> new IllegalArgumentException("Type name not found: " + structType.name())); - if (value instanceof StructValue) { - return ((StructValue) value).value(); + return ((StructValue) value).value(); } return value; From 21c3318538393114b9237bcc1c806a999b07f79c Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 28 Apr 2026 16:32:14 -0700 Subject: [PATCH 52/66] Policy nested rule fix PiperOrigin-RevId: 907239115 --- .../src/main/java/dev/cel/policy/BUILD.bazel | 4 +- .../java/dev/cel/policy/RuleComposer.java | 284 +++++++++++++----- .../cel/policy/CelPolicyCompilerImplTest.java | 29 +- .../java/dev/cel/policy/PolicyTestHelper.java | 28 +- .../src/test/resources/policy/k8s/tests.yaml | 31 +- .../test/resources/policy/limits/tests.yaml | 48 +-- .../resources/policy/nested_rule/tests.yaml | 39 +-- .../resources/policy/nested_rule2/tests.yaml | 68 +++-- .../resources/policy/nested_rule3/tests.yaml | 68 +++-- .../resources/policy/nested_rule4/config.yaml | 19 ++ .../resources/policy/nested_rule4/policy.yaml | 24 ++ .../resources/policy/nested_rule4/tests.yaml | 30 ++ .../resources/policy/nested_rule5/config.yaml | 19 ++ .../resources/policy/nested_rule5/policy.yaml | 30 ++ .../resources/policy/nested_rule5/tests.yaml | 42 +++ .../resources/policy/nested_rule6/config.yaml | 19 ++ .../resources/policy/nested_rule6/policy.yaml | 28 ++ .../resources/policy/nested_rule6/tests.yaml | 24 ++ .../resources/policy/nested_rule7/config.yaml | 19 ++ .../resources/policy/nested_rule7/policy.yaml | 29 ++ .../resources/policy/nested_rule7/tests.yaml | 42 +++ .../src/test/resources/policy/pb/tests.yaml | 35 +-- .../policy/required_labels/tests.yaml | 115 +++---- .../policy/restricted_destinations/tests.yaml | 200 ++++++------ 24 files changed, 893 insertions(+), 381 deletions(-) create mode 100644 testing/src/test/resources/policy/nested_rule4/config.yaml create mode 100644 testing/src/test/resources/policy/nested_rule4/policy.yaml create mode 100644 testing/src/test/resources/policy/nested_rule4/tests.yaml create mode 100644 testing/src/test/resources/policy/nested_rule5/config.yaml create mode 100644 testing/src/test/resources/policy/nested_rule5/policy.yaml create mode 100644 testing/src/test/resources/policy/nested_rule5/tests.yaml create mode 100644 testing/src/test/resources/policy/nested_rule6/config.yaml create mode 100644 testing/src/test/resources/policy/nested_rule6/policy.yaml create mode 100644 testing/src/test/resources/policy/nested_rule6/tests.yaml create mode 100644 testing/src/test/resources/policy/nested_rule7/config.yaml create mode 100644 testing/src/test/resources/policy/nested_rule7/policy.yaml create mode 100644 testing/src/test/resources/policy/nested_rule7/tests.yaml diff --git a/policy/src/main/java/dev/cel/policy/BUILD.bazel b/policy/src/main/java/dev/cel/policy/BUILD.bazel index 916f16f9b..a7bb90ffe 100644 --- a/policy/src/main/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/main/java/dev/cel/policy/BUILD.bazel @@ -247,19 +247,19 @@ java_library( visibility = ["//visibility:private"], deps = [ ":compiled_rule", - "//:auto_value", "//bundle:cel", "//common:cel_ast", "//common:compiler_common", "//common:mutable_ast", + "//common:mutable_source", "//common:operator", "//common/ast", + "//common/ast:mutable_expr", "//common/formats:value_string", "//common/navigation:mutable_navigation", "//extensions:optional_library", "//optimizer:ast_optimizer", "//optimizer:mutable_ast", - "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", ], ) diff --git a/policy/src/main/java/dev/cel/policy/RuleComposer.java b/policy/src/main/java/dev/cel/policy/RuleComposer.java index 7bbde7685..cafef4b7c 100644 --- a/policy/src/main/java/dev/cel/policy/RuleComposer.java +++ b/policy/src/main/java/dev/cel/policy/RuleComposer.java @@ -18,15 +18,17 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.stream.Collectors.toCollection; -import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelMutableAst; +import dev.cel.common.CelMutableSource; import dev.cel.common.CelValidationException; import dev.cel.common.Operator; +import dev.cel.common.ast.CelConstant; import dev.cel.common.ast.CelExpr.ExprKind.Kind; +import dev.cel.common.ast.CelMutableExpr; import dev.cel.common.formats.ValueString; import dev.cel.common.navigation.CelNavigableMutableAst; import dev.cel.common.navigation.CelNavigableMutableExpr; @@ -48,22 +50,11 @@ final class RuleComposer implements CelAstOptimizer { @Override public OptimizationResult optimize(CelAbstractSyntaxTree ast, Cel cel) { - RuleOptimizationResult result = optimizeRule(cel, compiledRule); - return OptimizationResult.create(result.ast().toParsedAst()); + Step result = optimizeRule(cel, compiledRule); + return OptimizationResult.create(result.expr.toParsedAst()); } - @AutoValue - abstract static class RuleOptimizationResult { - abstract CelMutableAst ast(); - - abstract boolean isOptionalResult(); - - static RuleOptimizationResult create(CelMutableAst ast, boolean isOptionalResult) { - return new AutoValue_RuleComposer_RuleOptimizationResult(ast, isOptionalResult); - } - } - - private RuleOptimizationResult optimizeRule(Cel cel, CelCompiledRule compiledRule) { + private Step optimizeRule(Cel cel, CelCompiledRule compiledRule) { cel = cel.toCelBuilder() .addVarDeclarations( @@ -72,81 +63,57 @@ private RuleOptimizationResult optimizeRule(Cel cel, CelCompiledRule compiledRul .collect(toImmutableList())) .build(); - CelMutableAst matchAst = astMutator.newGlobalCall(Function.OPTIONAL_NONE.getFunction()); - boolean isOptionalResult = true; - // Keep track of the last output ID that might cause type-check failure while attempting to - // compose the subgraphs. + Step output = null; + // If the rule has an optional output, the last result in the ternary should return + // `optional.none`. This output is implicit and created here to reflect the desired + // last possible output of this type of rule. + if (compiledRule.hasOptionalOutput()) { + output = + Step.newUnconditionalOptionalStep( + newTrueLiteral(), astMutator.newGlobalCall(Function.OPTIONAL_NONE.getFunction())); + } + long lastOutputId = 0; for (CelCompiledMatch match : Lists.reverse(compiledRule.matches())) { CelAbstractSyntaxTree conditionAst = match.condition(); - // If the condition is trivially true, none of the matches in the rule causes the result - // to become optional, and the rule is not the last match, then this will introduce - // unreachable outputs or rules. boolean isTriviallyTrue = match.isConditionTriviallyTrue(); + CelMutableAst condAst = CelMutableAst.fromCelAst(conditionAst); switch (match.result().kind()) { - // For the match's output, determine whether the output should be wrapped - // into an optional value, a conditional, or both. case OUTPUT: + // If the match has an output, then it is considered a non-optional output since + // it is explicitly stated. If the rule itself is optional, then the base case value + // of output being optional.none() will convert the non-optional value to an optional + // one. OutputValue matchOutput = match.result().output(); CelMutableAst outAst = CelMutableAst.fromCelAst(matchOutput.ast()); - if (isTriviallyTrue) { - matchAst = outAst; - isOptionalResult = false; - lastOutputId = matchOutput.sourceId(); - continue; - } - if (isOptionalResult) { - outAst = astMutator.newGlobalCall(Function.OPTIONAL_OF.getFunction(), outAst); - } - - matchAst = - astMutator.newGlobalCall( - Operator.CONDITIONAL.getFunction(), - CelMutableAst.fromCelAst(conditionAst), - outAst, - matchAst); + Step step = Step.newNonOptionalStep(!isTriviallyTrue, condAst, outAst); + output = combine(astMutator, step, output); + assertComposedAstIsValid( cel, - matchAst, + output.expr, "conflicting output types found.", matchOutput.sourceId(), lastOutputId); lastOutputId = matchOutput.sourceId(); - continue; + break; case RULE: // If the match has a nested rule, then compute the rule and whether it has // an optional return value. CelCompiledRule matchNestedRule = match.result().rule(); - RuleOptimizationResult nestedRule = optimizeRule(cel, matchNestedRule); + Step nestedRule = optimizeRule(cel, matchNestedRule); boolean nestedHasOptional = matchNestedRule.hasOptionalOutput(); - CelMutableAst nestedRuleAst = nestedRule.ast(); - if (isOptionalResult && !nestedHasOptional) { - nestedRuleAst = - astMutator.newGlobalCall(Function.OPTIONAL_OF.getFunction(), nestedRuleAst); - } - if (!isOptionalResult && nestedHasOptional) { - matchAst = astMutator.newGlobalCall(Function.OPTIONAL_OF.getFunction(), matchAst); - isOptionalResult = true; - } - // If either the nested rule or current condition output are optional then - // use optional.or() to specify the combination of the first and second results - // Note, the argument order is reversed due to the traversal of matches in - // reverse order. - if (isOptionalResult && isTriviallyTrue) { - matchAst = astMutator.newMemberCall(nestedRuleAst, Function.OR.getFunction(), matchAst); - } else { - matchAst = - astMutator.newGlobalCall( - Operator.CONDITIONAL.getFunction(), - CelMutableAst.fromCelAst(conditionAst), - nestedRuleAst, - matchAst); - } + + Step ruleStep = + nestedHasOptional + ? Step.newOptionalStep(!isTriviallyTrue, condAst, nestedRule.expr) + : Step.newNonOptionalStep(!isTriviallyTrue, condAst, nestedRule.expr); + output = combine(astMutator, ruleStep, output); assertComposedAstIsValid( cel, - matchAst, + output.expr, String.format( "failed composing the subrule '%s' due to conflicting output types.", matchNestedRule.ruleId().map(ValueString::value).orElse("")), @@ -155,11 +122,124 @@ private RuleOptimizationResult optimizeRule(Cel cel, CelCompiledRule compiledRul } } - CelMutableAst result = inlineCompiledVariables(matchAst, compiledRule.variables()); + CelMutableAst resultExpr = output.expr; + resultExpr = inlineCompiledVariables(resultExpr, compiledRule.variables()); + resultExpr = astMutator.renumberIdsConsecutively(resultExpr); + + return output.isOptional + ? Step.newUnconditionalOptionalStep(newTrueLiteral(), resultExpr) + : Step.newUnconditionalNonOptionalStep(newTrueLiteral(), resultExpr); + } + + static RuleComposer newInstance( + CelCompiledRule compiledRule, String variablePrefix, int iterationLimit) { + return new RuleComposer(compiledRule, variablePrefix, iterationLimit); + } + + // Assembles two output expressions into a single output step. + private Step combine(AstMutator astMutator, Step currentStep, Step accumulatedStep) { + if (accumulatedStep == null) { + return currentStep; + } + CelMutableAst trueCondition = newTrueLiteral(); + + if (currentStep.isOptional) { + return combineWhenCurrentIsOptional(currentStep, accumulatedStep, astMutator, trueCondition); + } else { + return combineWhenCurrentIsNonOptional( + currentStep, accumulatedStep, astMutator, trueCondition); + } + } + + private Step combineWhenCurrentIsOptional( + Step currentStep, Step accumulatedStep, AstMutator astMutator, CelMutableAst trueCondition) { + // optional.combine(optional) // optional + // (optional && conditional).combine(non-optional) // optional + // (optional && unconditional).combine(non-optional) // non-optional + if (accumulatedStep.isOptional) { + if (currentStep.isConditional) { + return Step.newUnconditionalOptionalStep( + trueCondition, + astMutator.newGlobalCall( + Operator.CONDITIONAL.getFunction(), + currentStep.cond, + currentStep.expr, + accumulatedStep.expr)); + } else { + if (!isOptionalNone(accumulatedStep.expr)) { + // If either the nested rule or current condition output are optional then + // use optional.or() to specify the combination of the first and second results + // Note, the argument order is reversed due to the traversal of matches in + // reverse order. + return Step.newUnconditionalOptionalStep( + trueCondition, + astMutator.newMemberCall(currentStep.expr, "or", accumulatedStep.expr)); + } + return currentStep; + } + } else { // accumulatedStep is non-optional + if (currentStep.isConditional) { + return Step.newUnconditionalOptionalStep( + trueCondition, + astMutator.newGlobalCall( + Operator.CONDITIONAL.getFunction(), + currentStep.cond, + currentStep.expr, + astMutator.newGlobalCall( + Function.OPTIONAL_OF.getFunction(), accumulatedStep.expr))); + } else { + return Step.newUnconditionalNonOptionalStep( + trueCondition, + astMutator.newMemberCall(currentStep.expr, "orValue", accumulatedStep.expr)); + } + } + } + + private Step combineWhenCurrentIsNonOptional( + Step currentStep, Step accumulatedStep, AstMutator astMutator, CelMutableAst trueCondition) { + // non-optional.combine(non-optional) // non-optional + // (non-optional && conditional).combine(optional) // optional + // (non-optional && unconditional).combine(optional) // non-optional + // + // The last combination case is unusual, but effectively it means that the non-optional value + // prunes away + // the potential optional output. + if (accumulatedStep.isOptional) { + if (currentStep.isConditional) { + return Step.newUnconditionalOptionalStep( + trueCondition, + astMutator.newGlobalCall( + Operator.CONDITIONAL.getFunction(), + currentStep.cond, + astMutator.newGlobalCall(Function.OPTIONAL_OF.getFunction(), currentStep.expr), + accumulatedStep.expr)); + } else { + // If the condition is trivially true, none of the matches in the rule causes the result + // to become optional, and the rule is not the last match, then this will introduce + // unreachable outputs or rules (pruning away 'accumulatedStep'). + return currentStep; + } + } else { // accumulatedStep is non-optional + return Step.newUnconditionalNonOptionalStep( + trueCondition, + astMutator.newGlobalCall( + Operator.CONDITIONAL.getFunction(), + currentStep.cond, + currentStep.expr, + accumulatedStep.expr)); + } + } - result = astMutator.renumberIdsConsecutively(result); + private static boolean isOptionalNone(CelMutableAst ast) { + CelMutableExpr expr = ast.expr(); + return expr.getKind().equals(Kind.CALL) + && expr.call().function().equals("optional.none") + && expr.call().args().isEmpty(); + } - return RuleOptimizationResult.create(result, isOptionalResult); + private static CelMutableAst newTrueLiteral() { + return CelMutableAst.of( + CelMutableExpr.ofConstant(CelConstant.ofValue(true)), CelMutableSource.newInstance()); } private CelMutableAst inlineCompiledVariables( @@ -186,11 +266,6 @@ private CelMutableAst inlineCompiledVariables( return mutatedAst; } - static RuleComposer newInstance( - CelCompiledRule compiledRule, String variablePrefix, int iterationLimit) { - return new RuleComposer(compiledRule, variablePrefix, iterationLimit); - } - private void assertComposedAstIsValid( Cel cel, CelMutableAst composedAst, String failureMessage, Long... ids) { assertComposedAstIsValid(cel, composedAst, failureMessage, Arrays.asList(ids)); @@ -206,10 +281,55 @@ private void assertComposedAstIsValid( } } - private RuleComposer(CelCompiledRule compiledRule, String variablePrefix, int iterationLimit) { - this.compiledRule = checkNotNull(compiledRule); - this.variablePrefix = variablePrefix; - this.astMutator = AstMutator.newInstance(iterationLimit); + // Step represents an intermediate stage of rule and match expression composition. + // + // The CelCompiledRule and CelCompiledMatch types are meant to represent standalone tuples of + // condition and output expressions, and have no notion of how the order of combination would + // impact composition since composition rules may vary based on the policy execution semantic, + // e.g. first-match versus logical-or, logical-and, or accumulation. + private static class Step { + /** + * Indicates whether the output step has an optional result. Individual conditional attributes + * are not optional; however, rules and subrules can have optional output. + */ + private final boolean isOptional; + + /** True if the condition expression is not trivially true. */ + private final boolean isConditional; + + /** The condition associated with the output. */ + private final CelMutableAst cond; + + /** The output expression for the step. */ + private final CelMutableAst expr; + + private Step( + boolean isOptional, boolean isConditional, CelMutableAst cond, CelMutableAst expr) { + this.isOptional = isOptional; + this.isConditional = isConditional; + this.cond = cond; + this.expr = expr; + } + + private static Step newOptionalStep( + boolean isConditional, CelMutableAst cond, CelMutableAst expr) { + return new Step(/* isOptional= */ true, isConditional, cond, expr); + } + + private static Step newNonOptionalStep( + boolean isConditional, CelMutableAst cond, CelMutableAst expr) { + return new Step(/* isOptional= */ false, isConditional, cond, expr); + } + + private static Step newUnconditionalOptionalStep( + CelMutableAst trueCondition, CelMutableAst expr) { + return new Step(/* isOptional= */ true, /* isConditional= */ false, trueCondition, expr); + } + + private static Step newUnconditionalNonOptionalStep( + CelMutableAst trueCondition, CelMutableAst expr) { + return new Step(/* isOptional= */ false, /* isConditional= */ false, trueCondition, expr); + } } static final class RuleCompositionException extends RuntimeException { @@ -225,4 +345,10 @@ private RuleCompositionException( this.compileException = e; } } + + private RuleComposer(CelCompiledRule compiledRule, String variablePrefix, int iterationLimit) { + this.compiledRule = checkNotNull(compiledRule); + this.variablePrefix = variablePrefix; + this.astMutator = AstMutator.newInstance(iterationLimit); + } } diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index d4ca76324..35e249407 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -214,7 +214,30 @@ public void evaluateYamlPolicy_withCanonicalTestData( // Read the policy source String policySource = testData.yamlPolicy.readPolicyYamlContent(); CelPolicy policy = POLICY_PARSER.parse(policySource); - CelAbstractSyntaxTree expectedOutputAst = cel.compile(testData.testCase.getOutput()).getAst(); + Object outputObj = testData.testCase.getOutput(); + String exprToCompile; + if (outputObj instanceof String) { + exprToCompile = (String) outputObj; + } else if (outputObj instanceof Map) { + @SuppressWarnings("unchecked") // Test only + Map outputMap = (Map) outputObj; + if (outputMap.containsKey("value")) { + Object value = outputMap.get("value"); + if (value instanceof String) { + String escapedValue = ((String) value).replace("\"", "\\\""); + exprToCompile = "\"" + escapedValue + "\""; // Quote string literals + } else { + exprToCompile = String.valueOf(value); + } + } else if (outputMap.containsKey("expr")) { + exprToCompile = (String) outputMap.get("expr"); + } else { + throw new IllegalArgumentException("Invalid output format: " + outputObj); + } + } else { + throw new IllegalArgumentException("Invalid output format: " + outputObj); + } + CelAbstractSyntaxTree expectedOutputAst = cel.compile(exprToCompile).getAst(); Object expectedOutput = cel.createProgram(expectedOutputAst).eval(); // Act @@ -266,8 +289,8 @@ public void evaluateYamlPolicy_nestedRuleProducesOptionalOutput() throws Excepti CelPolicyCompilerFactory.newPolicyCompiler(cel).build().compile(policy); Optional evalResult = (Optional) cel.createProgram(compiledPolicyAst).eval(); - // Result is Optional> - assertThat(evalResult).hasValue(Optional.of(true)); + // Result is Optional containing true + assertThat(evalResult).hasValue(true); } @Test diff --git a/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java b/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java index 18d5ffc69..59647f4d9 100644 --- a/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java +++ b/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java @@ -40,11 +40,11 @@ final class PolicyTestHelper { enum TestYamlPolicy { NESTED_RULE( "nested_rule", - true, + false, "cel.@block([resource.origin, @index0 in [\"us\", \"uk\", \"es\"], {\"banned\": true}]," + " ((@index0 in {\"us\": false, \"ru\": false, \"ir\": false} && !@index1) ?" - + " optional.of(@index2) : optional.none()).or(optional.of(@index1 ? {\"banned\":" - + " false} : @index2)))"), + + " optional.of(@index2) : optional.none()).orValue(@index1 ? {\"banned\":" + + " false} : @index2))"), NESTED_RULE2( "nested_rule2", false, @@ -61,6 +61,22 @@ enum TestYamlPolicy { + " false, \"ru\": false, \"ir\": false} && @index1) ? {\"banned\":" + " \"restricted_region\"} : {\"banned\": \"bad_actor\"}) : (@index1 ?" + " optional.of({\"banned\": \"unconfigured_region\"}) : optional.none()))"), + NESTED_RULE4("nested_rule4", false, "(x > 0) ? true : false"), + NESTED_RULE5( + "nested_rule5", + true, + "cel.@block([optional.of(true), optional.none()], (x > 0) ? ((x > 2) ? @index0 : @index1) :" + + " ((x > 1) ? ((x >= 2) ? @index0 : @index1) : optional.of(false)))"), + NESTED_RULE6( + "nested_rule6", + false, + "cel.@block([optional.of(true), optional.none()], ((x > 2) ? @index0 : @index1).orValue(((x" + + " > 3) ? @index0 : @index1).orValue(false)))"), + NESTED_RULE7( + "nested_rule7", + true, + "cel.@block([optional.of(true), optional.none()], ((x > 2) ? @index0 : @index1).or(((x > 3)" + + " ? @index0 : @index1).or((x > 1) ? optional.of(false) : @index1)))"), REQUIRED_LABELS( "required_labels", true, @@ -198,7 +214,7 @@ public List getTests() { public static final class PolicyTestCase { private String name; private Map input; - private String output; + private Object output; public void setName(String name) { this.name = name; @@ -208,7 +224,7 @@ public void setInput(Map input) { this.input = input; } - public void setOutput(String output) { + public void setOutput(Object output) { this.output = output; } @@ -220,7 +236,7 @@ public Map getInput() { return input; } - public String getOutput() { + public Object getOutput() { return output; } diff --git a/testing/src/test/resources/policy/k8s/tests.yaml b/testing/src/test/resources/policy/k8s/tests.yaml index 8585c5efb..f3e7de790 100644 --- a/testing/src/test/resources/policy/k8s/tests.yaml +++ b/testing/src/test/resources/policy/k8s/tests.yaml @@ -14,18 +14,19 @@ description: K8s admission control tests section: -- name: "invalid" - tests: - - name: "restricted_container" - input: - resource.namespace: - value: "dev.cel" - resource.labels: - value: - environment: "staging" - resource.containers: - value: - - staging.dev.cel.container1 - - staging.dev.cel.container2 - - preprod.dev.cel.container3 - output: "'only staging containers are allowed in namespace dev.cel'" + - name: "invalid" + tests: + - name: "restricted_container" + input: + resource.namespace: + value: "dev.cel" + resource.labels: + value: + environment: "staging" + resource.containers: + value: + - staging.dev.cel.container1 + - staging.dev.cel.container2 + - preprod.dev.cel.container3 + output: + value: "only staging containers are allowed in namespace dev.cel" diff --git a/testing/src/test/resources/policy/limits/tests.yaml b/testing/src/test/resources/policy/limits/tests.yaml index fe6daa61d..88772e075 100644 --- a/testing/src/test/resources/policy/limits/tests.yaml +++ b/testing/src/test/resources/policy/limits/tests.yaml @@ -14,25 +14,29 @@ description: Limits related tests section: -- name: "now_after_hours" - tests: - - name: "7pm" - input: - now: - expr: "timestamp('2024-07-30T00:30:00Z')" - output: "'hello, me'" - - name: "8pm" - input: - now: - expr: "timestamp('2024-07-30T20:30:00Z')" - output: "'goodbye, me!'" - - name: "9pm" - input: - now: - expr: "timestamp('2024-07-30T21:30:00Z')" - output: "'goodbye, me!!'" - - name: "11pm" - input: - now: - expr: "timestamp('2024-07-30T23:30:00Z')" - output: "'goodbye, me!!!'" \ No newline at end of file + - name: "now_after_hours" + tests: + - name: "7pm" + input: + now: + expr: "timestamp('2024-07-30T00:30:00Z')" + output: + value: "hello, me" + - name: "8pm" + input: + now: + expr: "timestamp('2024-07-30T20:30:00Z')" + output: + value: "goodbye, me!" + - name: "9pm" + input: + now: + expr: "timestamp('2024-07-30T21:30:00Z')" + output: + value: "goodbye, me!!" + - name: "11pm" + input: + now: + expr: "timestamp('2024-07-30T23:30:00Z')" + output: + value: "goodbye, me!!!" diff --git a/testing/src/test/resources/policy/nested_rule/tests.yaml b/testing/src/test/resources/policy/nested_rule/tests.yaml index a9807c376..3f9f63437 100644 --- a/testing/src/test/resources/policy/nested_rule/tests.yaml +++ b/testing/src/test/resources/policy/nested_rule/tests.yaml @@ -16,23 +16,26 @@ description: Nested rule conformance tests section: - name: "banned" tests: - - name: "restricted_origin" - input: - resource: - value: - origin: "ir" - output: "{'banned': true}" - - name: "by_default" - input: - resource: - value: - origin: "de" - output: "{'banned': true}" + - name: "restricted_origin" + input: + resource: + value: + origin: "ir" + output: + expr: "{'banned': true}" + - name: "by_default" + input: + resource: + value: + origin: "de" + output: + expr: "{'banned': true}" - name: "permitted" tests: - - name: "valid_origin" - input: - resource: - value: - origin: "uk" - output: "{'banned': false}" + - name: "valid_origin" + input: + resource: + value: + origin: "uk" + output: + expr: "{'banned': false}" diff --git a/testing/src/test/resources/policy/nested_rule2/tests.yaml b/testing/src/test/resources/policy/nested_rule2/tests.yaml index b5fbba745..0e1a9ca69 100644 --- a/testing/src/test/resources/policy/nested_rule2/tests.yaml +++ b/testing/src/test/resources/policy/nested_rule2/tests.yaml @@ -14,35 +14,39 @@ description: Nested rule conformance tests section: -- name: "banned" - tests: - - name: "restricted_origin" - input: - resource: - value: - user: "bad-user" - origin: "ir" - output: "{'banned': 'restricted_region'}" - - name: "by_default" - input: - resource: - value: - user: "bad-user" - origin: "de" - output: "{'banned': 'bad_actor'}" - - name: "unconfigured_region" - input: - resource: - value: - user: "good-user" - origin: "de" - output: "{'banned': 'unconfigured_region'}" -- name: "permitted" - tests: - - name: "valid_origin" - input: - resource: - value: - user: "good-user" - origin: "uk" - output: "{}" \ No newline at end of file + - name: "banned" + tests: + - name: "restricted_origin" + input: + resource: + value: + user: "bad-user" + origin: "ir" + output: + expr: "{'banned': 'restricted_region'}" + - name: "by_default" + input: + resource: + value: + user: "bad-user" + origin: "de" + output: + expr: "{'banned': 'bad_actor'}" + - name: "unconfigured_region" + input: + resource: + value: + user: "good-user" + origin: "de" + output: + expr: "{'banned': 'unconfigured_region'}" + - name: "permitted" + tests: + - name: "valid_origin" + input: + resource: + value: + user: "good-user" + origin: "uk" + output: + expr: "{}" diff --git a/testing/src/test/resources/policy/nested_rule3/tests.yaml b/testing/src/test/resources/policy/nested_rule3/tests.yaml index b10785d0c..9d993c65f 100644 --- a/testing/src/test/resources/policy/nested_rule3/tests.yaml +++ b/testing/src/test/resources/policy/nested_rule3/tests.yaml @@ -14,35 +14,39 @@ description: Nested rule conformance tests section: -- name: "banned" - tests: - - name: "restricted_origin" - input: - resource: - value: - user: "bad-user" - origin: "ir" - output: "{'banned': 'restricted_region'}" - - name: "by_default" - input: - resource: - value: - user: "bad-user" - origin: "de" - output: "{'banned': 'bad_actor'}" - - name: "unconfigured_region" - input: - resource: - value: - user: "good-user" - origin: "de" - output: "{'banned': 'unconfigured_region'}" -- name: "permitted" - tests: - - name: "valid_origin" - input: - resource: - value: - user: "good-user" - origin: "uk" - output: "optional.none()" \ No newline at end of file + - name: "banned" + tests: + - name: "restricted_origin" + input: + resource: + value: + user: "bad-user" + origin: "ir" + output: + expr: "{'banned': 'restricted_region'}" + - name: "by_default" + input: + resource: + value: + user: "bad-user" + origin: "de" + output: + expr: "{'banned': 'bad_actor'}" + - name: "unconfigured_region" + input: + resource: + value: + user: "good-user" + origin: "de" + output: + expr: "{'banned': 'unconfigured_region'}" + - name: "permitted" + tests: + - name: "valid_origin" + input: + resource: + value: + user: "good-user" + origin: "uk" + output: + expr: "optional.none()" diff --git a/testing/src/test/resources/policy/nested_rule4/config.yaml b/testing/src/test/resources/policy/nested_rule4/config.yaml new file mode 100644 index 000000000..5afb8c587 --- /dev/null +++ b/testing/src/test/resources/policy/nested_rule4/config.yaml @@ -0,0 +1,19 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "nested_rule4" +variables: + - name: x + type: + type_name: int diff --git a/testing/src/test/resources/policy/nested_rule4/policy.yaml b/testing/src/test/resources/policy/nested_rule4/policy.yaml new file mode 100644 index 000000000..ea53bfb25 --- /dev/null +++ b/testing/src/test/resources/policy/nested_rule4/policy.yaml @@ -0,0 +1,24 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: nested_rule4 +rule: + match: + - condition: x > 0 + rule: + match: + - rule: + match: + - output: "true" + - output: "false" diff --git a/testing/src/test/resources/policy/nested_rule4/tests.yaml b/testing/src/test/resources/policy/nested_rule4/tests.yaml new file mode 100644 index 000000000..006eddb88 --- /dev/null +++ b/testing/src/test/resources/policy/nested_rule4/tests.yaml @@ -0,0 +1,30 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +description: "Nested rule tests which explore optional vs non-optional returns" +section: + - name: "valid" + tests: + - name: "x=0" + input: + x: + value: 0 + output: + value: false + - name: "x=2" + input: + x: + value: 2 + output: + value: true diff --git a/testing/src/test/resources/policy/nested_rule5/config.yaml b/testing/src/test/resources/policy/nested_rule5/config.yaml new file mode 100644 index 000000000..499450090 --- /dev/null +++ b/testing/src/test/resources/policy/nested_rule5/config.yaml @@ -0,0 +1,19 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "nested_rule5" +variables: + - name: x + type: + type_name: int diff --git a/testing/src/test/resources/policy/nested_rule5/policy.yaml b/testing/src/test/resources/policy/nested_rule5/policy.yaml new file mode 100644 index 000000000..e43dce188 --- /dev/null +++ b/testing/src/test/resources/policy/nested_rule5/policy.yaml @@ -0,0 +1,30 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: nested_rule5 +rule: + match: + - condition: x > 0 + rule: + match: + - rule: + match: + - condition: "x > 2" + output: "true" + - condition: x > 1 + rule: + match: + - condition: "x >= 2" + output: "true" + - output: "false" diff --git a/testing/src/test/resources/policy/nested_rule5/tests.yaml b/testing/src/test/resources/policy/nested_rule5/tests.yaml new file mode 100644 index 000000000..8cd794051 --- /dev/null +++ b/testing/src/test/resources/policy/nested_rule5/tests.yaml @@ -0,0 +1,42 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +description: "Nested rule tests which explore optional vs non-optional returns" +section: + - name: "valid" + tests: + - name: "x=0" + input: + x: + value: 0 + output: + value: false + - name: "x=1" + input: + x: + value: 1 + output: + expr: "optional.none()" + - name: "x=2" + input: + x: + value: 2 + output: + expr: "optional.none()" + - name: "x=3" + input: + x: + value: 3 + output: + value: true diff --git a/testing/src/test/resources/policy/nested_rule6/config.yaml b/testing/src/test/resources/policy/nested_rule6/config.yaml new file mode 100644 index 000000000..a5b1ee16b --- /dev/null +++ b/testing/src/test/resources/policy/nested_rule6/config.yaml @@ -0,0 +1,19 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "nested_rule6" +variables: + - name: x + type: + type_name: int diff --git a/testing/src/test/resources/policy/nested_rule6/policy.yaml b/testing/src/test/resources/policy/nested_rule6/policy.yaml new file mode 100644 index 000000000..a3360e7c1 --- /dev/null +++ b/testing/src/test/resources/policy/nested_rule6/policy.yaml @@ -0,0 +1,28 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: nested_rule6 +rule: + match: + - rule: + match: + - rule: + match: + - condition: "x > 2" + output: "true" + - rule: + match: + - condition: "x > 3" + output: "true" + - output: "false" diff --git a/testing/src/test/resources/policy/nested_rule6/tests.yaml b/testing/src/test/resources/policy/nested_rule6/tests.yaml new file mode 100644 index 000000000..fef586df0 --- /dev/null +++ b/testing/src/test/resources/policy/nested_rule6/tests.yaml @@ -0,0 +1,24 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +description: "Nested rule tests which explore optional vs non-optional returns" +section: + - name: "valid" + tests: + - name: "x=0" + input: + x: + value: 0 + output: + value: false diff --git a/testing/src/test/resources/policy/nested_rule7/config.yaml b/testing/src/test/resources/policy/nested_rule7/config.yaml new file mode 100644 index 000000000..74d4d8c2d --- /dev/null +++ b/testing/src/test/resources/policy/nested_rule7/config.yaml @@ -0,0 +1,19 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "nested_rule7" +variables: + - name: x + type: + type_name: int diff --git a/testing/src/test/resources/policy/nested_rule7/policy.yaml b/testing/src/test/resources/policy/nested_rule7/policy.yaml new file mode 100644 index 000000000..fcacd017e --- /dev/null +++ b/testing/src/test/resources/policy/nested_rule7/policy.yaml @@ -0,0 +1,29 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: nested_rule7 +rule: + match: + - rule: + match: + - rule: + match: + - condition: "x > 2" + output: "true" + - rule: + match: + - condition: "x > 3" + output: "true" + - condition: "x > 1" + output: "false" diff --git a/testing/src/test/resources/policy/nested_rule7/tests.yaml b/testing/src/test/resources/policy/nested_rule7/tests.yaml new file mode 100644 index 000000000..ec2896878 --- /dev/null +++ b/testing/src/test/resources/policy/nested_rule7/tests.yaml @@ -0,0 +1,42 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +description: "Nested rule tests which explore optional vs non-optional returns" +section: + - name: "valid" + tests: + - name: "x=1" + input: + x: + value: 1 + output: + expr: "optional.none()" + - name: "x=2" + input: + x: + value: 2 + output: + value: false + - name: "x=3" + input: + x: + value: 3 + output: + value: true + - name: "x=4" + input: + x: + value: 4 + output: + value: true diff --git a/testing/src/test/resources/policy/pb/tests.yaml b/testing/src/test/resources/policy/pb/tests.yaml index 82dd6b11b..71cd56b57 100644 --- a/testing/src/test/resources/policy/pb/tests.yaml +++ b/testing/src/test/resources/policy/pb/tests.yaml @@ -14,20 +14,21 @@ description: "Protobuf input tests" section: -- name: "valid" - tests: - - name: "good spec" - input: - spec: - expr: > - TestAllTypes{single_int32: 10} - output: "optional.none()" -- name: "invalid" - tests: - - name: "bad spec" - input: - spec: - expr: > - TestAllTypes{single_int32: 11} - output: > - "invalid spec, got single_int32=11, wanted <= 10" + - name: "valid" + tests: + - name: "good spec" + input: + spec: + expr: > + TestAllTypes{single_int32: 10} + output: + expr: "optional.none()" + - name: "invalid" + tests: + - name: "bad spec" + input: + spec: + expr: > + TestAllTypes{single_int32: 11} + output: + value: "invalid spec, got single_int32=11, wanted <= 10" diff --git a/testing/src/test/resources/policy/required_labels/tests.yaml b/testing/src/test/resources/policy/required_labels/tests.yaml index 67681ef46..4296c6914 100644 --- a/testing/src/test/resources/policy/required_labels/tests.yaml +++ b/testing/src/test/resources/policy/required_labels/tests.yaml @@ -16,64 +16,65 @@ description: "Required labels conformance tests" section: - name: "valid" tests: - - name: "matching" - input: - spec: - value: - labels: - env: prod - experiment: "group b" - resource: - value: - labels: - env: prod - experiment: "group b" - release: "v0.1.0" - output: "optional.none()" + - name: "matching" + input: + spec: + value: + labels: + env: prod + experiment: "group b" + resource: + value: + labels: + env: prod + experiment: "group b" + release: "v0.1.0" + output: + expr: "optional.none()" - name: "missing" tests: - - name: "env" - input: - spec: - value: - labels: - env: prod - experiment: "group b" - resource: - value: - labels: - experiment: "group b" - release: "v0.1.0" - output: > - "missing one or more required labels: [\"env\"]" - - name: "experiment" - input: - spec: - value: - labels: - env: prod - experiment: "group b" - resource: - value: - labels: - env: staging - release: "v0.1.0" - output: > - "missing one or more required labels: [\"experiment\"]" + - name: "env" + input: + spec: + value: + labels: + env: prod + experiment: "group b" + resource: + value: + labels: + experiment: "group b" + release: "v0.1.0" + output: + value: "missing one or more required labels: [\"env\"]" + - name: "experiment" + input: + spec: + value: + labels: + env: prod + experiment: "group b" + resource: + value: + labels: + env: staging + release: "v0.1.0" + output: + value: "missing one or more required labels: [\"experiment\"]" - name: "invalid" tests: - - name: "env" - input: - spec: - value: - labels: - env: prod - experiment: "group b" - resource: - value: - labels: - env: staging - experiment: "group b" - release: "v0.1.0" - output: > - "invalid values provided on one or more labels: [\"env\"]" + - name: "env" + input: + spec: + value: + labels: + env: prod + experiment: "group b" + resource: + value: + labels: + env: staging + experiment: "group b" + release: "v0.1.0" + output: + value: "invalid values provided on one or more labels: [\"env\"]" diff --git a/testing/src/test/resources/policy/restricted_destinations/tests.yaml b/testing/src/test/resources/policy/restricted_destinations/tests.yaml index c0feeb202..f7ae36550 100644 --- a/testing/src/test/resources/policy/restricted_destinations/tests.yaml +++ b/testing/src/test/resources/policy/restricted_destinations/tests.yaml @@ -16,103 +16,107 @@ description: Restricted destinations conformance tests. section: - name: "valid" tests: - - name: "ip_allowed" - input: - "spec.origin": - value: "us" - "spec.restricted_destinations": - value: - - "cu" - - "ir" - - "kp" - - "sd" - - "sy" - "destination.ip": - value: "10.0.0.1" - "origin.ip": - value: "10.0.0.1" - request: - value: - auth: - claims: {} - resource: - value: - name: "/company/acme/secrets/doomsday-device" - labels: - location: "us" - output: "false" # false means unrestricted - - name: "nationality_allowed" - input: - "spec.origin": - value: "us" - "spec.restricted_destinations": - value: - - "cu" - - "ir" - - "kp" - - "sd" - - "sy" - "destination.ip": - value: "10.0.0.1" - request: - value: - auth: - claims: - nationality: "us" - resource: - value: - name: "/company/acme/secrets/doomsday-device" - labels: - location: "us" - output: "false" + - name: "ip_allowed" + input: + spec.origin: + value: "us" + spec.restricted_destinations: + value: + - "cu" + - "ir" + - "kp" + - "sd" + - "sy" + destination.ip: + value: "10.0.0.1" + origin.ip: + value: "10.0.0.1" + request: + value: + auth: + claims: {} + resource: + value: + name: "/company/acme/secrets/doomsday-device" + labels: + location: "us" + output: + value: false # false means unrestricted + - name: "nationality_allowed" + input: + spec.origin: + value: "us" + spec.restricted_destinations: + value: + - "cu" + - "ir" + - "kp" + - "sd" + - "sy" + destination.ip: + value: "10.0.0.1" + request: + value: + auth: + claims: + nationality: "us" + resource: + value: + name: "/company/acme/secrets/doomsday-device" + labels: + location: "us" + output: + value: false - name: "invalid" tests: - - name: "destination_ip_prohibited" - input: - "spec.origin": - value: "us" - "spec.restricted_destinations": - value: - - "cu" - - "ir" - - "kp" - - "sd" - - "sy" - "destination.ip": - value: "123.123.123.123" - "origin.ip": - value: "10.0.0.1" - request: - value: - auth: - claims: {} - resource: - value: - name: "/company/acme/secrets/doomsday-device" - labels: - location: "us" - output: "true" # true means restricted - - name: "resource_nationality_prohibited" - input: - "spec.origin": - value: "us" - "spec.restricted_destinations": - value: - - "cu" - - "ir" - - "kp" - - "sd" - - "sy" - "destination.ip": - value: "10.0.0.1" - request: - value: - auth: - claims: - nationality: "us" - resource: - value: - name: "/company/acme/secrets/doomsday-device" - labels: - location: "cu" - output: "true" + - name: "destination_ip_prohibited" + input: + spec.origin: + value: "us" + spec.restricted_destinations: + value: + - "cu" + - "ir" + - "kp" + - "sd" + - "sy" + destination.ip: + value: "123.123.123.123" + origin.ip: + value: "10.0.0.1" + request: + value: + auth: + claims: {} + resource: + value: + name: "/company/acme/secrets/doomsday-device" + labels: + location: "us" + output: + value: true # true means restricted + - name: "resource_nationality_prohibited" + input: + spec.origin: + value: "us" + spec.restricted_destinations: + value: + - "cu" + - "ir" + - "kp" + - "sd" + - "sy" + destination.ip: + value: "10.0.0.1" + request: + value: + auth: + claims: + nationality: "us" + resource: + value: + name: "/company/acme/secrets/doomsday-device" + labels: + location: "cu" + output: + value: true From 9ff1f233efd1a9e2f4d8b22e24e13e35cc83fb06 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 1 May 2026 11:54:57 -0700 Subject: [PATCH 53/66] Introduce CEL Policy Conformance Test Runner for Java PiperOrigin-RevId: 908834020 --- .../dev/cel/conformance/policy/BUILD.bazel | 25 +++ .../policy/PolicyConformanceTest.java | 72 +++++++ .../policy/PolicyConformanceTestRunner.java | 190 ++++++++++++++++++ .../policy/PolicyConformanceTests.java | 21 ++ .../policy/cel_policy_conformance_test.bzl | 50 +++++ 5 files changed, 358 insertions(+) create mode 100644 conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel create mode 100644 conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java create mode 100644 conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTestRunner.java create mode 100644 conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTests.java create mode 100644 conformance/src/test/java/dev/cel/conformance/policy/cel_policy_conformance_test.bzl diff --git a/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel b/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel new file mode 100644 index 000000000..9a9b14f74 --- /dev/null +++ b/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel @@ -0,0 +1,25 @@ +load("@rules_java//java:defs.bzl", "java_library") +load(":cel_policy_conformance_test.bzl", "cel_policy_conformance_test_java") + +package( + default_applicable_licenses = ["//:license"], + default_testonly = True, +) + +java_library( + name = "run", + srcs = glob(["*.java"]), + deps = [ + "//:auto_value", + "//bundle:cel", + "//testing/testrunner:cel_expression_source", + "//testing/testrunner:cel_test_context", + "//testing/testrunner:cel_test_suite", + "//testing/testrunner:cel_test_suite_text_proto_parser", + "//testing/testrunner:cel_test_suite_yaml_parser", + "//testing/testrunner:test_runner_library", + "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", + "@maven//:junit_junit", + ], +) diff --git a/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java new file mode 100644 index 000000000..700539927 --- /dev/null +++ b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java @@ -0,0 +1,72 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.conformance.policy; + +import com.google.protobuf.ListValue; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; +import dev.cel.testing.testrunner.CelExpressionSource; +import dev.cel.testing.testrunner.CelTestContext; +import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase; +import dev.cel.testing.testrunner.TestRunnerLibrary; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import org.junit.runners.model.Statement; + +/** Statement representing a single CEL policy conformance test case. */ +public final class PolicyConformanceTest extends Statement { + + private static final Cel CEL = CelFactory.standardCelBuilder().build(); + + private final String name; + private final CelTestCase testCase; + private final String dirPath; + + public PolicyConformanceTest(String name, CelTestCase testCase, String dirPath) { + this.name = name; + this.testCase = testCase; + this.dirPath = dirPath; + } + + public String getName() { + return name; + } + + @Override + public void evaluate() throws Throwable { + String policyFile = Paths.get(dirPath, "policy.yaml").toString(); + + CelTestContext.Builder contextBuilder = + CelTestContext.newBuilder() + .setCelExpression(CelExpressionSource.fromSource(policyFile)) + .setCel(CEL) + .addMessageTypes( + Struct.getDescriptor(), Value.getDescriptor(), ListValue.getDescriptor()); + + Path yamlConfigPath = Paths.get(dirPath, "config.yaml"); + Path textprotoConfigPath = Paths.get(dirPath, "config.textproto"); + + if (Files.exists(yamlConfigPath)) { + contextBuilder.setConfigFile(yamlConfigPath.toString()); + } else if (Files.exists(textprotoConfigPath)) { + contextBuilder.setConfigFile(textprotoConfigPath.toString()); + } + + TestRunnerLibrary.runTest(testCase, contextBuilder.build()); + } +} diff --git a/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTestRunner.java b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTestRunner.java new file mode 100644 index 000000000..6d1d86e47 --- /dev/null +++ b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTestRunner.java @@ -0,0 +1,190 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.conformance.policy; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.auto.value.AutoValue; +import com.google.common.base.Splitter; +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.common.io.Files; +import com.google.protobuf.ListValue; +import com.google.protobuf.Struct; +import com.google.protobuf.TypeRegistry; +import com.google.protobuf.Value; +import dev.cel.testing.testrunner.CelTestSuite; +import dev.cel.testing.testrunner.CelTestSuite.CelTestSection; +import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase; +import dev.cel.testing.testrunner.CelTestSuiteTextProtoParser; +import dev.cel.testing.testrunner.CelTestSuiteYamlParser; +import java.io.File; +import java.util.Arrays; +import java.util.List; +import org.junit.runner.Description; +import org.junit.runner.notification.RunNotifier; +import org.junit.runners.ParentRunner; +import org.junit.runners.model.InitializationError; + +/** Custom JUnit runner for CEL policy conformance tests. */ +public final class PolicyConformanceTestRunner extends ParentRunner { + + private static final Splitter SPLITTER = Splitter.on(",").omitEmptyStrings(); + private static final String TESTS_YAML_FILE_NAME = "tests.yaml"; + private static final String TESTS_TEXTPROTO_FILE_NAME = "tests.textproto"; + private static final TypeRegistry TYPE_REGISTRY = + TypeRegistry.newBuilder() + .add(Struct.getDescriptor()) + .add(Value.getDescriptor()) + .add(ListValue.getDescriptor()) + .build(); + + private static final String TEST_DIRS_PROP = + System.getProperty("dev.cel.policy.conformance.tests"); + private static final String TESTDATA_DIR = + System.getProperty("dev.cel.policy.conformance.testdata_dir", "testdata"); + private static final String SKIP_TESTS_PROP = + System.getProperty("dev.cel.policy.conformance.skip_tests"); + + private static final ImmutableList TESTS_TO_SKIP = + Strings.isNullOrEmpty(SKIP_TESTS_PROP) + ? ImmutableList.of() + : ImmutableList.copyOf(SPLITTER.splitToList(SKIP_TESTS_PROP)); + + private static final ImmutableList TEST_DIRS = + Strings.isNullOrEmpty(TEST_DIRS_PROP) + ? discoverTestDirs(TESTDATA_DIR) + : ImmutableList.copyOf(SPLITTER.splitToList(TEST_DIRS_PROP)); + + private static ImmutableList discoverTestDirs(String testdataDir) { + File dir = new File(testdataDir); + if (!dir.exists() || !dir.isDirectory()) { + return ImmutableList.of(); + } + String[] directories = dir.list((current, name) -> new File(current, name).isDirectory()); + if (directories == null) { + return ImmutableList.of(); + } + Arrays.sort(directories); + return ImmutableList.copyOf(directories); + } + + private final ImmutableList tests; + + private ImmutableList loadTests() { + if (TEST_DIRS.isEmpty()) { + return ImmutableList.of(); + } + + ImmutableList.Builder testsBuilder = ImmutableList.builder(); + + for (String dir : TEST_DIRS) { + String fullDirPath = TESTDATA_DIR + "/" + dir; + try { + ImmutableList suites = readTestSuites(fullDirPath); + for (CelTestSuiteContext namedSuite : suites) { + for (CelTestSection section : namedSuite.testSuite().sections()) { + for (CelTestCase testCase : section.tests()) { + String baseName = String.format("%s/%s/%s", dir, section.name(), testCase.name()); + String displayName = baseName + namedSuite.formatSuffix(); + if (!shouldSkipTest(baseName, TESTS_TO_SKIP)) { + testsBuilder.add(new PolicyConformanceTest(displayName, testCase, fullDirPath)); + } + } + } + } + } catch (Exception e) { + throw new RuntimeException("Failed to load test suite in " + fullDirPath, e); + } + } + return testsBuilder.build(); + } + + private static boolean shouldSkipTest(String name, List testsToSkip) { + for (String testToSkip : testsToSkip) { + if (name.startsWith(testToSkip)) { + String consumedName = name.substring(testToSkip.length()); + if (consumedName.isEmpty() || consumedName.startsWith("/")) { + return true; + } + } + } + return false; + } + + private static ImmutableList readTestSuites(String dirPath) + throws Exception { + File dir = new File(dirPath); + File yamlFile = new File(dir, TESTS_YAML_FILE_NAME); + File textprotoFile = new File(dir, TESTS_TEXTPROTO_FILE_NAME); + + boolean bothExist = yamlFile.exists() && textprotoFile.exists(); + ImmutableList.Builder suitesBuilder = ImmutableList.builder(); + + if (yamlFile.exists()) { + suitesBuilder.add( + CelTestSuiteContext.create( + CelTestSuiteYamlParser.newInstance() + .parse(Files.asCharSource(yamlFile, UTF_8).read()), + bothExist ? " (yaml)" : "")); + } + if (textprotoFile.exists()) { + suitesBuilder.add( + CelTestSuiteContext.create( + CelTestSuiteTextProtoParser.newInstance() + .parse(Files.asCharSource(textprotoFile, UTF_8).read(), TYPE_REGISTRY), + bothExist ? " (textproto)" : "")); + } + + ImmutableList suites = suitesBuilder.build(); + if (suites.isEmpty()) { + throw new IllegalArgumentException( + String.format( + "No %s or %s found in %s", TESTS_YAML_FILE_NAME, TESTS_TEXTPROTO_FILE_NAME, dirPath)); + } + return suites; + } + + @Override + protected ImmutableList getChildren() { + return tests; + } + + @Override + protected Description describeChild(PolicyConformanceTest child) { + return Description.createTestDescription(getTestClass().getJavaClass(), child.getName()); + } + + @Override + protected void runChild(PolicyConformanceTest child, RunNotifier notifier) { + runLeaf(child, describeChild(child), notifier); + } + + public PolicyConformanceTestRunner(Class clazz) throws InitializationError { + super(clazz); + this.tests = loadTests(); + } + + @AutoValue + abstract static class CelTestSuiteContext { + abstract CelTestSuite testSuite(); + + abstract String formatSuffix(); + + static CelTestSuiteContext create(CelTestSuite testSuite, String formatSuffix) { + return new AutoValue_PolicyConformanceTestRunner_CelTestSuiteContext(testSuite, formatSuffix); + } + } +} diff --git a/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTests.java b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTests.java new file mode 100644 index 000000000..46596763e --- /dev/null +++ b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTests.java @@ -0,0 +1,21 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.conformance.policy; + +import org.junit.runner.RunWith; + +/** Main test class for CEL policy conformance tests. */ +@RunWith(PolicyConformanceTestRunner.class) +public class PolicyConformanceTests {} diff --git a/conformance/src/test/java/dev/cel/conformance/policy/cel_policy_conformance_test.bzl b/conformance/src/test/java/dev/cel/conformance/policy/cel_policy_conformance_test.bzl new file mode 100644 index 000000000..3e3720ec5 --- /dev/null +++ b/conformance/src/test/java/dev/cel/conformance/policy/cel_policy_conformance_test.bzl @@ -0,0 +1,50 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Macro to run CEL policy conformance tests.""" + +load("@rules_java//java:defs.bzl", "java_test") + +def cel_policy_conformance_test_java( + name, + testdata, + test_cases = [], + skip_tests = [], + **kwargs): + """Macro to run CEL policy conformance tests for Java. + + Args: + name: The name of the test target. + testdata: Testdata filegroup target. + test_cases: (optional) List of test case names (directory names) to run. + skip_tests: (optional) List of test case names (directory names) to skip. + **kwargs: Other standard Bazel target attributes. + """ + + lbl = native.package_relative_label(testdata) + testdata_dir = lbl.package + "/" + lbl.name + + java_test( + name = name, + jvm_flags = [ + "-Ddev.cel.policy.conformance.tests=" + ",".join(test_cases), + "-Ddev.cel.policy.conformance.testdata_dir=" + testdata_dir, + "-Ddev.cel.policy.conformance.skip_tests=" + ",".join(skip_tests), + ], + data = [testdata], + size = "small", + test_class = "dev.cel.conformance.policy.PolicyConformanceTests", + runtime_deps = [Label(":run")], + **kwargs + ) From fc24bfde5a9a6f82e1830abdec98f169a761576f Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 1 May 2026 15:07:48 -0700 Subject: [PATCH 54/66] Internal Changes PiperOrigin-RevId: 908917795 --- .../dev/cel/conformance/policy/BUILD.bazel | 1 + .../policy/PolicyConformanceTest.java | 21 ++++++++++++++++++- .../testrunner/CelTestSuiteYamlParser.java | 2 +- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel b/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel index 9a9b14f74..236bacd93 100644 --- a/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel +++ b/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel @@ -12,6 +12,7 @@ java_library( deps = [ "//:auto_value", "//bundle:cel", + "//runtime:function_binding", "//testing/testrunner:cel_expression_source", "//testing/testrunner:cel_test_context", "//testing/testrunner:cel_test_suite", diff --git a/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java index 700539927..5d6c84dcd 100644 --- a/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java +++ b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java @@ -19,6 +19,7 @@ import com.google.protobuf.Value; import dev.cel.bundle.Cel; import dev.cel.bundle.CelFactory; +import dev.cel.runtime.CelFunctionBinding; import dev.cel.testing.testrunner.CelExpressionSource; import dev.cel.testing.testrunner.CelTestContext; import dev.cel.testing.testrunner.CelTestSuite.CelTestSection.CelTestCase; @@ -31,7 +32,25 @@ /** Statement representing a single CEL policy conformance test case. */ public final class PolicyConformanceTest extends Statement { - private static final Cel CEL = CelFactory.standardCelBuilder().build(); + private static final Cel CEL = + CelFactory.standardCelBuilder() + .addFunctionBindings( + CelFunctionBinding.fromOverloads( + "locationCode", + CelFunctionBinding.from( + "locationCode_string", + String.class, + (ip) -> { + switch (ip) { + case "10.0.0.1": + return "us"; + case "10.0.0.2": + return "de"; + default: + return "ir"; + } + }))) + .build(); private final String name; private final CelTestCase testCase; diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteYamlParser.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteYamlParser.java index 71c4b9231..2340bf229 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteYamlParser.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestSuiteYamlParser.java @@ -90,7 +90,7 @@ private CelTestSuite parseYaml(String celTestSuiteYamlContent, String descriptio } private CelTestSuite.Builder parseTestSuite(ParserContext ctx, Node node) { - CelTestSuite.Builder builder = CelTestSuite.newBuilder(); + CelTestSuite.Builder builder = CelTestSuite.newBuilder().setName("").setDescription(""); long id = ctx.collectMetadata(node); if (!assertYamlType(ctx, id, node, YamlNodeType.MAP)) { ctx.reportError(id, "Unknown test suite type: " + node.getTag()); From d08c4243a85b4394d0e929c76b912763f170eec2 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Mon, 4 May 2026 15:29:11 -0700 Subject: [PATCH 55/66] Internal Changes PiperOrigin-RevId: 910276432 --- .../src/test/java/dev/cel/conformance/policy/BUILD.bazel | 1 + .../dev/cel/conformance/policy/PolicyConformanceTest.java | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel b/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel index 236bacd93..ca34530fb 100644 --- a/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel +++ b/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel @@ -19,6 +19,7 @@ java_library( "//testing/testrunner:cel_test_suite_text_proto_parser", "//testing/testrunner:cel_test_suite_yaml_parser", "//testing/testrunner:test_runner_library", + "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", "@maven//:junit_junit", diff --git a/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java index 5d6c84dcd..cd24339c0 100644 --- a/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java +++ b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java @@ -14,11 +14,10 @@ package dev.cel.conformance.policy; -import com.google.protobuf.ListValue; import com.google.protobuf.Struct; -import com.google.protobuf.Value; import dev.cel.bundle.Cel; import dev.cel.bundle.CelFactory; +import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.runtime.CelFunctionBinding; import dev.cel.testing.testrunner.CelExpressionSource; import dev.cel.testing.testrunner.CelTestContext; @@ -74,8 +73,9 @@ public void evaluate() throws Throwable { CelTestContext.newBuilder() .setCelExpression(CelExpressionSource.fromSource(policyFile)) .setCel(CEL) - .addMessageTypes( - Struct.getDescriptor(), Value.getDescriptor(), ListValue.getDescriptor()); + .addFileTypes( + TestAllTypes.getDescriptor().getFile(), + Struct.getDescriptor().getFile()); Path yamlConfigPath = Paths.get(dirPath, "config.yaml"); Path textprotoConfigPath = Paths.get(dirPath, "config.textproto"); From 5fc176681ecf09b360f5b503228d916e9e89d6fc Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 5 May 2026 14:45:32 -0700 Subject: [PATCH 56/66] Implement context_variable support in YAML environment PiperOrigin-RevId: 910924913 --- .../src/main/java/dev/cel/bundle/BUILD.bazel | 1 + .../java/dev/cel/bundle/CelEnvironment.java | 23 +++++++++++++ .../cel/bundle/CelEnvironmentYamlParser.java | 34 +++++++++++++++++++ .../dev/cel/testing/testrunner/BUILD.bazel | 1 + .../testing/testrunner/CelTestContext.java | 22 ++++++------ .../testing/testrunner/TestRunnerLibrary.java | 33 +++++++++++++----- .../testrunner/TestRunnerLibraryTest.java | 2 +- 7 files changed, 95 insertions(+), 21 deletions(-) diff --git a/bundle/src/main/java/dev/cel/bundle/BUILD.bazel b/bundle/src/main/java/dev/cel/bundle/BUILD.bazel index 742f718f1..716442849 100644 --- a/bundle/src/main/java/dev/cel/bundle/BUILD.bazel +++ b/bundle/src/main/java/dev/cel/bundle/BUILD.bazel @@ -104,6 +104,7 @@ java_library( ":required_fields_checker", "//:auto_value", "//bundle:cel", + "//checker:proto_type_mask", "//checker:standard_decl", "//common:compiler_common", "//common:container", diff --git a/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java b/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java index b85f16cb1..ccbaef61b 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java +++ b/bundle/src/main/java/dev/cel/bundle/CelEnvironment.java @@ -30,6 +30,7 @@ import dev.cel.checker.CelStandardDeclarations; import dev.cel.checker.CelStandardDeclarations.StandardFunction; import dev.cel.checker.CelStandardDeclarations.StandardOverload; +import dev.cel.checker.ProtoTypeMask; import dev.cel.common.CelContainer; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; @@ -134,6 +135,9 @@ public abstract class CelEnvironment { /** Limits to set in the environment. */ public abstract ImmutableSet limits(); + /** Context variable to enable in the environment. */ + public abstract Optional contextVariable(); + /** Builder for {@link CelEnvironment}. */ @AutoValue.Builder public abstract static class Builder { @@ -199,6 +203,8 @@ public Builder setLimits(Limit... limits) { public abstract Builder setLimits(ImmutableSet limits); + public abstract Builder setContextVariable(ContextVariable contextVariable); + abstract CelEnvironment autoBuild(); @CheckReturnValue @@ -258,6 +264,12 @@ public CelCompiler extend(CelCompiler celCompiler, CelOptions celOptions) applyStandardLibrarySubset(compilerBuilder); + contextVariable() + .ifPresent( + cv -> + compilerBuilder.addProtoTypeMasks( + ProtoTypeMask.ofAllFields(cv.typeName()).withFieldsAsVariableDeclarations())); + return compilerBuilder.build(); } catch (RuntimeException e) { throw new CelEnvironmentException(e.getMessage(), e); @@ -406,6 +418,17 @@ private static CanonicalCelExtension getExtensionOrThrow(String extensionName) { return extension; } + /** Represents a context variable declaration. */ + @AutoValue + public abstract static class ContextVariable { + /** Fully qualified type name of the context variable. */ + public abstract String typeName(); + + public static ContextVariable create(String typeName) { + return new AutoValue_CelEnvironment_ContextVariable(typeName); + } + } + /** Represents a policy variable declaration. */ @AutoValue public abstract static class VariableDecl { diff --git a/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java b/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java index f129d9f5d..14f1c93d8 100644 --- a/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java +++ b/bundle/src/main/java/dev/cel/bundle/CelEnvironmentYamlParser.java @@ -28,6 +28,7 @@ import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.CanIgnoreReturnValue; import dev.cel.bundle.CelEnvironment.Alias; +import dev.cel.bundle.CelEnvironment.ContextVariable; import dev.cel.bundle.CelEnvironment.ExtensionConfig; import dev.cel.bundle.CelEnvironment.FunctionDecl; import dev.cel.bundle.CelEnvironment.LibrarySubset; @@ -320,6 +321,36 @@ private ImmutableSet parseAbbreviations(ParserContext ctx, Node no return builder.build(); } + private ContextVariable parseContextVariable(ParserContext ctx, Node node) { + long id = ctx.collectMetadata(node); + if (!assertYamlType(ctx, id, node, YamlNodeType.MAP)) { + return ContextVariable.create(""); + } + + MappingNode mapNode = (MappingNode) node; + String typeName = ""; + for (NodeTuple nodeTuple : mapNode.getValue()) { + Node keyNode = nodeTuple.getKeyNode(); + long keyId = ctx.collectMetadata(keyNode); + Node valueNode = nodeTuple.getValueNode(); + String keyName = ((ScalarNode) keyNode).getValue(); + switch (keyName) { + case "type_name": + typeName = newString(ctx, valueNode); + break; + default: + ctx.reportError(keyId, String.format("Unsupported context_variable tag: %s", keyName)); + break; + } + } + + if (typeName.isEmpty()) { + ctx.reportError(id, "Missing required attribute(s): type_name"); + } + + return ContextVariable.create(typeName); + } + private ImmutableSet parseVariables(ParserContext ctx, Node node) { long valueId = ctx.collectMetadata(node); ImmutableSet.Builder variableSetBuilder = ImmutableSet.builder(); @@ -900,6 +931,9 @@ private CelEnvironment.Builder parseConfig(ParserContext ctx, Node node) { case "limits": builder.setLimits(parseLimits(ctx, valueNode)); break; + case "context_variable": + builder.setContextVariable(parseContextVariable(ctx, valueNode)); + break; default: ctx.reportError(id, "Unknown config tag: " + fieldName); // continue handling the rest of the nodes diff --git a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel index d0fed9bea..677884a8a 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel @@ -161,6 +161,7 @@ java_library( deps = [ ":cel_expression_source", ":default_result_matcher", + ":registry_utils", ":result_matcher", "//:auto_value", "//bundle:cel", diff --git a/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java b/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java index 1be0bab25..6ef988a44 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/CelTestContext.java @@ -140,21 +140,21 @@ public Optional celDescriptors() { return Optional.empty(); } + /** Returns a unified set of {@link CelDescriptors} combined from all descriptor sources. */ @Memoized - public Optional typeRegistry() { + public Optional mergedDescriptors() { if (fileTypes().isEmpty() && !fileDescriptorSetPath().isPresent()) { return Optional.empty(); } - TypeRegistry.Builder builder = TypeRegistry.newBuilder(); - if (!fileTypes().isEmpty()) { - builder.add( - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileTypes()) - .messageTypeDescriptors()); - } - if (celDescriptors().isPresent()) { - builder.add(celDescriptors().get().messageTypeDescriptors()); - } - return Optional.of(builder.build()); + ImmutableSet.Builder allFiles = + ImmutableSet.builder().addAll(fileTypes()); + celDescriptors().ifPresent(d -> allFiles.addAll(d.fileDescriptors())); + return Optional.of(CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(allFiles.build())); + } + + @Memoized + public Optional typeRegistry() { + return mergedDescriptors().map(RegistryUtils::getTypeRegistry); } public abstract Optional extensionRegistry(); diff --git a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java index 69c365972..1d3e49fbe 100644 --- a/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java +++ b/testing/src/main/java/dev/cel/testing/testrunner/TestRunnerLibrary.java @@ -360,18 +360,33 @@ private static Object getEvaluationResultWithMessage( } private static Message unpackAny(Any any, CelTestContext celTestContext) throws IOException { - if (!celTestContext.fileDescriptorSetPath().isPresent()) { - throw new IllegalArgumentException( - "Proto descriptors are required for unpacking Any messages."); + TypeRegistry typeRegistry = + celTestContext + .typeRegistry() + .orElseThrow( + () -> + new IllegalArgumentException( + "Proto descriptors or type registry are required for unpacking Any" + + " messages.")); + + Descriptor descriptor = typeRegistry.getDescriptorForTypeUrl(any.getTypeUrl()); + if (descriptor == null) { + throw new IllegalArgumentException("Descriptor not found for type URL: " + any.getTypeUrl()); } - Descriptor descriptor = - RegistryUtils.getTypeRegistry(celTestContext.celDescriptors().get()) - .getDescriptorForTypeUrl(any.getTypeUrl()); + + ExtensionRegistry extensionRegistry = + celTestContext + .extensionRegistry() + .orElseGet( + () -> + celTestContext + .mergedDescriptors() + .map(RegistryUtils::getExtensionRegistry) + .orElseGet(ExtensionRegistry::getEmptyRegistry)); + return DynamicMessage.getDefaultInstance(descriptor) .getParserForType() - .parseFrom( - any.getValue(), - RegistryUtils.getExtensionRegistry(celTestContext.celDescriptors().get())); + .parseFrom(any.getValue(), extensionRegistry); } private static Message getEvaluatedContextExpr( diff --git a/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java b/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java index b83375b35..112ef1f82 100644 --- a/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java +++ b/testing/src/test/java/dev/cel/testing/testrunner/TestRunnerLibraryTest.java @@ -262,7 +262,7 @@ public void runTest_missingProtoDescriptors_failure() throws Exception { assertThat(thrown) .hasMessageThat() - .contains("Proto descriptors are required for unpacking Any messages."); + .contains("Proto descriptors or type registry are required for unpacking Any messages"); } @Test From f89d7e5100ed62b94640697e87245cee5ff6c703 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 6 May 2026 12:44:24 -0700 Subject: [PATCH 57/66] Internal Changes PiperOrigin-RevId: 911493152 --- .../dev/cel/conformance/policy/BUILD.bazel | 2 + .../policy/PolicyConformanceTest.java | 9 ++ .../main/java/dev/cel/policy/CelPolicy.java | 2 +- .../java/dev/cel/policy/testing/BUILD.bazel | 27 ++++ .../dev/cel/policy/testing/K8sTagHandler.java | 117 ++++++++++++++++++ .../src/test/java/dev/cel/policy/BUILD.bazel | 2 +- .../cel/policy/CelPolicyCompilerImplTest.java | 2 +- .../cel/policy/CelPolicyYamlParserTest.java | 2 +- .../java/dev/cel/policy/PolicyTestHelper.java | 108 ---------------- policy/testing/BUILD.bazel | 12 ++ 10 files changed, 171 insertions(+), 112 deletions(-) create mode 100644 policy/src/main/java/dev/cel/policy/testing/BUILD.bazel create mode 100644 policy/src/main/java/dev/cel/policy/testing/K8sTagHandler.java create mode 100644 policy/testing/BUILD.bazel diff --git a/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel b/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel index ca34530fb..27853f29b 100644 --- a/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel +++ b/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel @@ -12,6 +12,8 @@ java_library( deps = [ "//:auto_value", "//bundle:cel", + "//policy:parser_factory", + "//policy/testing:k8s_test_tag_handler", "//runtime:function_binding", "//testing/testrunner:cel_expression_source", "//testing/testrunner:cel_test_context", diff --git a/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java index cd24339c0..4f1c643c2 100644 --- a/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java +++ b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java @@ -18,6 +18,8 @@ import dev.cel.bundle.Cel; import dev.cel.bundle.CelFactory; import dev.cel.expr.conformance.proto3.TestAllTypes; +import dev.cel.policy.CelPolicyParserFactory; +import dev.cel.policy.testing.K8sTagHandler; import dev.cel.runtime.CelFunctionBinding; import dev.cel.testing.testrunner.CelExpressionSource; import dev.cel.testing.testrunner.CelTestContext; @@ -77,6 +79,13 @@ public void evaluate() throws Throwable { TestAllTypes.getDescriptor().getFile(), Struct.getDescriptor().getFile()); + // Scopes the custom Kubernetes tag visitor exclusively to k8s tests to prevent non-standard + // grammar leakage. + if (name.startsWith("k8s/")) { + contextBuilder.setCelPolicyParser( + CelPolicyParserFactory.newYamlParserBuilder().addTagVisitor(new K8sTagHandler()).build()); + } + Path yamlConfigPath = Paths.get(dirPath, "config.yaml"); Path textprotoConfigPath = Paths.get(dirPath, "config.textproto"); diff --git a/policy/src/main/java/dev/cel/policy/CelPolicy.java b/policy/src/main/java/dev/cel/policy/CelPolicy.java index 9e442a2e7..19f6631d0 100644 --- a/policy/src/main/java/dev/cel/policy/CelPolicy.java +++ b/policy/src/main/java/dev/cel/policy/CelPolicy.java @@ -252,7 +252,7 @@ public abstract static class Builder implements RequiredFieldsChecker { abstract Optional id(); - abstract Optional result(); + public abstract Optional result(); abstract Optional explanation(); diff --git a/policy/src/main/java/dev/cel/policy/testing/BUILD.bazel b/policy/src/main/java/dev/cel/policy/testing/BUILD.bazel new file mode 100644 index 000000000..3a8a4950b --- /dev/null +++ b/policy/src/main/java/dev/cel/policy/testing/BUILD.bazel @@ -0,0 +1,27 @@ +load("@rules_java//java:defs.bzl", "java_library") + +package( + default_applicable_licenses = [ + "//:license", + ], + default_testonly = True, + default_visibility = [ + "//policy/testing:__pkg__", + ], +) + +java_library( + name = "k8s_tag_handler", + srcs = ["K8sTagHandler.java"], + tags = [ + ], + deps = [ + "//common/formats:value_string", + "//common/formats:yaml_helper", + "//policy", + "//policy:parser", + "//policy:policy_parser_context", + "@maven//:com_google_guava_guava", + "@maven//:org_yaml_snakeyaml", + ], +) diff --git a/policy/src/main/java/dev/cel/policy/testing/K8sTagHandler.java b/policy/src/main/java/dev/cel/policy/testing/K8sTagHandler.java new file mode 100644 index 000000000..04635e054 --- /dev/null +++ b/policy/src/main/java/dev/cel/policy/testing/K8sTagHandler.java @@ -0,0 +1,117 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.policy.testing; + +import com.google.common.annotations.VisibleForTesting; +import dev.cel.common.formats.ValueString; +import dev.cel.common.formats.YamlHelper; +import dev.cel.common.formats.YamlHelper.YamlNodeType; +import dev.cel.policy.CelPolicy; +import dev.cel.policy.CelPolicy.Match; +import dev.cel.policy.CelPolicyParser.TagVisitor; +import dev.cel.policy.PolicyParserContext; +import org.yaml.snakeyaml.nodes.Node; +import org.yaml.snakeyaml.nodes.SequenceNode; + +/** + * K8sTagHandler is a {@link TagVisitor} implementation to support parsing Kubernetes + * ValidatingAdmissionPolicy structures in testing and conformance environments. + */ +@VisibleForTesting +public final class K8sTagHandler implements TagVisitor { + + @Override + public void visitPolicyTag( + PolicyParserContext ctx, + long id, + String tagName, + Node node, + CelPolicy.Builder policyBuilder) { + switch (tagName) { + case "kind": + policyBuilder.putMetadata("kind", ctx.newYamlString(node).value()); + break; + case "metadata": + YamlHelper.assertYamlType(ctx, id, node, YamlNodeType.MAP); + break; + case "spec": + CelPolicy.Rule spec = ctx.parseRule(ctx, policyBuilder, node); + policyBuilder.setRule(spec); + break; + default: + TagVisitor.super.visitPolicyTag(ctx, id, tagName, node, policyBuilder); + break; + } + } + + @Override + public void visitRuleTag( + PolicyParserContext ctx, + long id, + String tagName, + Node node, + CelPolicy.Builder policyBuilder, + CelPolicy.Rule.Builder ruleBuilder) { + switch (tagName) { + case "failurePolicy": + policyBuilder.putMetadata(tagName, ctx.newYamlString(node).value()); + break; + case "matchConstraints": + YamlHelper.assertYamlType(ctx, id, node, YamlNodeType.MAP); + break; + case "validations": + if (!YamlHelper.assertYamlType(ctx, id, node, YamlNodeType.LIST)) { + return; + } + SequenceNode seqNode = (SequenceNode) node; + for (Node valNode : seqNode.getValue()) { + ruleBuilder.addMatches(ctx.parseMatch(ctx, policyBuilder, valNode)); + } + break; + default: + TagVisitor.super.visitRuleTag(ctx, id, tagName, node, policyBuilder, ruleBuilder); + break; + } + } + + @Override + public void visitMatchTag( + PolicyParserContext ctx, + long id, + String tagName, + Node node, + CelPolicy.Builder policyBuilder, + CelPolicy.Match.Builder matchBuilder) { + if (!matchBuilder.result().isPresent()) { + matchBuilder.setResult( + Match.Result.ofOutput(ValueString.of(ctx.nextId(), "'invalid admission request'"))); + } + switch (tagName) { + case "expression": + // The K8s expression to validate must return false in order to generate a violation + // message. + ValueString condition = ctx.newSourceString(node); + String invertedCondition = "!(" + condition.value() + ")"; + matchBuilder.setCondition(ValueString.of(condition.id(), invertedCondition)); + break; + case "messageExpression": + matchBuilder.setResult(Match.Result.ofOutput(ctx.newSourceString(node))); + break; + default: + TagVisitor.super.visitMatchTag(ctx, id, tagName, node, policyBuilder, matchBuilder); + break; + } + } +} diff --git a/policy/src/test/java/dev/cel/policy/BUILD.bazel b/policy/src/test/java/dev/cel/policy/BUILD.bazel index 8a28caee1..3089a3849 100644 --- a/policy/src/test/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/test/java/dev/cel/policy/BUILD.bazel @@ -30,9 +30,9 @@ java_library( "//policy:compiler_factory", "//policy:parser", "//policy:parser_factory", - "//policy:policy_parser_context", "//policy:source", "//policy:validation_exception", + "//policy/testing:k8s_test_tag_handler", "//runtime", "//runtime:function_binding", "//testing:cel_runtime_flavor", diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index 35e249407..416e3b95f 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -37,12 +37,12 @@ import dev.cel.extensions.CelOptionalLibrary; import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparserFactory; -import dev.cel.policy.PolicyTestHelper.K8sTagHandler; import dev.cel.policy.PolicyTestHelper.PolicyTestSuite; import dev.cel.policy.PolicyTestHelper.PolicyTestSuite.PolicyTestSection; import dev.cel.policy.PolicyTestHelper.PolicyTestSuite.PolicyTestSection.PolicyTestCase; import dev.cel.policy.PolicyTestHelper.PolicyTestSuite.PolicyTestSection.PolicyTestCase.PolicyTestInput; import dev.cel.policy.PolicyTestHelper.TestYamlPolicy; +import dev.cel.policy.testing.K8sTagHandler; import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelLateFunctionBindings; import dev.cel.testing.CelRuntimeFlavor; diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyYamlParserTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyYamlParserTest.java index 22aec6746..2a2c47a98 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyYamlParserTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyYamlParserTest.java @@ -22,8 +22,8 @@ import com.google.testing.junit.testparameterinjector.TestParameterInjector; import dev.cel.common.formats.ValueString; import dev.cel.policy.CelPolicy.Import; -import dev.cel.policy.PolicyTestHelper.K8sTagHandler; import dev.cel.policy.PolicyTestHelper.TestYamlPolicy; +import dev.cel.policy.testing.K8sTagHandler; import org.junit.Test; import org.junit.runner.RunWith; diff --git a/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java b/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java index 59647f4d9..6e918286b 100644 --- a/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java +++ b/policy/src/test/java/dev/cel/policy/PolicyTestHelper.java @@ -19,11 +19,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Ascii; import com.google.common.io.Resources; -import dev.cel.common.formats.ValueString; -import dev.cel.policy.CelPolicy.Match; -import dev.cel.policy.CelPolicy.Match.Result; -import dev.cel.policy.CelPolicy.Rule; -import dev.cel.policy.CelPolicyParser.TagVisitor; import java.io.IOException; import java.net.URL; import java.util.List; @@ -31,8 +26,6 @@ import org.yaml.snakeyaml.LoaderOptions; import org.yaml.snakeyaml.Yaml; import org.yaml.snakeyaml.constructor.Constructor; -import org.yaml.snakeyaml.nodes.Node; -import org.yaml.snakeyaml.nodes.SequenceNode; /** Package-private class to assist with policy testing. */ final class PolicyTestHelper { @@ -273,106 +266,5 @@ private static String readFile(String path) throws IOException { return Resources.toString(getResource(path), UTF_8); } - static class K8sTagHandler implements TagVisitor { - - @Override - public void visitPolicyTag( - PolicyParserContext ctx, - long id, - String tagName, - Node node, - CelPolicy.Builder policyBuilder) { - switch (tagName) { - case "kind": - policyBuilder.putMetadata("kind", ctx.newYamlString(node)); - break; - case "metadata": - long metadataId = ctx.collectMetadata(node); - if (!node.getTag().getValue().equals("tag:yaml.org,2002:map")) { - ctx.reportError( - metadataId, - String.format( - "invalid 'metadata' type, expected map got: %s", node.getTag().getValue())); - } - break; - case "spec": - Rule rule = ctx.parseRule(ctx, policyBuilder, node); - policyBuilder.setRule(rule); - break; - default: - TagVisitor.super.visitPolicyTag(ctx, id, tagName, node, policyBuilder); - break; - } - } - - @Override - public void visitRuleTag( - PolicyParserContext ctx, - long id, - String tagName, - Node node, - CelPolicy.Builder policyBuilder, - Rule.Builder ruleBuilder) { - switch (tagName) { - case "failurePolicy": - policyBuilder.putMetadata(tagName, ctx.newYamlString(node)); - break; - case "matchConstraints": - long matchConstraintsId = ctx.collectMetadata(node); - if (!node.getTag().getValue().equals("tag:yaml.org,2002:map")) { - ctx.reportError( - matchConstraintsId, - String.format( - "invalid 'matchConstraints' type, expected map got: %s", - node.getTag().getValue())); - } - break; - case "validations": - long validationId = ctx.collectMetadata(node); - if (!node.getTag().getValue().equals("tag:yaml.org,2002:seq")) { - ctx.reportError( - validationId, - String.format( - "invalid 'validations' type, expected list got: %s", node.getTag().getValue())); - } - - SequenceNode validationNodes = (SequenceNode) node; - for (Node element : validationNodes.getValue()) { - ruleBuilder.addMatches(ctx.parseMatch(ctx, policyBuilder, element)); - } - break; - default: - TagVisitor.super.visitRuleTag(ctx, id, tagName, node, policyBuilder, ruleBuilder); - break; - } - } - - @Override - public void visitMatchTag( - PolicyParserContext ctx, - long id, - String tagName, - Node node, - CelPolicy.Builder policyBuilder, - Match.Builder matchBuilder) { - switch (tagName) { - case "expression": - // The K8s expression to validate must return false in order to generate a violation - // message. - ValueString conditionValue = ctx.newYamlString(node); - conditionValue = - conditionValue.toBuilder().setValue("!(" + conditionValue.value() + ")").build(); - matchBuilder.setCondition(conditionValue); - break; - case "messageExpression": - matchBuilder.setResult(Result.ofOutput(ctx.newYamlString(node))); - break; - default: - TagVisitor.super.visitMatchTag(ctx, id, tagName, node, policyBuilder, matchBuilder); - break; - } - } - } - private PolicyTestHelper() {} } diff --git a/policy/testing/BUILD.bazel b/policy/testing/BUILD.bazel new file mode 100644 index 000000000..898368c3c --- /dev/null +++ b/policy/testing/BUILD.bazel @@ -0,0 +1,12 @@ +load("@rules_java//java:defs.bzl", "java_library") + +package( + default_applicable_licenses = ["//:license"], + default_testonly = True, + default_visibility = ["//:internal"], +) + +java_library( + name = "k8s_test_tag_handler", + exports = ["//policy/src/main/java/dev/cel/policy/testing:k8s_tag_handler"], +) From 966b0cf4238191c91c291aa0a9b60571231f2c07 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 6 May 2026 15:47:37 -0700 Subject: [PATCH 58/66] Add support for running policy error cases in conformance tests PiperOrigin-RevId: 911591805 --- .../dev/cel/conformance/policy/BUILD.bazel | 2 + .../policy/PolicyConformanceTest.java | 16 +++++++- .../policy/PolicyConformanceTestRunner.java | 37 +++++++++++++++++-- .../java/dev/cel/policy/RuleComposer.java | 4 +- .../expected_errors.baseline | 4 +- .../expected_errors.baseline | 2 +- 6 files changed, 55 insertions(+), 10 deletions(-) diff --git a/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel b/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel index 27853f29b..0326b6f15 100644 --- a/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel +++ b/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel @@ -11,8 +11,10 @@ java_library( srcs = glob(["*.java"]), deps = [ "//:auto_value", + "//:java_truth", "//bundle:cel", "//policy:parser_factory", + "//policy:validation_exception", "//policy/testing:k8s_test_tag_handler", "//runtime:function_binding", "//testing/testrunner:cel_expression_source", diff --git a/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java index 4f1c643c2..d7851bb72 100644 --- a/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java +++ b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTest.java @@ -14,11 +14,14 @@ package dev.cel.conformance.policy; +import static com.google.common.truth.Truth.assertThat; + import com.google.protobuf.Struct; import dev.cel.bundle.Cel; import dev.cel.bundle.CelFactory; import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.policy.CelPolicyParserFactory; +import dev.cel.policy.CelPolicyValidationException; import dev.cel.policy.testing.K8sTagHandler; import dev.cel.runtime.CelFunctionBinding; import dev.cel.testing.testrunner.CelExpressionSource; @@ -28,6 +31,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Locale; import org.junit.runners.model.Statement; /** Statement representing a single CEL policy conformance test case. */ @@ -95,6 +99,16 @@ public void evaluate() throws Throwable { contextBuilder.setConfigFile(textprotoConfigPath.toString()); } - TestRunnerLibrary.runTest(testCase, contextBuilder.build()); + try { + TestRunnerLibrary.runTest(testCase, contextBuilder.build()); + } catch (CelPolicyValidationException e) { + if (testCase.output().kind() == CelTestCase.Output.Kind.EVAL_ERROR) { + String expectedError = testCase.output().evalError().get(0).toString(); + assertThat(e.getMessage().toLowerCase(Locale.US)) + .contains(expectedError.toLowerCase(Locale.US)); + } else { + throw e; + } + } } } diff --git a/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTestRunner.java b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTestRunner.java index 6d1d86e47..62812b124 100644 --- a/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTestRunner.java +++ b/conformance/src/test/java/dev/cel/conformance/policy/PolicyConformanceTestRunner.java @@ -44,6 +44,7 @@ public final class PolicyConformanceTestRunner extends ParentRunner discoverTestDirs(String testdataDir) { if (!dir.exists() || !dir.isDirectory()) { return ImmutableList.of(); } - String[] directories = dir.list((current, name) -> new File(current, name).isDirectory()); - if (directories == null) { + File[] topLevelDirs = dir.listFiles(File::isDirectory); + if (topLevelDirs == null) { return ImmutableList.of(); } - Arrays.sort(directories); - return ImmutableList.copyOf(directories); + + ImmutableList.Builder testDirsBuilder = ImmutableList.builder(); + Arrays.sort(topLevelDirs); + for (File topLevelDir : topLevelDirs) { + if (hasTestSuite(topLevelDir)) { + testDirsBuilder.add(topLevelDir.getName()); + continue; + } + + // Check one level deeper to support nested tests like compile_errors/unreachable + File[] subDirs = topLevelDir.listFiles(File::isDirectory); + if (subDirs == null) { + continue; + } + + Arrays.sort(subDirs); + for (File subDir : subDirs) { + if (hasTestSuite(subDir)) { + testDirsBuilder.add(topLevelDir.getName() + "/" + subDir.getName()); + } + } + } + + return testDirsBuilder.build(); + } + + private static boolean hasTestSuite(File dir) { + return (new File(dir, TESTS_YAML_FILE_NAME).exists() + || new File(dir, TESTS_TEXTPROTO_FILE_NAME).exists()) + && new File(dir, POLICY_YAML_FILE_NAME).exists(); } private final ImmutableList tests; diff --git a/policy/src/main/java/dev/cel/policy/RuleComposer.java b/policy/src/main/java/dev/cel/policy/RuleComposer.java index cafef4b7c..5fa0957f5 100644 --- a/policy/src/main/java/dev/cel/policy/RuleComposer.java +++ b/policy/src/main/java/dev/cel/policy/RuleComposer.java @@ -93,7 +93,7 @@ private Step optimizeRule(Cel cel, CelCompiledRule compiledRule) { assertComposedAstIsValid( cel, output.expr, - "conflicting output types found.", + "incompatible output types found.", matchOutput.sourceId(), lastOutputId); lastOutputId = matchOutput.sourceId(); @@ -115,7 +115,7 @@ private Step optimizeRule(Cel cel, CelCompiledRule compiledRule) { cel, output.expr, String.format( - "failed composing the subrule '%s' due to conflicting output types.", + "failed composing the subrule '%s' due to incompatible output types.", matchNestedRule.ruleId().map(ValueString::value).orElse("")), lastOutputId); break; diff --git a/testing/src/test/resources/policy/compose_errors_conflicting_output/expected_errors.baseline b/testing/src/test/resources/policy/compose_errors_conflicting_output/expected_errors.baseline index 3e2624b64..0facbbe2e 100644 --- a/testing/src/test/resources/policy/compose_errors_conflicting_output/expected_errors.baseline +++ b/testing/src/test/resources/policy/compose_errors_conflicting_output/expected_errors.baseline @@ -1,6 +1,6 @@ -ERROR: compose_errors_conflicting_output/policy.yaml:22:14: conflicting output types found. +ERROR: compose_errors_conflicting_output/policy.yaml:22:14: incompatible output types found. | output: "false" | .............^ -ERROR: compose_errors_conflicting_output/policy.yaml:23:14: conflicting output types found. +ERROR: compose_errors_conflicting_output/policy.yaml:23:14: incompatible output types found. | - output: "{'banned': true}" | .............^ \ No newline at end of file diff --git a/testing/src/test/resources/policy/compose_errors_conflicting_subrule/expected_errors.baseline b/testing/src/test/resources/policy/compose_errors_conflicting_subrule/expected_errors.baseline index 559d62e1d..92ddff311 100644 --- a/testing/src/test/resources/policy/compose_errors_conflicting_subrule/expected_errors.baseline +++ b/testing/src/test/resources/policy/compose_errors_conflicting_subrule/expected_errors.baseline @@ -1,3 +1,3 @@ -ERROR: compose_errors_conflicting_subrule/policy.yaml:36:14: failed composing the subrule 'banned regions' due to conflicting output types. +ERROR: compose_errors_conflicting_subrule/policy.yaml:36:14: failed composing the subrule 'banned regions' due to incompatible output types. | output: "{'banned': false}" | .............^ \ No newline at end of file From 14d4c2e39151f2e99e36f9818a9118b01c1d9ed3 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 6 May 2026 17:16:33 -0700 Subject: [PATCH 59/66] Track incompatible output types for accurate error reporting PiperOrigin-RevId: 911634269 --- policy/BUILD.bazel | 6 ++ .../src/main/java/dev/cel/policy/BUILD.bazel | 3 +- .../java/dev/cel/policy/RuleComposer.java | 88 ++++++++++++------- .../src/test/java/dev/cel/policy/BUILD.bazel | 2 + .../cel/policy/CelPolicyCompilerImplTest.java | 19 ++++ .../expected_errors.baseline | 4 +- .../expected_errors.baseline | 3 + 7 files changed, 92 insertions(+), 33 deletions(-) diff --git a/policy/BUILD.bazel b/policy/BUILD.bazel index 5979f1ba7..bce68f001 100644 --- a/policy/BUILD.bazel +++ b/policy/BUILD.bazel @@ -59,3 +59,9 @@ java_library( name = "compiler_builder", exports = ["//policy/src/main/java/dev/cel/policy:compiler_builder"], ) + +java_library( + name = "rule_composer", + visibility = ["//:internal"], + exports = ["//policy/src/main/java/dev/cel/policy:rule_composer"], +) diff --git a/policy/src/main/java/dev/cel/policy/BUILD.bazel b/policy/src/main/java/dev/cel/policy/BUILD.bazel index a7bb90ffe..e0d6af461 100644 --- a/policy/src/main/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/main/java/dev/cel/policy/BUILD.bazel @@ -244,7 +244,6 @@ java_library( java_library( name = "rule_composer", srcs = ["RuleComposer.java"], - visibility = ["//visibility:private"], deps = [ ":compiled_rule", "//bundle:cel", @@ -257,6 +256,8 @@ java_library( "//common/ast:mutable_expr", "//common/formats:value_string", "//common/navigation:mutable_navigation", + "//common/types:cel_types", + "//common/types:type_providers", "//extensions:optional_library", "//optimizer:ast_optimizer", "//optimizer:mutable_ast", diff --git a/policy/src/main/java/dev/cel/policy/RuleComposer.java b/policy/src/main/java/dev/cel/policy/RuleComposer.java index 5fa0957f5..73d31a4ee 100644 --- a/policy/src/main/java/dev/cel/policy/RuleComposer.java +++ b/policy/src/main/java/dev/cel/policy/RuleComposer.java @@ -18,6 +18,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.stream.Collectors.toCollection; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import dev.cel.bundle.Cel; @@ -32,11 +33,14 @@ import dev.cel.common.formats.ValueString; import dev.cel.common.navigation.CelNavigableMutableAst; import dev.cel.common.navigation.CelNavigableMutableExpr; +import dev.cel.common.types.CelType; +import dev.cel.common.types.CelTypes; import dev.cel.extensions.CelOptionalLibrary.Function; import dev.cel.optimizer.AstMutator; import dev.cel.optimizer.CelAstOptimizer; import dev.cel.policy.CelCompiledRule.CelCompiledMatch; import dev.cel.policy.CelCompiledRule.CelCompiledMatch.OutputValue; +import dev.cel.policy.CelCompiledRule.CelCompiledMatch.Result; import dev.cel.policy.CelCompiledRule.CelCompiledVariable; import java.util.ArrayList; import java.util.Arrays; @@ -74,11 +78,15 @@ private Step optimizeRule(Cel cel, CelCompiledRule compiledRule) { } long lastOutputId = 0; + // The expected output type of the rule, used to verify that all branches agree on the type. + CelType lastOutputType = null; for (CelCompiledMatch match : Lists.reverse(compiledRule.matches())) { CelAbstractSyntaxTree conditionAst = match.condition(); boolean isTriviallyTrue = match.isConditionTriviallyTrue(); CelMutableAst condAst = CelMutableAst.fromCelAst(conditionAst); + long currentSourceId = lastOutputId; + switch (match.result().kind()) { case OUTPUT: // If the match has an output, then it is considered a non-optional output since @@ -86,42 +94,54 @@ private Step optimizeRule(Cel cel, CelCompiledRule compiledRule) { // of output being optional.none() will convert the non-optional value to an optional // one. OutputValue matchOutput = match.result().output(); - CelMutableAst outAst = CelMutableAst.fromCelAst(matchOutput.ast()); - Step step = Step.newNonOptionalStep(!isTriviallyTrue, condAst, outAst); + Step step = + Step.newNonOptionalStep( + !isTriviallyTrue, condAst, CelMutableAst.fromCelAst(matchOutput.ast())); + currentSourceId = matchOutput.sourceId(); + output = combine(astMutator, step, output); - assertComposedAstIsValid( - cel, - output.expr, - "incompatible output types found.", - matchOutput.sourceId(), - lastOutputId); - lastOutputId = matchOutput.sourceId(); + String outputFailureMessage = + String.format( + "incompatible output types: block has output type %s, but previous outputs have" + + " type %s", + lastOutputType == null ? "" : CelTypes.format(lastOutputType), + CelTypes.format(matchOutput.ast().getResultType())); + lastOutputType = + assertComposedAstIsValid( + cel, output.expr, outputFailureMessage, currentSourceId, lastOutputId) + .getResultType(); + break; case RULE: // If the match has a nested rule, then compute the rule and whether it has // an optional return value. CelCompiledRule matchNestedRule = match.result().rule(); Step nestedRule = optimizeRule(cel, matchNestedRule); - boolean nestedHasOptional = matchNestedRule.hasOptionalOutput(); - Step ruleStep = - nestedHasOptional - ? Step.newOptionalStep(!isTriviallyTrue, condAst, nestedRule.expr) - : Step.newNonOptionalStep(!isTriviallyTrue, condAst, nestedRule.expr); + new Step( + matchNestedRule.hasOptionalOutput(), !isTriviallyTrue, condAst, nestedRule.expr); + currentSourceId = getFirstOutputSourceId(matchNestedRule); + output = combine(astMutator, ruleStep, output); - assertComposedAstIsValid( - cel, - output.expr, - String.format( - "failed composing the subrule '%s' due to incompatible output types.", - matchNestedRule.ruleId().map(ValueString::value).orElse("")), - lastOutputId); + lastOutputType = + assertComposedAstIsValid( + cel, + output.expr, + String.format( + "failed composing the subrule '%s' due to incompatible output types.", + matchNestedRule.ruleId().map(ValueString::value).orElse("")), + currentSourceId, + lastOutputId) + .getResultType(); break; } + + lastOutputId = currentSourceId; } + Preconditions.checkState(output != null, "Policy contains no outputs."); CelMutableAst resultExpr = output.expr; resultExpr = inlineCompiledVariables(resultExpr, compiledRule.variables()); resultExpr = astMutator.renumberIdsConsecutively(resultExpr); @@ -266,21 +286,34 @@ private CelMutableAst inlineCompiledVariables( return mutatedAst; } - private void assertComposedAstIsValid( + private CelAbstractSyntaxTree assertComposedAstIsValid( Cel cel, CelMutableAst composedAst, String failureMessage, Long... ids) { - assertComposedAstIsValid(cel, composedAst, failureMessage, Arrays.asList(ids)); + return assertComposedAstIsValid(cel, composedAst, failureMessage, Arrays.asList(ids)); } - private void assertComposedAstIsValid( + private CelAbstractSyntaxTree assertComposedAstIsValid( Cel cel, CelMutableAst composedAst, String failureMessage, List ids) { try { - cel.check(composedAst.toParsedAst()).getAst(); + return cel.check(composedAst.toParsedAst()).getAst(); } catch (CelValidationException e) { ids = ids.stream().filter(id -> id > 0).collect(toCollection(ArrayList::new)); throw new RuleCompositionException(failureMessage, e, ids); } } + private static long getFirstOutputSourceId(CelCompiledRule rule) { + for (CelCompiledMatch match : rule.matches()) { + if (match.result().kind() == Result.Kind.OUTPUT) { + return match.result().output().sourceId(); + } else if (match.result().kind() == Result.Kind.RULE) { + return getFirstOutputSourceId(match.result().rule()); + } + } + + // Fallback to the nested rule ID if the policy is invalid and contains no output + return rule.sourceId(); + } + // Step represents an intermediate stage of rule and match expression composition. // // The CelCompiledRule and CelCompiledMatch types are meant to represent standalone tuples of @@ -311,11 +344,6 @@ private Step( this.expr = expr; } - private static Step newOptionalStep( - boolean isConditional, CelMutableAst cond, CelMutableAst expr) { - return new Step(/* isOptional= */ true, isConditional, cond, expr); - } - private static Step newNonOptionalStep( boolean isConditional, CelMutableAst cond, CelMutableAst expr) { return new Step(/* isOptional= */ false, isConditional, cond, expr); diff --git a/policy/src/test/java/dev/cel/policy/BUILD.bazel b/policy/src/test/java/dev/cel/policy/BUILD.bazel index 3089a3849..bc8a5d4b4 100644 --- a/policy/src/test/java/dev/cel/policy/BUILD.bazel +++ b/policy/src/test/java/dev/cel/policy/BUILD.bazel @@ -27,9 +27,11 @@ java_library( "//parser:parser_factory", "//parser:unparser", "//policy", + "//policy:compiled_rule", "//policy:compiler_factory", "//policy:parser", "//policy:parser_factory", + "//policy:rule_composer", "//policy:source", "//policy:validation_exception", "//policy/testing:k8s_test_tag_handler", diff --git a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java index 416e3b95f..b4065b60c 100644 --- a/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java +++ b/policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java @@ -31,6 +31,7 @@ import dev.cel.bundle.CelEnvironmentYamlParser; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelOptions; +import dev.cel.common.formats.ValueString; import dev.cel.common.types.OptionalType; import dev.cel.common.types.SimpleType; import dev.cel.expr.conformance.proto3.TestAllTypes; @@ -356,6 +357,24 @@ public void evaluateYamlPolicy_withSimpleVariable() throws Exception { assertThat(evalResult).isFalse(); } + @Test + public void compose_ruleWithNoOutputs_throws() throws Exception { + Cel cel = newCel(); + CelCompiledRule emptyRule = + CelCompiledRule.create( + 1L, + Optional.of(ValueString.of(2L, "empty_rule")), + ImmutableList.of(), + ImmutableList.of(), + cel); + RuleComposer composer = RuleComposer.newInstance(emptyRule, "variables.", 1000); + CelAbstractSyntaxTree ast = cel.compile("true").getAst(); + + IllegalStateException e = + assertThrows(IllegalStateException.class, () -> composer.optimize(ast, cel)); + assertThat(e).hasMessageThat().isEqualTo("Policy contains no outputs."); + } + private static final class EvaluablePolicyTestData { private final TestYamlPolicy yamlPolicy; private final PolicyTestCase testCase; diff --git a/testing/src/test/resources/policy/compose_errors_conflicting_output/expected_errors.baseline b/testing/src/test/resources/policy/compose_errors_conflicting_output/expected_errors.baseline index 0facbbe2e..bc205c2ab 100644 --- a/testing/src/test/resources/policy/compose_errors_conflicting_output/expected_errors.baseline +++ b/testing/src/test/resources/policy/compose_errors_conflicting_output/expected_errors.baseline @@ -1,6 +1,6 @@ -ERROR: compose_errors_conflicting_output/policy.yaml:22:14: incompatible output types found. +ERROR: compose_errors_conflicting_output/policy.yaml:22:14: incompatible output types: block has output type map(string, bool), but previous outputs have type bool | output: "false" | .............^ -ERROR: compose_errors_conflicting_output/policy.yaml:23:14: incompatible output types found. +ERROR: compose_errors_conflicting_output/policy.yaml:23:14: incompatible output types: block has output type map(string, bool), but previous outputs have type bool | - output: "{'banned': true}" | .............^ \ No newline at end of file diff --git a/testing/src/test/resources/policy/compose_errors_conflicting_subrule/expected_errors.baseline b/testing/src/test/resources/policy/compose_errors_conflicting_subrule/expected_errors.baseline index 92ddff311..66e48ea57 100644 --- a/testing/src/test/resources/policy/compose_errors_conflicting_subrule/expected_errors.baseline +++ b/testing/src/test/resources/policy/compose_errors_conflicting_subrule/expected_errors.baseline @@ -1,3 +1,6 @@ +ERROR: compose_errors_conflicting_subrule/policy.yaml:34:18: failed composing the subrule 'banned regions' due to incompatible output types. + | output: "true" + | .................^ ERROR: compose_errors_conflicting_subrule/policy.yaml:36:14: failed composing the subrule 'banned regions' due to incompatible output types. | output: "{'banned': false}" | .............^ \ No newline at end of file From bdfe82340650d66e5fc8446780bb539a0c31e144 Mon Sep 17 00:00:00 2001 From: Robert Yokota Date: Thu, 14 May 2026 11:54:11 -0700 Subject: [PATCH 60/66] Fix math.round to use HALF_UP to match doc, cel-go/cel-cpp --- .../main/java/dev/cel/extensions/CelMathExtensions.java | 2 +- .../java/dev/cel/extensions/CelMathExtensionsTest.java | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java index 78a0fd51c..63108aa0c 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelMathExtensions.java @@ -874,7 +874,7 @@ private static double round(double x) { if (isNaN(x) || isInfinite(x)) { return x; } - return DoubleMath.roundToLong(x, RoundingMode.HALF_EVEN); + return DoubleMath.roundToLong(x, RoundingMode.HALF_UP); } private static Number sign(Number x) { diff --git a/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java index 16d5c4c83..68c80dedb 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelMathExtensionsTest.java @@ -735,6 +735,13 @@ public void floor_invalidArgs_throwsException(String expr) { @TestParameters("{expr: 'math.round(-1.5)' , expectedResult: -2.0}") @TestParameters("{expr: 'math.round(-1.2)' , expectedResult: -1.0}") @TestParameters("{expr: 'math.round(-1.6)' , expectedResult: -2.0}") + // Discriminating tie cases: confirm "ties round away from zero" (HALF_UP), not + // banker's rounding (HALF_EVEN). 1.5/-1.5 above don't distingish the two because + // their nearest-even neighbor (2/-2) is also the away-from-zero neighbor. + @TestParameters("{expr: 'math.round(0.5)' , expectedResult: 1.0}") + @TestParameters("{expr: 'math.round(2.5)' , expectedResult: 3.0}") + @TestParameters("{expr: 'math.round(-0.5)' , expectedResult: -1.0}") + @TestParameters("{expr: 'math.round(-2.5)' , expectedResult: -3.0}") @TestParameters("{expr: 'math.round(0.0/0.0)' , expectedResult: NaN}") @TestParameters("{expr: 'math.round(1.0/0.0)' , expectedResult: Infinity}") @TestParameters("{expr: 'math.round(-1.0/0.0)' , expectedResult: -Infinity}") From 9988a347afbe817614c7817e92f23d52c3791872 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 14 May 2026 00:48:33 +0000 Subject: [PATCH 61/66] Enable policy conformance test suite in OSS PiperOrigin-RevId: 915149734 --- MODULE.bazel | 1 + .../test/java/dev/cel/conformance/policy/BUILD.bazel | 5 +++++ .../policy/cel_policy_conformance_test.bzl | 6 +++++- repositories.bzl | 11 +++++++++++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/MODULE.bazel b/MODULE.bazel index 0b67c825c..895715a5f 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -137,3 +137,4 @@ use_repo(maven, "maven", "maven_android", "maven_conformance") non_module_dependencies = use_extension("//:repositories.bzl", "non_module_dependencies") use_repo(non_module_dependencies, "antlr4_jar") use_repo(non_module_dependencies, "bazel_common") +use_repo(non_module_dependencies, "cel_policy") diff --git a/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel b/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel index 0326b6f15..e4d80eccf 100644 --- a/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel +++ b/conformance/src/test/java/dev/cel/conformance/policy/BUILD.bazel @@ -29,3 +29,8 @@ java_library( "@maven//:junit_junit", ], ) + +cel_policy_conformance_test_java( + name = "policy_conformance_tests", + testdata = "@cel_policy//conformance:testdata", +) diff --git a/conformance/src/test/java/dev/cel/conformance/policy/cel_policy_conformance_test.bzl b/conformance/src/test/java/dev/cel/conformance/policy/cel_policy_conformance_test.bzl index 3e3720ec5..b53d982bb 100644 --- a/conformance/src/test/java/dev/cel/conformance/policy/cel_policy_conformance_test.bzl +++ b/conformance/src/test/java/dev/cel/conformance/policy/cel_policy_conformance_test.bzl @@ -33,7 +33,11 @@ def cel_policy_conformance_test_java( """ lbl = native.package_relative_label(testdata) - testdata_dir = lbl.package + "/" + lbl.name + + # Under Bzlmod, external repository runfiles are located in sibling directories + # named after their canonical repository name. + repo_prefix = "../" + lbl.workspace_name + "/" if lbl.workspace_name else "" + testdata_dir = repo_prefix + lbl.package + "/" + lbl.name java_test( name = name, diff --git a/repositories.bzl b/repositories.bzl index 8e9a9ba47..88f01019a 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -33,9 +33,20 @@ def bazel_common_dependency(): url = "https://github.com/google/bazel-common/archive/%s.tar.gz" % bazel_common_tag, ) +def cel_policy_dependency(): + cel_policy_tag = "569292f1c4eaa41894c1e37ee94eb146e284bcfa" + cel_policy_sha = "5a68318d906f6ce18492ad6f82b5f8bb083fd9d694cf567d399216c11da03157" + http_archive( + name = "cel_policy", + sha256 = cel_policy_sha, + strip_prefix = "cel-policy-%s" % cel_policy_tag, + url = "https://github.com/cel-expr/cel-policy/archive/%s.tar.gz" % cel_policy_tag, + ) + def _non_module_dependencies_impl(_ctx): antlr4_jar_dependency() bazel_common_dependency() + cel_policy_dependency() non_module_dependencies = module_extension( implementation = _non_module_dependencies_impl, From 558f09e38d9d754de4286c736bdf555b612bd827 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 14 May 2026 15:55:33 -0700 Subject: [PATCH 62/66] Implement Native Extensions Future iterations may introduce an annotation based field mapping (ex: `@CelName('foo')`. PiperOrigin-RevId: 915660684 --- common/internal/BUILD.bazel | 5 + .../java/dev/cel/common/internal/BUILD.bazel | 3 + .../cel/common/internal/ReflectionUtil.java | 15 + extensions/BUILD.bazel | 5 + .../main/java/dev/cel/extensions/BUILD.bazel | 25 + .../dev/cel/extensions/CelExtensions.java | 29 +- .../extensions/CelNativeTypesExtensions.java | 1041 +++++++++++++ .../cel/extensions/CelOptionalLibrary.java | 6 +- .../main/java/dev/cel/extensions/README.md | 55 + .../test/java/dev/cel/extensions/BUILD.bazel | 2 + .../CelNativeTypesExtensionsTest.java | 1349 +++++++++++++++++ .../runtime/planner/NamespacedAttribute.java | 4 +- .../runtime/planner/RelativeAttribute.java | 4 +- 13 files changed, 2529 insertions(+), 14 deletions(-) create mode 100644 extensions/src/main/java/dev/cel/extensions/CelNativeTypesExtensions.java create mode 100644 extensions/src/test/java/dev/cel/extensions/CelNativeTypesExtensionsTest.java diff --git a/common/internal/BUILD.bazel b/common/internal/BUILD.bazel index 781566713..7c33e56b9 100644 --- a/common/internal/BUILD.bazel +++ b/common/internal/BUILD.bazel @@ -147,3 +147,8 @@ cel_android_library( name = "date_time_helpers_android", exports = ["//common/src/main/java/dev/cel/common/internal:date_time_helpers_android"], ) + +java_library( + name = "reflection_util", + exports = ["//common/src/main/java/dev/cel/common/internal:reflection_util"], +) diff --git a/common/src/main/java/dev/cel/common/internal/BUILD.bazel b/common/src/main/java/dev/cel/common/internal/BUILD.bazel index 6b470d98c..58b15b103 100644 --- a/common/src/main/java/dev/cel/common/internal/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/internal/BUILD.bazel @@ -398,8 +398,11 @@ java_library( java_library( name = "reflection_util", srcs = ["ReflectionUtil.java"], + tags = [ + ], deps = [ "//common/annotations", + "@maven//:com_google_guava_guava", ], ) diff --git a/common/src/main/java/dev/cel/common/internal/ReflectionUtil.java b/common/src/main/java/dev/cel/common/internal/ReflectionUtil.java index e513a446b..97bed650f 100644 --- a/common/src/main/java/dev/cel/common/internal/ReflectionUtil.java +++ b/common/src/main/java/dev/cel/common/internal/ReflectionUtil.java @@ -14,9 +14,11 @@ package dev.cel.common.internal; +import com.google.common.reflect.TypeToken; import dev.cel.common.annotations.Internal; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.lang.reflect.Type; /** * Utility class for invoking Java reflection. @@ -48,5 +50,18 @@ public static Object invoke(Method method, Object object, Object... params) { } } + /** Resolves a generic parameter of a base class from a type token. */ + public static Type resolveGenericParameter(TypeToken token, Class baseClass, int index) { + return token.resolveType(baseClass.getTypeParameters()[index]).getType(); + } + + /** + * Extracts the raw Class from a Type. Handles Class, ParameterizedType, and WildcardType (returns + * upper bound). Returns Object.class as fallback. + */ + public static Class getRawType(Type type) { + return TypeToken.of(type).getRawType(); + } + private ReflectionUtil() {} } diff --git a/extensions/BUILD.bazel b/extensions/BUILD.bazel index c6a029106..dea4cd760 100644 --- a/extensions/BUILD.bazel +++ b/extensions/BUILD.bazel @@ -56,3 +56,8 @@ java_library( name = "comprehensions", exports = ["//extensions/src/main/java/dev/cel/extensions:comprehensions"], ) + +java_library( + name = "native", + exports = ["//extensions/src/main/java/dev/cel/extensions:native"], +) diff --git a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel index f8e4bfc8c..73bab08c9 100644 --- a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel @@ -34,6 +34,7 @@ java_library( ":encoders", ":lists", ":math", + ":native", ":optional_library", ":protos", ":regex", @@ -185,6 +186,7 @@ java_library( "//common/types", "//common/values", "//common/values:cel_byte_string", + "//common/values:cel_value", "//compiler:compiler_builder", "//extensions:extension_library", "//parser:macro", @@ -318,3 +320,26 @@ java_library( "@maven//:com_google_guava_guava", ], ) + +java_library( + name = "native", + srcs = ["CelNativeTypesExtensions.java"], + tags = [ + ], + deps = [ + "//checker:checker_builder", + "//common/exceptions:attribute_not_found", + "//common/internal:reflection_util", + "//common/types", + "//common/types:type_providers", + "//common/values", + "//common/values:cel_byte_string", + "//common/values:cel_value", + "//common/values:cel_value_provider", + "//compiler:compiler_builder", + "//runtime", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + "@maven//:org_jspecify_jspecify", + ], +) diff --git a/extensions/src/main/java/dev/cel/extensions/CelExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelExtensions.java index 8f1770f3f..8adc39384 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelExtensions.java @@ -15,13 +15,13 @@ package dev.cel.extensions; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static java.util.Arrays.stream; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; import com.google.errorprone.annotations.InlineMe; import dev.cel.common.CelOptions; import dev.cel.extensions.CelMathExtensions.Function; +import java.util.EnumSet; import java.util.Set; /** @@ -350,6 +350,18 @@ public static CelComprehensionsExtensions comprehensions() { return COMPREHENSIONS_EXTENSIONS; } + /** + * Extensions for supporting native Java types (POJOs) in CEL. + * + *

Refer to README.md for details on property discovery, type mapping, and limitations. + * + *

Note: Passing classes with unsupported types or anonymous/local classes will result in an + * {@link IllegalArgumentException} when the runtime is built. + */ + public static CelNativeTypesExtensions nativeTypes(Class... classes) { + return CelNativeTypesExtensions.nativeTypes(classes); + } + /** * Retrieves all function names used by every extension libraries. * @@ -359,18 +371,17 @@ public static CelComprehensionsExtensions comprehensions() { */ public static ImmutableSet getAllFunctionNames() { return Streams.concat( - stream(CelMathExtensions.Function.values()) - .map(CelMathExtensions.Function::getFunction), - stream(CelStringExtensions.Function.values()) + EnumSet.allOf(Function.class).stream().map(CelMathExtensions.Function::getFunction), + EnumSet.allOf(CelStringExtensions.Function.class).stream() .map(CelStringExtensions.Function::getFunction), - stream(SetsFunction.values()).map(SetsFunction::getFunction), - stream(CelEncoderExtensions.Function.values()) + EnumSet.allOf(SetsFunction.class).stream().map(SetsFunction::getFunction), + EnumSet.allOf(CelEncoderExtensions.Function.class).stream() .map(CelEncoderExtensions.Function::getFunction), - stream(CelListsExtensions.Function.values()) + EnumSet.allOf(CelListsExtensions.Function.class).stream() .map(CelListsExtensions.Function::getFunction), - stream(CelRegexExtensions.Function.values()) + EnumSet.allOf(CelRegexExtensions.Function.class).stream() .map(CelRegexExtensions.Function::getFunction), - stream(CelComprehensionsExtensions.Function.values()) + EnumSet.allOf(CelComprehensionsExtensions.Function.class).stream() .map(CelComprehensionsExtensions.Function::getFunction)) .collect(toImmutableSet()); } diff --git a/extensions/src/main/java/dev/cel/extensions/CelNativeTypesExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelNativeTypesExtensions.java new file mode 100644 index 000000000..fd579a3bc --- /dev/null +++ b/extensions/src/main/java/dev/cel/extensions/CelNativeTypesExtensions.java @@ -0,0 +1,1041 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.extensions; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Arrays.stream; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.primitives.Primitives; +import com.google.common.primitives.UnsignedLong; +import com.google.common.reflect.TypeToken; +import com.google.errorprone.annotations.Immutable; +import dev.cel.checker.CelCheckerBuilder; +import dev.cel.common.exceptions.CelAttributeNotFoundException; +import dev.cel.common.internal.ReflectionUtil; +import dev.cel.common.types.CelType; +import dev.cel.common.types.CelTypeProvider; +import dev.cel.common.types.ListType; +import dev.cel.common.types.MapType; +import dev.cel.common.types.OptionalType; +import dev.cel.common.types.SimpleType; +import dev.cel.common.types.StructType; +import dev.cel.common.types.StructTypeReference; +import dev.cel.common.values.CelByteString; +import dev.cel.common.values.CelValue; +import dev.cel.common.values.CelValueConverter; +import dev.cel.common.values.CelValueProvider; +import dev.cel.common.values.StructValue; +import dev.cel.compiler.CelCompilerLibrary; +import dev.cel.runtime.CelRuntimeBuilder; +import dev.cel.runtime.CelRuntimeLibrary; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.lang.reflect.Type; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Queue; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Function; +import org.jspecify.annotations.Nullable; + +/** + * Extension for supporting native Java types (POJOs) in CEL. + * + *

This allows seamless plugin and evaluation of message creations and field selections without + * involving protobuf. + */ +@Immutable +public final class CelNativeTypesExtensions implements CelCompilerLibrary, CelRuntimeLibrary { + + private final NativeTypeRegistry registry; + + // Set of all standard java.lang.Object method names. + private static final ImmutableSet OBJECT_METHOD_NAMES = + stream(Object.class.getDeclaredMethods()).map(Method::getName).collect(toImmutableSet()); + + private static final ImmutableMap, CelType> JAVA_TO_CEL_TYPE_MAP = + ImmutableMap., CelType>builder() + .put(boolean.class, SimpleType.BOOL) + .put(Boolean.class, SimpleType.BOOL) + .put(String.class, SimpleType.STRING) + .put(int.class, SimpleType.INT) + .put(Integer.class, SimpleType.INT) + .put(long.class, SimpleType.INT) + .put(Long.class, SimpleType.INT) + .put(UnsignedLong.class, SimpleType.UINT) + .put(float.class, SimpleType.DOUBLE) + .put(Float.class, SimpleType.DOUBLE) + .put(double.class, SimpleType.DOUBLE) + .put(Double.class, SimpleType.DOUBLE) + .put(byte[].class, SimpleType.BYTES) + .put(CelByteString.class, SimpleType.BYTES) + .put(Duration.class, SimpleType.DURATION) + .put(Instant.class, SimpleType.TIMESTAMP) + .put(Object.class, SimpleType.DYN) + .buildOrThrow(); + + private static final ImmutableMap, Object> JAVA_TO_DEFAULT_VALUE_MAP = + ImmutableMap., Object>builder() + .put(boolean.class, false) + .put(Boolean.class, false) + .put(String.class, "") + .put(int.class, 0L) + .put(Integer.class, 0L) + .put(long.class, 0L) + .put(Long.class, 0L) + .put(UnsignedLong.class, UnsignedLong.ZERO) + .put(float.class, 0.0) + .put(Float.class, 0.0) + .put(double.class, 0.0) + .put(Double.class, 0.0) + .put(byte[].class, new byte[0]) + .put(CelByteString.class, CelByteString.EMPTY) + .put(Duration.class, Duration.ZERO) + .put(Instant.class, Instant.EPOCH) + .put(Optional.class, Optional.empty()) + .buildOrThrow(); + + /** Creates a new instance of {@link CelNativeTypesExtensions} for the given classes. */ + static CelNativeTypesExtensions nativeTypes(Class... classes) { + return new CelNativeTypesExtensions(new NativeTypeRegistry(NativeTypeScanner.scan(classes))); + } + + @VisibleForTesting + NativeTypeRegistry getRegistry() { + return registry; + } + + @Override + public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) { + runtimeBuilder.setValueProvider(registry); + runtimeBuilder.setTypeProvider(registry); + } + + @Override + public void setCheckerOptions(CelCheckerBuilder checkerBuilder) { + checkerBuilder.setTypeProvider(registry); + } + + /** + * NativeTypeScanner scans registered Java classes to extract properties and compile accessors. + */ + @VisibleForTesting + static final class NativeTypeScanner { + private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup(); + + private NativeTypeScanner() {} + + private static final class ScanResult { + private final ImmutableMap> classMap; + private final ImmutableMap typeMap; + private final ImmutableMap, StructType> classToTypeMap; + private final ImmutableMap, ImmutableMap> accessorMap; + + ScanResult( + ImmutableMap> classMap, + ImmutableMap typeMap, + ImmutableMap, StructType> classToTypeMap, + ImmutableMap, ImmutableMap> accessorMap) { + this.classMap = classMap; + this.typeMap = typeMap; + this.classToTypeMap = classToTypeMap; + this.accessorMap = accessorMap; + } + } + + private static ScanResult scan(Class... classes) { + ImmutableMap.Builder> classMapBuilder = ImmutableMap.builder(); + ImmutableMap.Builder typeMapBuilder = ImmutableMap.builder(); + ImmutableMap.Builder, StructType> classToTypeMapBuilder = ImmutableMap.builder(); + ImmutableMap.Builder, ImmutableMap> accessorMapBuilder = + ImmutableMap.builder(); + + Set> visited = new HashSet<>(); + Queue> queue = new ArrayDeque<>(Arrays.asList(classes)); + + while (!queue.isEmpty()) { + Class clazz = queue.poll(); + if (shouldSkip(clazz, visited)) { + continue; + } + visited.add(clazz); + + String typeName = getCelTypeName(clazz); + classMapBuilder.put(typeName, clazz); + + ImmutableMap accessors = scanProperties(clazz, queue); + accessorMapBuilder.put(clazz, accessors); + } + + ImmutableMap> classMap = classMapBuilder.buildOrThrow(); + ImmutableMap, ImmutableMap> accessorMap = + accessorMapBuilder.buildOrThrow(); + + for (Map.Entry> entry : classMap.entrySet()) { + String typeName = entry.getKey(); + Class clazz = entry.getValue(); + + StructType structType = createStructType(clazz, classMap, accessorMap); + typeMapBuilder.put(typeName, structType); + classToTypeMapBuilder.put(clazz, structType); + } + + ScanResult result = + new ScanResult( + classMap, + typeMapBuilder.buildOrThrow(), + classToTypeMapBuilder.buildOrThrow(), + accessorMap); + + validateRegisteredClasses(result.classToTypeMap, result.classMap, result.accessorMap); + + return result; + } + + private static void validateRegisteredClasses( + ImmutableMap, StructType> classToTypeMap, + ImmutableMap> classMap, + ImmutableMap, ImmutableMap> accessorMap) { + for (Class clazz : classToTypeMap.keySet()) { + for (String prop : getProperties(clazz)) { + try { + getPropertyType(clazz, prop, classMap, accessorMap); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "Unsupported type for property '" + prop + "' in class " + clazz.getName(), e); + } + } + } + } + + private static boolean shouldSkip(Class clazz, Set> visited) { + return clazz == null + || visited.contains(clazz) + || clazz.isInterface() + || isSupportedType(clazz); + } + + private static boolean isSupportedType(Class type) { + return JAVA_TO_CEL_TYPE_MAP.containsKey(type) + || type == Optional.class + || List.class.isAssignableFrom(type) + || Map.class.isAssignableFrom(type) + || type.isArray(); + } + + private static StructType createStructType( + Class clazz, + ImmutableMap> classMap, + ImmutableMap, ImmutableMap> accessorMap) { + return StructType.create( + getCelTypeName(clazz), + getProperties(clazz), + fieldName -> Optional.of(getPropertyType(clazz, fieldName, classMap, accessorMap))); + } + + private static CelType getPropertyType( + Class clazz, + String propertyName, + ImmutableMap> classMap, + ImmutableMap, ImmutableMap> accessorMap) { + ImmutableMap accessors = accessorMap.get(clazz); + if (accessors != null) { + PropertyAccessor accessor = accessors.get(propertyName); + if (accessor != null) { + return mapJavaTypeToCelType(accessor.targetType, accessor.genericTargetType, classMap); + } + } + throw new IllegalArgumentException("No public field or getter for " + propertyName); + } + + private static CelType mapJavaTypeToCelType( + Class type, Type genericType, ImmutableMap> classMap) { + + CelType celType = JAVA_TO_CEL_TYPE_MAP.get(type); + if (celType != null) { + return celType; + } + + if (type.isInterface() + && !List.class.isAssignableFrom(type) + && !Map.class.isAssignableFrom(type)) { + throw new IllegalArgumentException("Unsupported interface type: " + type.getName()); + } + + TypeToken token = TypeToken.of(genericType); + + if (List.class.isAssignableFrom(type)) { + Type elementType = ReflectionUtil.resolveGenericParameter(token, List.class, 0); + return ListType.create( + mapJavaTypeToCelType(ReflectionUtil.getRawType(elementType), elementType, classMap)); + } + + if (Map.class.isAssignableFrom(type)) { + Type keyType = ReflectionUtil.resolveGenericParameter(token, Map.class, 0); + Type valueType = ReflectionUtil.resolveGenericParameter(token, Map.class, 1); + + CelType celKeyType = + mapJavaTypeToCelType(ReflectionUtil.getRawType(keyType), keyType, classMap); + if (celKeyType == SimpleType.DOUBLE) { + throw new IllegalArgumentException("Decimals are not allowed as map keys in CEL."); + } + + return MapType.create( + celKeyType, + mapJavaTypeToCelType(ReflectionUtil.getRawType(valueType), valueType, classMap)); + } + + // Optional is a final class, so reference equality is equivalent to isAssignableFrom + // but slightly more performant than tree traversal. + if (type == Optional.class) { + Type optionalType = ReflectionUtil.resolveGenericParameter(token, Optional.class, 0); + return OptionalType.create( + mapJavaTypeToCelType(ReflectionUtil.getRawType(optionalType), optionalType, classMap)); + } + + String typeName = getCelTypeName(type); + if (classMap.containsKey(typeName)) { + return StructTypeReference.create(typeName); + } + + throw new IllegalArgumentException( + "Unsupported Java type for CEL mapping: " + type.getName()); + } + + private static ImmutableMap scanProperties( + Class clazz, Queue> queue) { + ImmutableMap.Builder builtAccessors = ImmutableMap.builder(); + + for (String propName : getProperties(clazz)) { + buildPropertyAccessor(clazz, propName, queue) + .ifPresent(accessor -> builtAccessors.put(propName, accessor)); + } + + return builtAccessors.buildOrThrow(); + } + + private static Optional buildPropertyAccessor( + Class clazz, String propName, Queue> queue) { + Method getter = findGetter(clazz, propName); + Field field = findField(clazz, propName); + + Class propType = null; + Type genericPropType = null; + Function compiledGetter = null; + BiConsumer compiledSetter = null; + + if (getter != null) { + propType = getter.getReturnType(); + genericPropType = getter.getGenericReturnType(); + discoverCustomTypes(genericPropType, queue); + compiledGetter = compileGetter(getter); + } else if (field != null) { + propType = field.getType(); + genericPropType = field.getGenericType(); + discoverCustomTypes(genericPropType, queue); + compiledGetter = compileFieldGetter(field); + } + + if (propType != null) { + Method setter = findSetter(clazz, propName, propType); + if (setter != null) { + compiledSetter = compileSetter(setter); + } else if (field != null + && !Modifier.isFinal(field.getModifiers()) + && Primitives.wrap(field.getType()) == Primitives.wrap(propType)) { + compiledSetter = compileFieldSetter(field); + } + } + + if (compiledGetter != null) { + return Optional.of( + new PropertyAccessor(compiledGetter, compiledSetter, propType, genericPropType)); + } + + return Optional.empty(); + } + + /** + * Recursively explores a {@link Type} and discovers any transitive, user-defined custom POJO + * classes nested inside multi-level generic collections, lists, maps, or optionals, pushing + * them into the scanning discovery queue. + * + *

"Custom types" are any public non-primitive, non-built-in Java classes that require + * explicit properties reflective scanning and mapping to a CEL StructType schema (as opposed to + * standard built-in types like {@code String}, {@code List}, or {@code Map}). + * + * @param type The Java type token or parameterized collection type to recursively unpack. + * @param queue The central scanning queue where newly discovered custom classes are pushed for + * subsequent properties discovery. + */ + private static void discoverCustomTypes(Type type, Queue> queue) { + Preconditions.checkNotNull(type, "Type to discover cannot be null."); + Preconditions.checkNotNull(queue, "Queue cannot be null."); + TypeToken token = TypeToken.of(type); + Class rawType = token.getRawType(); + + if (List.class.isAssignableFrom(rawType)) { + Type elementType = ReflectionUtil.resolveGenericParameter(token, List.class, 0); + discoverCustomTypes(elementType, queue); + return; + } + + if (Map.class.isAssignableFrom(rawType)) { + Type keyType = ReflectionUtil.resolveGenericParameter(token, Map.class, 0); + Type valueType = ReflectionUtil.resolveGenericParameter(token, Map.class, 1); + discoverCustomTypes(keyType, queue); + discoverCustomTypes(valueType, queue); + return; + } + + if (rawType == Optional.class) { + Type optionalType = ReflectionUtil.resolveGenericParameter(token, Optional.class, 0); + discoverCustomTypes(optionalType, queue); + return; + } + + if (!JAVA_TO_DEFAULT_VALUE_MAP.containsKey(rawType) + && Modifier.isPublic(rawType.getModifiers())) { + queue.add(rawType); + } + } + + private static Function compileGetter(Method getter) { + try { + // Required to unreflect public getters of package-private classes registered from other + // packages. + getter.setAccessible(true); + MethodHandle mh = LOOKUP.unreflect(getter); + return instance -> { + try { + return mh.invoke(instance); + } catch (Throwable t) { + throw new IllegalStateException("Failed to invoke getter for " + getter, t); + } + }; + } catch (IllegalAccessException e) { + throw new IllegalStateException("Failed to unreflect getter", e); + } + } + + private static Function compileFieldGetter(Field field) { + try { + // Required to unreflect public fields of package-private classes registered from other + // packages. + field.setAccessible(true); + MethodHandle mh = LOOKUP.unreflectGetter(field); + return instance -> { + try { + return mh.invoke(instance); + } catch (Throwable t) { + throw new IllegalStateException("Failed to get field " + field, t); + } + }; + } catch (IllegalAccessException e) { + throw new IllegalStateException("Failed to access field " + field, e); + } + } + + private static BiConsumer compileSetter(Method setter) { + try { + setter.setAccessible(true); + MethodHandle mh = LOOKUP.unreflect(setter); + return (instance, value) -> { + try { + mh.invoke(instance, value); + } catch (Throwable t) { + throw new IllegalStateException("Failed to invoke setter for " + setter, t); + } + }; + } catch (IllegalAccessException e) { + throw new IllegalStateException("Failed to unreflect setter", e); + } + } + + private static BiConsumer compileFieldSetter(Field field) { + try { + field.setAccessible(true); + MethodHandle mh = LOOKUP.unreflectSetter(field); + return (instance, value) -> { + try { + mh.invoke(instance, value); + } catch (Throwable t) { + throw new IllegalStateException("Failed to set field " + field, t); + } + }; + } catch (IllegalAccessException e) { + throw new IllegalStateException("Failed to access field " + field, e); + } + } + + private static @Nullable Method findGetter(Class clazz, String propertyName) { + String getterName = buildMethodName("get", propertyName); + String isGetterName = buildMethodName("is", propertyName); + + Method isGetter = null; + Method prefixLess = null; + + for (Method method : clazz.getMethods()) { + if (method.isBridge() || method.isSynthetic()) { + // Ignore compiler-generated duplicates + continue; + } + if (method.getParameterCount() == 0) { + String name = method.getName(); + if (name.equals(getterName)) { + return method; + } + if (name.equals(isGetterName)) { + isGetter = method; + } + if (name.equals(propertyName)) { + prefixLess = method; + } + } + } + + if (isGetter != null) { + return isGetter; + } + return prefixLess; + } + + private static @Nullable Field findField(Class clazz, String propertyName) { + for (Field field : clazz.getFields()) { + if (field.getName().equals(propertyName)) { + return field; + } + } + return null; + } + + private static @Nullable Method findSetter( + Class clazz, String propertyName, Class propertyType) { + String setterName = buildMethodName("set", propertyName); + return stream(clazz.getMethods()) + .filter(m -> !m.isBridge() && !m.isSynthetic()) + .filter(m -> m.getName().equals(setterName)) + .filter(m -> m.getParameterCount() == 1) + .filter(m -> m.getParameterTypes()[0].equals(propertyType)) + .findFirst() + .orElse(null); + } + + private static Set getAllDeclaredFieldNames(Class clazz) { + Set declaredFieldNames = new HashSet<>(); + Class currentClass = clazz; + while (currentClass != null) { + for (Field field : currentClass.getDeclaredFields()) { + declaredFieldNames.add(field.getName()); + } + currentClass = currentClass.getSuperclass(); + } + return declaredFieldNames; + } + + @VisibleForTesting + static ImmutableSet getProperties(Class clazz) { + ImmutableSet.Builder properties = ImmutableSet.builder(); + Set declaredFieldNames = getAllDeclaredFieldNames(clazz); + for (Field field : clazz.getFields()) { + if (Modifier.isStatic(field.getModifiers())) { + continue; + } + properties.add(field.getName()); + } + for (Method method : clazz.getMethods()) { + if (isGetter(method)) { + String propName = getPropertyName(method); + if (method.getName().startsWith("get") || method.getName().startsWith("is")) { + properties.add(propName); + } else if (declaredFieldNames.contains(propName)) { + properties.add(propName); + } + } + } + return properties.build(); + } + + private static boolean isGetter(Method method) { + if (Modifier.isStatic(method.getModifiers())) { + return false; + } + if (!Modifier.isPublic(method.getModifiers()) || method.getParameterCount() != 0) { + return false; + } + if (method.getReturnType() == void.class) { + return false; + } + String name = method.getName(); + if (OBJECT_METHOD_NAMES.contains(name)) { + return false; + } + if (name.startsWith("get")) { + return name.length() > 3; + } + if (name.startsWith("is")) { + return name.length() > 2 && Primitives.wrap(method.getReturnType()) == Boolean.class; + } + return true; + } + + private static String decapitalize(String name) { + Preconditions.checkArgument(name != null && !name.isEmpty()); + if (name.length() > 1 + && Character.isUpperCase(name.charAt(1)) + && Character.isUpperCase(name.charAt(0))) { + return name; + } + char[] chars = name.toCharArray(); + chars[0] = Character.toLowerCase(chars[0]); + return new String(chars); + } + + private static String getPropertyName(Method method) { + String name = method.getName(); + if (name.startsWith("get")) { + return decapitalize(name.substring(3)); + } + if (name.startsWith("is")) { + return decapitalize(name.substring(2)); + } + if (name.startsWith("set")) { + return decapitalize(name.substring(3)); + } + return name; + } + + private static String capitalize(String name) { + return Character.toUpperCase(name.charAt(0)) + name.substring(1); + } + + private static String buildMethodName(String prefix, String propertyName) { + return prefix + capitalize(propertyName); + } + } + + /** + * NativeTypeRegistry holds the state produced by NativeTypeScanner and acts as a CelValueProvider + * and CelTypeProvider for the CEL runtime. + */ + @VisibleForTesting + @Immutable + static final class NativeTypeRegistry implements CelValueProvider, CelTypeProvider { + + private final ImmutableMap> classMap; + private final ImmutableMap typeMap; + private final ImmutableMap, StructType> classToTypeMap; + private final ImmutableMap, ImmutableMap> accessorMap; + private final NativeValueConverter converter; + + private NativeTypeRegistry(NativeTypeScanner.ScanResult scanResult) { + this.classMap = scanResult.classMap; + this.typeMap = scanResult.typeMap; + this.classToTypeMap = scanResult.classToTypeMap; + this.accessorMap = scanResult.accessorMap; + this.converter = new NativeValueConverter(this); + } + + @Override + public ImmutableList types() { + return ImmutableList.copyOf(typeMap.values()); + } + + @Override + public Optional findType(String typeName) { + return Optional.ofNullable(typeMap.get(typeName)); + } + + @Override + public Optional newValue(String typeName, Map fields) { + Class clazz = classMap.get(typeName); + if (clazz == null) { + return Optional.empty(); + } + + try { + Constructor constructor = clazz.getDeclaredConstructor(); + constructor.setAccessible(true); + Object instance = constructor.newInstance(); + ImmutableMap accessors = accessorMap.get(clazz); + + for (Map.Entry entry : fields.entrySet()) { + PropertyAccessor accessor = accessors.get(entry.getKey()); + if (accessor == null) { + throw new IllegalArgumentException( + "Unknown field: " + entry.getKey() + " for type " + typeName); + } + Object value = + converter.toNative(entry.getValue(), accessor.targetType, accessor.genericTargetType); + accessor.setValue(instance, value); + } + + StructType structType = typeMap.get(typeName); + return Optional.of(new PojoStructValue(instance, accessors, structType)); + } catch (NoSuchMethodException e) { + throw new IllegalStateException( + "Failed to create instance of " + + typeName + + ": No public no-argument constructor found.", + e); + } catch (Exception e) { + throw new IllegalStateException("Failed to create instance of " + typeName, e); + } + } + + @Override + public CelValueConverter celValueConverter() { + return this.converter; + } + } + + /** + * PropertyAccessor holds the compiled getter and setter for a property, along with its type + * information. + */ + @Immutable + @SuppressWarnings("Immutable") + private static final class PropertyAccessor { + private final Function getter; + private final @Nullable BiConsumer setter; + private final Class targetType; + private final @Nullable Type genericTargetType; + + private PropertyAccessor( + Function getter, + @Nullable BiConsumer setter, + Class targetType, + @Nullable Type genericTargetType) { + this.getter = getter; + this.setter = setter; + this.targetType = targetType; + this.genericTargetType = genericTargetType; + } + + Object getValue(Object instance) { + return getter.apply(instance); + } + + Object getDefaultValue() { + return getDefaultValue(targetType); + } + + private static Object getDefaultValue(Class targetType) { + Object defaultValue = JAVA_TO_DEFAULT_VALUE_MAP.get(targetType); + if (defaultValue != null) { + return defaultValue; + } + if (List.class.isAssignableFrom(targetType)) { + return ImmutableList.of(); + } + if (Map.class.isAssignableFrom(targetType)) { + return ImmutableMap.of(); + } + + try { + Constructor constructor = targetType.getDeclaredConstructor(); + constructor.setAccessible(true); + return constructor.newInstance(); + } catch (Exception e) { + throw new IllegalStateException( + String.format( + "Failed to instantiate default instance for uninitialized field of type [%s]. " + + "Please ensure the class has a no-argument constructor or is initialized.", + targetType.getName()), + e); + } + } + + void setValue(Object instance, Object value) { + if (setter != null) { + setter.accept(instance, value); + } else { + throw new IllegalStateException("No setter found for property"); + } + } + } + + /** NativeValueConverter handles conversion between Java objects and CEL values. */ + @Immutable + private static final class NativeValueConverter extends CelValueConverter { + + private final NativeTypeRegistry registry; + + private NativeValueConverter(NativeTypeRegistry registry) { + this.registry = registry; + } + + @Override + public Object toRuntimeValue(Object value) { + if (value instanceof CelValue) { + return super.toRuntimeValue(value); + } + + Class clazz = value.getClass(); + ImmutableMap accessors = registry.accessorMap.get(clazz); + + if (accessors != null) { + return new PojoStructValue(value, accessors, registry.classToTypeMap.get(clazz)); + } + + return super.toRuntimeValue(value); + } + + Object toNative(Object value, Class targetType, Type genericType) { + if (value instanceof CelValue && !StructValue.class.isAssignableFrom(targetType)) { + value = super.maybeUnwrap(value); + } + if (targetType == Optional.class) { + if (value instanceof Optional) { + return value; + } + return Optional.ofNullable(value); + } + if (targetType == UnsignedLong.class) { + if (value instanceof UnsignedLong) { + return value; + } + } + if (targetType == byte[].class && value instanceof CelByteString) { + return ((CelByteString) value).toByteArray(); + } + + if (List.class.isAssignableFrom(targetType) && value instanceof List) { + return convertListToNative((List) value, targetType, genericType); + } + + if (Map.class.isAssignableFrom(targetType) && value instanceof Map) { + return convertMapToNative((Map) value, targetType, genericType); + } + + return downcastPrimitives(value, targetType); + } + + // Safe reflection collection cast. + @SuppressWarnings("unchecked") + private Object convertListToNative(List list, Class targetType, Type genericType) { + TypeToken token = TypeToken.of(genericType); + Type elementType = ReflectionUtil.resolveGenericParameter(token, List.class, 0); + Class componentType = ReflectionUtil.getRawType(elementType); + + boolean isConcreteClass = + !targetType.isInterface() && !Modifier.isAbstract(targetType.getModifiers()); + + // Instantiates concrete collection types to prevent ClassCastExceptions. + // For example, if a POJO field is declared as a concrete implementation like + // ArrayList, assigning a Guava ImmutableList will fail at runtime due to type + // mismatch. + if (isConcreteClass) { + List concreteList; + try { + concreteList = (List) targetType.getConstructor().newInstance(); + } catch (Exception e) { + throw new IllegalStateException( + "Failed to instantiate concrete collection class for field target type: " + + targetType.getName(), + e); + } + + for (Object element : list) { + concreteList.add(toNative(element, componentType, elementType)); + } + return concreteList; + } + + ImmutableList.Builder builder = null; + for (int i = 0; i < list.size(); i++) { + Object element = list.get(i); + Object converted = toNative(element, componentType, elementType); + if (!Objects.equals(converted, element) && builder == null) { + builder = ImmutableList.builderWithExpectedSize(list.size()); + for (int j = 0; j < i; j++) { + builder.add(list.get(j)); + } + } + if (builder != null) { + builder.add(converted); + } + } + + if (builder == null) { + return list; + } + return builder.build(); + } + + // Safe reflection collection cast. + @SuppressWarnings("unchecked") + private Object convertMapToNative(Map map, Class targetType, Type genericType) { + TypeToken token = TypeToken.of(genericType); + Type keyType = ReflectionUtil.resolveGenericParameter(token, Map.class, 0); + Type valueType = ReflectionUtil.resolveGenericParameter(token, Map.class, 1); + Class rawKeyType = ReflectionUtil.getRawType(keyType); + Class rawValueType = ReflectionUtil.getRawType(valueType); + + boolean isConcreteClass = + !targetType.isInterface() && !Modifier.isAbstract(targetType.getModifiers()); + + // Instantiates concrete map types to prevent ClassCastExceptions. + // For example, if a POJO field is declared as a concrete implementation like HashMap, + // assigning a Guava ImmutableMap will fail at runtime due to type mismatch. + if (isConcreteClass) { + Map concreteMap; + try { + concreteMap = (Map) targetType.getConstructor().newInstance(); + } catch (Exception e) { + throw new IllegalStateException( + "Failed to instantiate concrete map class for field target type: " + + targetType.getName(), + e); + } + + for (Map.Entry entry : map.entrySet()) { + concreteMap.put( + toNative(entry.getKey(), rawKeyType, keyType), + toNative(entry.getValue(), rawValueType, valueType)); + } + return concreteMap; + } + + ImmutableMap.Builder builder = null; + for (Map.Entry entry : map.entrySet()) { + Object key = entry.getKey(); + Object val = entry.getValue(); + Object convertedKey = toNative(key, rawKeyType, keyType); + Object convertedVal = toNative(val, rawValueType, valueType); + + if ((!Objects.equals(convertedKey, key) || !Objects.equals(convertedVal, val)) + && builder == null) { + builder = ImmutableMap.builderWithExpectedSize(map.size()); + for (Map.Entry prevEntry : map.entrySet()) { + if (Objects.equals(prevEntry.getKey(), entry.getKey())) { + break; + } + builder.put(prevEntry.getKey(), prevEntry.getValue()); + } + } + + if (builder != null) { + builder.put(convertedKey, convertedVal); + } + } + + if (builder == null) { + return map; + } + return builder.buildOrThrow(); + } + + private Object downcastPrimitives(Object value, Class targetType) { + Class wrappedTargetType = Primitives.wrap(targetType); + if (wrappedTargetType == Integer.class && value instanceof Long) { + return ((Long) value).intValue(); + } + if (wrappedTargetType == Float.class && value instanceof Double) { + return ((Double) value).floatValue(); + } + + return value; + } + } + + /** PojoStructValue represents a native Java object as a CEL struct value. */ + @SuppressWarnings("Immutable") + private static final class PojoStructValue extends StructValue { + private final Object instance; + private final ImmutableMap accessors; + private final StructType celType; + + private PojoStructValue( + Object instance, ImmutableMap accessors, StructType celType) { + this.instance = instance; + this.accessors = accessors; + this.celType = celType; + } + + @Override + public Object value() { + return instance; + } + + @Override + public boolean isZeroValue() { + throw new UnsupportedOperationException( + "isZeroValue is unsupported for ordinary Java POJOs. Please implement StructValue" + + " directly on the backing class if zero-value trait support is required."); + } + + @Override + public CelType celType() { + return celType; + } + + @Override + public Object select(String field) { + // Intentionally not proxying `find` here to avoid Optional wrapper allocations. + PropertyAccessor accessor = accessors.get(field); + if (accessor != null) { + Object value = accessor.getValue(instance); + if (value == null) { + return accessor.getDefaultValue(); + } + return value; + } + throw CelAttributeNotFoundException.forFieldResolution(field); + } + + @Override + public Optional find(String field) { + PropertyAccessor accessor = accessors.get(field); + if (accessor == null) { + return Optional.empty(); + } + Object value = accessor.getValue(instance); + return Optional.ofNullable(value); + } + } + + private static String getCelTypeName(Class clazz) { + String canonicalName = clazz.getCanonicalName(); + if (canonicalName == null) { + throw new IllegalArgumentException( + "Cannot get canonical name for class: " + + clazz.getName() + + ". Anonymous or local classes are not supported."); + } + return canonicalName; + } + + private CelNativeTypesExtensions(NativeTypeRegistry registry) { + this.registry = registry; + } +} diff --git a/extensions/src/main/java/dev/cel/extensions/CelOptionalLibrary.java b/extensions/src/main/java/dev/cel/extensions/CelOptionalLibrary.java index a3777c759..87a31341f 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelOptionalLibrary.java +++ b/extensions/src/main/java/dev/cel/extensions/CelOptionalLibrary.java @@ -53,6 +53,7 @@ import dev.cel.common.types.TypeParamType; import dev.cel.common.types.TypeType; import dev.cel.common.values.CelByteString; +import dev.cel.common.values.CelValue; import dev.cel.common.values.NullValue; import dev.cel.compiler.CelCompilerLibrary; import dev.cel.parser.CelMacro; @@ -415,9 +416,6 @@ private static ImmutableList elideOptionalCollection(Collection variables) throws Exception { + CelAbstractSyntaxTree ast = isParseOnly ? CEL.parse(expr).getAst() : CEL.compile(expr).getAst(); + return CEL.createProgram(ast).eval(variables); + } + + @Test + public void nativeTypes_createStructAndSelect() throws Exception { + Object result = + eval( + "TestAllTypesPublicFieldsPojo{boolVal:" + + " true, stringVal: 'hello'}.stringVal == 'hello'"); + + assertThat(result).isEqualTo(true); + } + + @Test + public void nativeTypes_createNestedStruct() throws Exception { + Object result = + eval( + "TestAllTypesPublicFieldsPojo{nestedVal:" + + " TestNestedType{value:" + + " 'nested'}}.nestedVal.value == 'nested'"); + + assertThat(result).isEqualTo(true); + } + + @Test + public void nativeTypes_resolveVariableWithNestedField() throws Exception { + Cel cel = + CelFactory.plannerCelBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addVar( + "pojo", + StructTypeReference.create(TestAllTypesPublicFieldsPojo.class.getCanonicalName())) + .addCompilerLibraries(NATIVE_TYPE_EXTENSIONS) + .addRuntimeLibraries(NATIVE_TYPE_EXTENSIONS) + .build(); + CelAbstractSyntaxTree ast = + isParseOnly + ? cel.parse("pojo.nestedVal.value == 'nested'").getAst() + : cel.compile("pojo.nestedVal.value == 'nested'").getAst(); + CelRuntime.Program program = cel.createProgram(ast); + TestAllTypesPublicFieldsPojo pojo = new TestAllTypesPublicFieldsPojo(); + TestNestedType nested = new TestNestedType(); + nested.value = "nested"; + pojo.nestedVal = nested; + + Object result = program.eval(ImmutableMap.of("pojo", pojo)); + + assertThat(result).isEqualTo(true); + } + + @Test + public void nativeTypes_createStructWithComplexTypes() throws Exception { + assertThat( + eval( + "TestAllTypesPublicFieldsPojo{" + + " durationVal: duration('5s')," + + " listVal: ['a', 'b']," + + " mapVal: {'key': 'value'}" + + "}.durationVal == duration('5s')")) + .isEqualTo(true); + } + + @Test + public void nativeTypes_transitiveDiscoveryThroughMap() throws Exception { + PojoWithCustomMap pojo = new PojoWithCustomMap(); + HashMap map = new HashMap<>(); + TestNestedType nested = new TestNestedType(); + nested.value = "hello"; + map.put("key", nested); + pojo.mapVal = map; + + CelNativeTypesExtensions extensions = + CelNativeTypesExtensions.nativeTypes(PojoWithCustomMap.class); + Cel cel = + CelFactory.plannerCelBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addVar("pojo", StructTypeReference.create(PojoWithCustomMap.class.getCanonicalName())) + .addCompilerLibraries(extensions) + .addRuntimeLibraries(extensions) + .build(); + + CelAbstractSyntaxTree ast = cel.compile("pojo.mapVal['key'].value == 'hello'").getAst(); + Object result = cel.createProgram(ast).eval(ImmutableMap.of("pojo", pojo)); + + assertThat(result).isEqualTo(true); + } + + @Test + public void nativeTypes_createStructWithOptionalField() throws Exception { + Cel cel = + CelFactory.plannerCelBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addCompilerLibraries( + CelExtensions.nativeTypes(TestRefValFieldType.class), CelExtensions.optional()) + .addRuntimeLibraries( + CelExtensions.nativeTypes(TestRefValFieldType.class), CelExtensions.optional()) + .build(); + CelAbstractSyntaxTree ast = + cel.parse( + "TestRefValFieldType{optionalName: optional.of('my name')}.optionalName.orValue('')" + + " == 'my name'") + .getAst(); + CelRuntime.Program program = cel.createProgram(ast); + + Object result = program.eval(); + + assertThat(result).isEqualTo(true); + } + + @Test + public void nativeTypes_createComprehensiveStruct() throws Exception { + String expr = + "ComprehensiveTestAllTypes{\n" + + " nestedVal: ComprehensiveTestNestedType{nestedMapVal: {1: false}},\n" + + " boolVal: true,\n" + + " bytesVal: b'hello',\n" + + " durationVal: duration('5s'),\n" + + " doubleVal: 1.5,\n" + + " floatVal: 2.5,\n" + + " int32Val: 10,\n" + + " int64Val: 20,\n" + + " stringVal: 'hello world',\n" + + " timestampVal: timestamp('2011-08-06T01:23:45Z'),\n" + + " uint32Val: 100,\n" + + " uint64Val: 200,\n" + + " listVal: [\n" + + " ComprehensiveTestNestedType{\n" + + " nestedListVal:['goodbye', 'cruel', 'world'],\n" + + " nestedMapVal: {42: true},\n" + + " customName: 'name'\n" + + " }\n" + + " ],\n" + + " arrayVal: [\n" + + " ComprehensiveTestNestedType{\n" + + " nestedListVal:['goodbye', 'cruel', 'world'],\n" + + " nestedMapVal: {42: true},\n" + + " customName: 'name'\n" + + " }\n" + + " ],\n" + + " mapVal: {'map-key': ComprehensiveTestAllTypes{boolVal: true}},\n" + + " customSliceVal: [TestNestedSliceType{value: 'none'}],\n" + + " customMapVal: {'even': TestMapVal{value: 'more'}},\n" + + " customName: 'name'\n" + + "}"; + + CelAbstractSyntaxTree ast = CEL.parse(expr).getAst(); + CelRuntime.Program program = CEL.createProgram(ast); + Object result = program.eval(); + + // Construct expected output + ComprehensiveTestAllTypes expected = new ComprehensiveTestAllTypes(); + expected.boolVal = true; + expected.bytesVal = "hello".getBytes(UTF_8); + expected.durationVal = Duration.ofSeconds(5); + expected.doubleVal = 1.5; + expected.floatVal = 2.5f; + expected.int32Val = 10; + expected.int64Val = 20; + expected.stringVal = "hello world"; + expected.timestampVal = Instant.parse("2011-08-06T01:23:45Z"); + expected.uint32Val = 100; + expected.uint64Val = 200; + expected.customName = "name"; + + ComprehensiveTestNestedType nested1 = new ComprehensiveTestNestedType(); + nested1.nestedMapVal = ImmutableMap.of(1L, false); + expected.nestedVal = nested1; + + ComprehensiveTestNestedType nested2 = new ComprehensiveTestNestedType(); + nested2.nestedListVal = ImmutableList.of("goodbye", "cruel", "world"); + nested2.nestedMapVal = ImmutableMap.of(42L, true); + nested2.customName = "name"; + expected.listVal = ImmutableList.of(nested2); + expected.arrayVal = ImmutableList.of(nested2); + + ComprehensiveTestAllTypes mapValElement = new ComprehensiveTestAllTypes(); + mapValElement.boolVal = true; + expected.mapVal = ImmutableMap.of("map-key", mapValElement); + + TestNestedSliceType sliceElem = new TestNestedSliceType(); + sliceElem.value = "none"; + expected.customSliceVal = ImmutableList.of(sliceElem); + + TestMapVal mapValElem = new TestMapVal(); + mapValElem.value = "more"; + expected.customMapVal = ImmutableMap.of("even", mapValElem); + + assertThat(result).isEqualTo(expected); + } + + @Test + public void nativeTypes_staticErrors() throws Exception { + // undeclared reference + CelValidationException e = + assertThrows(CelValidationException.class, () -> CEL.compile("UnknownType{}").getAst()); + assertThat(e).hasMessageThat().contains("reference"); + + // undefined field + e = + assertThrows( + CelValidationException.class, + () -> CEL.compile("ComprehensiveTestAllTypes{undefinedField: true}").getAst()); + assertThat(e).hasMessageThat().contains("undefined field"); + } + + @Test + public void nativeTypes_anonymousClass_throwsException() { + Object anon = new Object() {}; + + Class clazz = anon.getClass(); + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> CelExtensions.nativeTypes(clazz)); + assertThat(exception).hasMessageThat().contains("Anonymous or local classes are not supported"); + } + + @Test + public void nativeTypes_createStruct_privateConstructor() throws Exception { + Object result = eval("TestPrivateConstructorPojo{value:" + " 'hello'}"); + + assertThat(result).isInstanceOf(TestPrivateConstructorPojo.class); + assertThat(((TestPrivateConstructorPojo) result).value).isEqualTo("hello"); + } + + @Test + public void nativeTypes_precedence_getterOverField() throws Exception { + assertThat(eval("TestPrecedencePojo{}.value")).isEqualTo("hello"); + } + + @Test + public void nativeTypes_protoPrecedence() throws Exception { + CelValueProvider customProvider = + (structType, fields) -> { + if (structType.equals("cel.expr.conformance.proto3.TestAllTypes")) { + return Optional.of("POJO_WINS"); + } + return Optional.empty(); + }; + Cel cel = + CelFactory.plannerCelBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .setValueProvider(customProvider) + .addMessageTypes(TestAllTypes.getDescriptor()) + .build(); + CelAbstractSyntaxTree ast = cel.compile("cel.expr.conformance.proto3.TestAllTypes{}").getAst(); + + Object result = cel.createProgram(ast).eval(); + + assertThat(result).isNotEqualTo("POJO_WINS"); + assertThat(result).isInstanceOf(TestAllTypes.class); + } + + @Test + public void nativeTypes_createWithSetterAndSelectWithGetter() throws Exception { + assertThat(eval("TestGetterSetterPojo{value: 'hello', active: true}.value == 'hello'")) + .isEqualTo(true); + } + + @Test + public void nativeTypes_missingNoArgConstructor_throws() throws Exception { + CelEvaluationException exception = + assertThrows( + CelEvaluationException.class, + () -> eval("TestMissingNoArgConstructorPojo{value: 'hello'}")); + + assertThat(exception).hasMessageThat().contains("No public no-argument constructor found"); + } + + @Test + public void nativeTypes_createWithDeepConversion() throws Exception { + Object result = eval("TestDeepConversionPojo{ints: [1, 2], floats: {'a': 1.0, 'b': 2.0}}"); + + assertThat(result).isInstanceOf(TestDeepConversionPojo.class); + TestDeepConversionPojo pojo = (TestDeepConversionPojo) result; + assertThat(pojo.ints.get(0)).isEqualTo(1); + assertThat(pojo.floats).containsEntry("a", 1.0f); + } + + @Test + public void nativeTypes_wildcardList_success() throws Exception { + assertThat(eval("TestWildcardPojo{values: ['hello']}.values[0] == 'hello'")).isEqualTo(true); + } + + @Test + public void nativeTypes_unsupportedTypeSet_throwsOnRegistration() throws Exception { + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> CelExtensions.nativeTypes(TestUnsupportedSetPojo.class)); + assertThat(e).hasMessageThat().contains("Unsupported type for property 'strings'"); + } + + @Test + public void nativeTypes_arrayType_throwsOnRegistration() throws Exception { + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> CelExtensions.nativeTypes(TestArrayPojo.class)); + assertThat(e).hasMessageThat().contains("Unsupported type for property 'values'"); + } + + @Test + public void nativeTypes_packagePrivateClass_fieldAccess_success() throws Exception { + assertThat(eval("TestPackagePrivatePojo{value: 'hello'}.value == 'hello'")).isEqualTo(true); + } + + @Test + public void nativeTypes_packagePrivateClass_methodAccess_success() throws Exception { + assertThat(eval("TestPackagePrivateWithGetterPojo{value: 'hello'}.value == 'hello'")) + .isEqualTo(true); + } + + @Test + public void nativeTypes_privateField_notExposed() throws Exception { + CelNativeTypesExtensions extensions = CelExtensions.nativeTypes(TestPrivateFieldPojo.class); + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addLibraries(extensions) + .build(); + + CelValidationException e = + assertThrows( + CelValidationException.class, + () -> compiler.compile("TestPrivateFieldPojo{secret: 'hello'}").getAst()); + assertThat(e).hasMessageThat().contains("undefined field"); + } + + @Test + public void nativeTypes_inheritance_success() throws Exception { + // Accessing child's prefix-less getter + assertThat(eval("TestChildPojo{}.childValue")).isEqualTo("child"); + // Accessing parent's standard getter + assertThat(eval("TestChildPojo{}.standardValue")).isEqualTo("standard"); + // Accessing parent's prefix-less getter + assertThat(eval("TestChildPojo{}.parentValue")).isEqualTo("parent"); + } + + @Test + public void nativeTypes_standardType_cannotBeConstructedAsStruct() throws Exception { + CelValidationException e = + assertThrows( + CelValidationException.class, () -> CEL.compile("java.lang.String{}").getAst()); + assertThat(e).hasMessageThat().contains("undeclared reference"); + } + + @Test + public void nativeTypes_doubleMapKey_throwsOnRegistration() throws Exception { + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> CelExtensions.nativeTypes(TestDoubleMapKeyPojo.class)); + assertThat(e).hasCauseThat().hasMessageThat().contains("Decimals are not allowed as map keys"); + } + + @Test + public void nativeTypes_optionalCustomStruct_registered() throws Exception { + CelNativeTypesExtensions extensions = CelExtensions.nativeTypes(TestOptionalUrlPojo.class); + CelNativeTypesExtensions.NativeTypeRegistry registry = extensions.getRegistry(); + + Optional type = registry.findType(TestURLPojo.class.getCanonicalName()); + + assertThat(type).isPresent(); + } + + @Test + public void nativeTypes_abstractClass_throwsOnConstruction() throws Exception { + CelAbstractSyntaxTree ast = CEL.parse("TestAbstractPojo{}").getAst(); + CelRuntime.Program program = CEL.createProgram(ast); + + CelEvaluationException e = assertThrows(CelEvaluationException.class, () -> program.eval()); + assertThat(e).hasMessageThat().contains("Failed to create instance of"); + assertThat(e).hasCauseThat().isInstanceOf(InstantiationException.class); + } + + @Test + public void nativeTypes_nestedList_registered() throws Exception { + CelNativeTypesExtensions extensions = + CelExtensions.nativeTypes(TestAllTypesPublicFieldsPojo.class); + CelNativeTypesExtensions.NativeTypeRegistry registry = extensions.getRegistry(); + + Optional type = + registry.findType(TestAllTypesPublicFieldsPojo.class.getCanonicalName()); + + assertThat(type).isPresent(); + StructType structType = (StructType) type.get(); + assertThat(structType.findField("nestedListVal")).isPresent(); + } + + @Test + public void nativeTypes_invalidGetters_notRegistered() throws Exception { + ImmutableSet properties = + CelNativeTypesExtensions.NativeTypeScanner.getProperties( + TestAllTypesPublicFieldsPojo.class); + + assertThat(properties).doesNotContain("invalidParam"); + assertThat(properties).doesNotContain("invalidString"); + } + + @Test + public void nativeTypes_celByteString_success() throws Exception { + assertThat(eval("TestAllTypesPublicFieldsPojo{}.celBytesVal" + " == b'\\x01\\x02\\x03'")) + .isEqualTo(true); + } + + @Test + public void nativeTypes_celByteString_construction_success() throws Exception { + assertThat( + eval( + "dev.cel.extensions.CelNativeTypesExtensionsTest.TestAllTypesPublicFieldsPojo{celBytesVal:" + + " b'\\x01\\x02\\x03'}.celBytesVal == b'\\x01\\x02\\x03'")) + .isEqualTo(true); + } + + @Test + public void nativeTypes_singleLetterGetter_success() throws Exception { + Object result = eval("TestAllTypesPublicFieldsPojo{}.a == 'a'"); + assertThat(result).isEqualTo(true); + } + + @Test + public void nativeTypes_getterNamedGet_rejected() throws Exception { + CelValidationException e = + assertThrows( + CelValidationException.class, + () -> CEL.compile("TestAllTypesPublicFieldsPojo{}.get").getAst()); + assertThat(e).hasMessageThat().contains("undefined field 'get'"); + } + + @Test + public void nativeTypes_circularReference_success() throws Exception { + CelNativeTypesExtensions extensions = CelExtensions.nativeTypes(TestCircularA.class); + CelNativeTypesExtensions.NativeTypeRegistry registry = extensions.getRegistry(); + + Optional typeA = registry.findType(TestCircularA.class.getCanonicalName()); + Optional typeB = registry.findType(TestCircularB.class.getCanonicalName()); + + assertThat(typeA).isPresent(); + assertThat(typeB).isPresent(); + } + + @Test + public void nativeTypes_specialDecapitalization_success() throws Exception { + Object result = eval("dev.cel.extensions.CelNativeTypesExtensionsTest.TestURLPojo{}.URL"); + + assertThat(result).isEqualTo("https://google.com"); + } + + @Test + public void nativeTypes_prefixLessGetter_success() throws Exception { + CelNativeTypesExtensions extensions = CelExtensions.nativeTypes(TestPrefixLessGetterPojo.class); + CelRuntime celRuntime = + CelRuntimeFactory.plannerRuntimeBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addLibraries(extensions) + .build(); + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addLibraries(extensions) + .build(); + CelAbstractSyntaxTree ast = + celCompiler + .compile( + "dev.cel.extensions.CelNativeTypesExtensionsTest.TestPrefixLessGetterPojo{}.value") + .getAst(); + CelRuntime.Program program = celRuntime.createProgram(ast); + + Object result = program.eval(); + + assertThat(result).isEqualTo("hello"); + } + + @Test + public void nativeTypes_isGetter_success() throws Exception { + CelNativeTypesExtensions extensions = CelExtensions.nativeTypes(TestGetterSetterPojo.class); + CelRuntime celRuntime = + CelRuntimeFactory.plannerRuntimeBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addLibraries(extensions) + .build(); + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addLibraries(extensions) + .build(); + CelAbstractSyntaxTree ast = + celCompiler + .compile( + "dev.cel.extensions.CelNativeTypesExtensionsTest.TestGetterSetterPojo{active:" + + " true}.active") + .getAst(); + CelRuntime.Program program = celRuntime.createProgram(ast); + + Object result = program.eval(); + + assertThat(result).isEqualTo(true); + } + + @Test + public void nativeTypes_selectUndefinedField_parsedOnly_throwsException() throws Exception { + + CelNativeTypesExtensions extensions = + CelExtensions.nativeTypes(TestAllTypesPublicFieldsPojo.class); + + CelRuntime celRuntime = + CelRuntimeFactory.plannerRuntimeBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addLibraries(extensions) + .build(); + + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addLibraries(extensions) + .build(); + + CelAbstractSyntaxTree ast = celCompiler.parse("pojo.undefinedField").getAst(); + CelRuntime.Program program = celRuntime.createProgram(ast); + + TestAllTypesPublicFieldsPojo pojo = new TestAllTypesPublicFieldsPojo(); + + CelEvaluationException e = + assertThrows( + CelEvaluationException.class, () -> program.eval(ImmutableMap.of("pojo", pojo))); + assertThat(e).hasCauseThat().isInstanceOf(CelAttributeNotFoundException.class); + } + + @Test + public void nativeTypes_createWithUint_fromUnsignedLong() throws Exception { + CelNativeTypesExtensions extensions = + CelExtensions.nativeTypes(TestAllTypesPublicFieldsPojo.class); + CelRuntime celRuntime = + CelRuntimeFactory.plannerRuntimeBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addLibraries(extensions) + .build(); + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addLibraries(extensions) + .build(); + CelAbstractSyntaxTree ast = + celCompiler + .compile( + "dev.cel.extensions.CelNativeTypesExtensionsTest.TestAllTypesPublicFieldsPojo{uintVal:" + + " 42u}") + .getAst(); + CelRuntime.Program program = celRuntime.createProgram(ast); + + Object result = program.eval(); + + assertThat(result).isInstanceOf(TestAllTypesPublicFieldsPojo.class); + TestAllTypesPublicFieldsPojo pojo = (TestAllTypesPublicFieldsPojo) result; + assertThat(pojo.uintVal).isEqualTo(UnsignedLong.fromLongBits(42L)); + } + + @Test + public void nativeTypes_mapJavaTypeToCelType_allSupportedTypes() throws Exception { + CelNativeTypesExtensions extensions = + CelExtensions.nativeTypes(TestAllTypesPublicFieldsPojo.class); + CelNativeTypesExtensions.NativeTypeRegistry registry = extensions.getRegistry(); + + Optional type = + registry.findType(TestAllTypesPublicFieldsPojo.class.getCanonicalName()); + + assertThat(type).isPresent(); + assertThat(type.get()).isInstanceOf(StructType.class); + StructType structType = (StructType) type.get(); + + assertThat(structType.findField("boolVal").map(StructType.Field::type)) + .hasValue(SimpleType.BOOL); + assertThat(structType.findField("boolObjVal").map(StructType.Field::type)) + .hasValue(SimpleType.BOOL); + assertThat(structType.findField("int32Val").map(StructType.Field::type)) + .hasValue(SimpleType.INT); + assertThat(structType.findField("intObjVal").map(StructType.Field::type)) + .hasValue(SimpleType.INT); + assertThat(structType.findField("int64Val").map(StructType.Field::type)) + .hasValue(SimpleType.INT); + assertThat(structType.findField("longObjVal").map(StructType.Field::type)) + .hasValue(SimpleType.INT); + assertThat(structType.findField("uintVal").map(StructType.Field::type)) + .hasValue(SimpleType.UINT); + assertThat(structType.findField("floatVal").map(StructType.Field::type)) + .hasValue(SimpleType.DOUBLE); + assertThat(structType.findField("floatObjVal").map(StructType.Field::type)) + .hasValue(SimpleType.DOUBLE); + assertThat(structType.findField("doubleVal").map(StructType.Field::type)) + .hasValue(SimpleType.DOUBLE); + assertThat(structType.findField("doubleObjVal").map(StructType.Field::type)) + .hasValue(SimpleType.DOUBLE); + assertThat(structType.findField("stringVal").map(StructType.Field::type)) + .hasValue(SimpleType.STRING); + assertThat(structType.findField("bytesVal").map(StructType.Field::type)) + .hasValue(SimpleType.BYTES); + assertThat(structType.findField("durationVal").map(StructType.Field::type)) + .hasValue(SimpleType.DURATION); + assertThat(structType.findField("timestampVal").map(StructType.Field::type)) + .hasValue(SimpleType.TIMESTAMP); + + assertThat(structType.findField("listVal").map(StructType.Field::type).get()) + .isInstanceOf(ListType.class); + ListType listType = + (ListType) structType.findField("listVal").map(StructType.Field::type).get(); + assertThat(listType.elemType()).isEqualTo(SimpleType.STRING); + + assertThat(structType.findField("mapIntVal").map(StructType.Field::type).get()) + .isInstanceOf(MapType.class); + MapType mapType = (MapType) structType.findField("mapIntVal").map(StructType.Field::type).get(); + assertThat(mapType.keyType()).isEqualTo(SimpleType.STRING); + assertThat(mapType.valueType()).isEqualTo(SimpleType.INT); + + assertThat(structType.findField("optionalVal").map(StructType.Field::type).get()) + .isInstanceOf(OptionalType.class); + OptionalType optionalType = + (OptionalType) structType.findField("optionalVal").map(StructType.Field::type).get(); + assertThat(optionalType.parameters().get(0)).isEqualTo(SimpleType.STRING); + } + + @Test + public void nativeTypes_mapJavaTypeToCelType_customCollectionSubclasses() throws Exception { + CelNativeTypesExtensions extensions = CelExtensions.nativeTypes(TestCustomCollectionPojo.class); + CelNativeTypesExtensions.NativeTypeRegistry registry = extensions.getRegistry(); + + Optional type = registry.findType(TestCustomCollectionPojo.class.getCanonicalName()); + StructType structType = (StructType) type.get(); + + assertThat(structType.findField("customList").map(StructType.Field::type)) + .hasValue(ListType.create(SimpleType.STRING)); + assertThat(structType.findField("customMap").map(StructType.Field::type)) + .hasValue(MapType.create(SimpleType.STRING, SimpleType.INT)); + } + + @Test + public void nativeTypes_objectMethods_notExposed() throws Exception { + CelNativeTypesExtensions extensions = + CelExtensions.nativeTypes(TestAllTypesPublicFieldsPojo.class); + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addLibraries(extensions) + .build(); + + CelValidationException e = + assertThrows( + CelValidationException.class, + () -> compiler.compile("TestAllTypesPublicFieldsPojo{}.toString").getAst()); + assertThat(e).hasMessageThat().contains("undefined field"); + } + + @Test + public void nativeTypes_nullSafeTraversal() throws Exception { + Cel cel = + CelFactory.plannerCelBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addCompilerLibraries(NATIVE_TYPE_EXTENSIONS) + .addRuntimeLibraries(NATIVE_TYPE_EXTENSIONS) + .addVar( + "pojo", + StructTypeReference.create(TestAllTypesPublicFieldsPojo.class.getCanonicalName())) + .build(); + + TestAllTypesPublicFieldsPojo pojo = new TestAllTypesPublicFieldsPojo(); + ImmutableMap vars = ImmutableMap.of("pojo", pojo); + + assertThat(cel.createProgram(cel.compile("pojo.stringVal").getAst()).eval(vars)).isEqualTo(""); + assertThat(cel.createProgram(cel.compile("pojo.int64Val").getAst()).eval(vars)).isEqualTo(0L); + assertThat(cel.createProgram(cel.compile("pojo.nestedVal.value").getAst()).eval(vars)) + .isEqualTo(""); + CelAbstractSyntaxTree abstractPojoAst = cel.compile("pojo.abstractPojo.value").getAst(); + CelRuntime.Program abstractPojoProgram = cel.createProgram(abstractPojoAst); + CelEvaluationException e = + assertThrows(CelEvaluationException.class, () -> abstractPojoProgram.eval(vars)); + assertThat(e).hasMessageThat().contains("Failed to instantiate default instance"); + } + + @Test + public void nativeTypes_presenceTest() throws Exception { + Cel cel = + CelFactory.plannerCelBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addCompilerLibraries(NATIVE_TYPE_EXTENSIONS) + .addRuntimeLibraries(NATIVE_TYPE_EXTENSIONS) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addVar( + "pojo", + StructTypeReference.create(TestAllTypesPublicFieldsPojo.class.getCanonicalName())) + .build(); + + TestAllTypesPublicFieldsPojo pojo = new TestAllTypesPublicFieldsPojo(); + ImmutableMap nullVars = ImmutableMap.of("pojo", pojo); + + TestAllTypesPublicFieldsPojo pojoWithValues = new TestAllTypesPublicFieldsPojo(); + pojoWithValues.stringVal = "hello"; + ImmutableMap valueVars = ImmutableMap.of("pojo", pojoWithValues); + + boolean hasPopulatedString = + (boolean) cel.createProgram(cel.compile("has(pojo.stringVal)").getAst()).eval(valueVars); + assertThat(hasPopulatedString).isTrue(); + + boolean hasNullString = + (boolean) cel.createProgram(cel.compile("has(pojo.stringVal)").getAst()).eval(nullVars); + assertThat(hasNullString).isFalse(); + + assertThrows( + CelValidationException.class, () -> cel.compile("has(pojo.nonExistentField)").getAst()); + } + + @Test + public void nativeTypes_zeroValue_collections_comprehensions() throws Exception { + assertThat(eval("TestAllTypesPublicFieldsPojo{}.listVal.filter(x, true) == []")) + .isEqualTo(true); + assertThat(eval("TestAllTypesPublicFieldsPojo{}.listVal.map(x, x + 'foo') == []")) + .isEqualTo(true); + assertThat(eval("TestAllTypesPublicFieldsPojo{}.listVal.exists(x, true)")).isEqualTo(false); + assertThat(eval("TestAllTypesPublicFieldsPojo{}.listVal.all(x, true)")).isEqualTo(true); + assertThat(eval("TestAllTypesPublicFieldsPojo{}.mapVal.exists(k, true)")).isEqualTo(false); + assertThat(eval("TestAllTypesPublicFieldsPojo{}.mapVal.all(k, true)")).isEqualTo(true); + } + + @Test + public void nativeTypes_customStructValue_optionalOfNonZeroValue() throws Exception { + CelNativeTypesExtensions extensions = + CelExtensions.nativeTypes(TestCustomStructValuePojo.class); + Cel cel = + CelFactory.plannerCelBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addCompilerLibraries(extensions, CelExtensions.optional()) + .addRuntimeLibraries(extensions, CelExtensions.optional()) + .addVar( + "pojo", + StructTypeReference.create(TestCustomStructValuePojo.class.getCanonicalName())) + .build(); + + TestCustomStructValuePojo emptyPojo = + new TestCustomStructValuePojo(ImmutableMap.of("value", "")); + ImmutableMap emptyVars = ImmutableMap.of("pojo", emptyPojo); + boolean isEmptyNone = + (boolean) + cel.createProgram(cel.compile("!optional.ofNonZeroValue(pojo).hasValue()").getAst()) + .eval(emptyVars); + assertThat(isEmptyNone).isTrue(); + + TestCustomStructValuePojo populatedPojo = + new TestCustomStructValuePojo(ImmutableMap.of("value", "hello")); + ImmutableMap populatedVars = ImmutableMap.of("pojo", populatedPojo); + boolean isPopulatedPresent = + (boolean) + cel.createProgram(cel.compile("optional.ofNonZeroValue(pojo).hasValue()").getAst()) + .eval(populatedVars); + assertThat(isPopulatedPresent).isTrue(); + } + + @Test + public void nativeTypes_staticMembers_skipped() throws Exception { + ImmutableSet properties = + CelNativeTypesExtensions.NativeTypeScanner.getProperties(TestStaticMembersPojo.class); + + assertThat(properties).contains("instanceField"); + assertThat(properties).doesNotContain("STATIC_FIELD"); + assertThat(properties).doesNotContain("staticGetter"); + assertThat(properties).doesNotContain("staticProperty"); + } + + @Test + public void nativeTypes_deeplyNestedGenerics_discovered() throws Exception { + CelNativeTypesExtensions extensions = CelExtensions.nativeTypes(TestNestedGenericsPojo.class); + Cel cel = + CelFactory.plannerCelBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addCompilerLibraries(extensions) + .addRuntimeLibraries(extensions) + .addVar( + "pojo", StructTypeReference.create(TestNestedGenericsPojo.class.getCanonicalName())) + .build(); + + TestNestedSimplePojo simplePojo = new TestNestedSimplePojo(); + TestNestedGenericsPojo pojo = new TestNestedGenericsPojo(); + pojo.nestedList = ImmutableList.of(ImmutableList.of(simplePojo)); + + boolean result = + (boolean) + cel.createProgram(cel.compile("pojo.nestedList[0][0].value == 'nested'").getAst()) + .eval(ImmutableMap.of("pojo", pojo)); + + assertThat(result).isTrue(); + } + + @Test + public void nativeTypes_concreteCollectionInstantiation_success() throws Exception { + TestCustomCollectionPojo result = + (TestCustomCollectionPojo) + eval("TestCustomCollectionPojo{customList: ['a', 'b'], customMap: {'key': 1}}"); + + assertThat(result).isNotNull(); + assertThat(result.customList).containsExactly("a", "b"); + assertThat(result.customMap).containsEntry("key", 1L); + } + + @Test + public void nativeTypes_getterFieldTypeMismatch_readOnly() throws Exception { + CelAbstractSyntaxTree ast = + CEL.compile("TestGetterFieldTypeMismatchPojo{mismatchField: 'hello'}").getAst(); + + CelRuntime.Program program = CEL.createProgram(ast); + CelEvaluationException exception = + assertThrows(CelEvaluationException.class, () -> program.eval(ImmutableMap.of())); + + assertThat(exception.getMessage()).contains("Failed to create instance"); + } + + public static class TestAllTypesPublicFieldsPojo { + public void doNothing() {} + + public String getA() { + return "a"; + } + + public String get() { + return "get"; + } + + public boolean boolVal; + public String stringVal; + public long int64Val; + public int int32Val; + public double doubleVal; + public float floatVal; + public byte[] bytesVal; + public Duration durationVal; + public Instant timestampVal; + public TestNestedType nestedVal; + public List listVal; + public Map mapVal; + + public Boolean boolObjVal; + public Integer intObjVal; + public Long longObjVal; + public UnsignedLong uintVal; + public Float floatObjVal; + public Double doubleObjVal; + public Optional optionalVal; + public Optional optionalNestedVal; + public Map mapIntVal; + public List> nestedListVal; + public CelByteString celBytesVal = CelByteString.of(new byte[] {1, 2, 3}); + public TestAbstractPojo abstractPojo; + + public String getInvalidParam(String param) { + return "invalid"; + } + + public String isInvalidString() { + return "invalid"; + } + } + + public static class PojoWithCustomMap { + public Map mapVal; + } + + public static class TestNestedType { + public String value; + } + + static class TestPackagePrivatePojo { + public String value; + } + + static class TestPackagePrivateWithGetterPojo { + private String value; + + public String getValue() { + return value; + } + + public void setValue(String value) { + this.value = value; + } + } + + public static class TestPrivateConstructorPojo { + public String value; + + private TestPrivateConstructorPojo() { + this.value = "default"; + } + } + + public static class TestPrecedencePojo { + public int value = 1; + + public String getValue() { + return "hello"; + } + } + + static final class TestGetterSetterPojo { + private String value; + private boolean active; + + public String getValue() { + return value; + } + + public void setValue(String value) { + this.value = value; + } + + public boolean isActive() { + return active; + } + + public void setActive(boolean active) { + this.active = active; + } + } + + public static final class TestUnsupportedSetPojo { + public Set strings; + } + + public static final class TestDeepConversionPojo { + public List ints; + public Map floats; + } + + public static final class TestMissingNoArgConstructorPojo { + public String value; + + public TestMissingNoArgConstructorPojo(String value) { + this.value = value; + } + } + + public static class TestRefValFieldType { + public Optional optionalName; + public int intVal; + public Instant time; + } + + public static class ComprehensiveTestNestedType { + public List nestedListVal; + public Map nestedMapVal; + public String customName; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ComprehensiveTestNestedType)) { + return false; + } + ComprehensiveTestNestedType that = (ComprehensiveTestNestedType) o; + return Objects.equals(nestedListVal, that.nestedListVal) + && Objects.equals(nestedMapVal, that.nestedMapVal) + && Objects.equals(customName, that.customName); + } + + @Override + public int hashCode() { + return Objects.hash(nestedListVal, nestedMapVal, customName); + } + } + + public static class TestNestedSliceType { + public String value; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof TestNestedSliceType)) { + return false; + } + TestNestedSliceType that = (TestNestedSliceType) o; + return Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hashCode(value); + } + } + + public static class TestMapVal { + public String value; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof TestMapVal)) { + return false; + } + TestMapVal that = (TestMapVal) o; + return Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hashCode(value); + } + } + + public static class ComprehensiveTestAllTypes { + public ComprehensiveTestNestedType nestedVal; + public ComprehensiveTestNestedType nestedStructVal; + public boolean boolVal; + public byte[] bytesVal; + public Duration durationVal; + public double doubleVal; + public float floatVal; + public int int32Val; + public long int64Val; + public String stringVal; + public Instant timestampVal; + public long uint32Val; + public long uint64Val; + public List listVal; + public List arrayVal; + public byte[] bytesArrayVal; + public Map mapVal; + public List customSliceVal; + public Map customMapVal; + public String customName; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ComprehensiveTestAllTypes)) { + return false; + } + ComprehensiveTestAllTypes that = (ComprehensiveTestAllTypes) o; + return boolVal == that.boolVal + && doubleVal == that.doubleVal + && floatVal == that.floatVal + && int32Val == that.int32Val + && int64Val == that.int64Val + && uint32Val == that.uint32Val + && uint64Val == that.uint64Val + && Objects.equals(nestedVal, that.nestedVal) + && Objects.equals(nestedStructVal, that.nestedStructVal) + && Arrays.equals(bytesVal, that.bytesVal) + && Objects.equals(durationVal, that.durationVal) + && Objects.equals(stringVal, that.stringVal) + && Objects.equals(timestampVal, that.timestampVal) + && Objects.equals(listVal, that.listVal) + && Objects.equals(arrayVal, that.arrayVal) + && Arrays.equals(bytesArrayVal, that.bytesArrayVal) + && Objects.equals(mapVal, that.mapVal) + && Objects.equals(customSliceVal, that.customSliceVal) + && Objects.equals(customMapVal, that.customMapVal) + && Objects.equals(customName, that.customName); + } + + @Override + public int hashCode() { + int result = + Objects.hash( + nestedVal, + nestedStructVal, + boolVal, + durationVal, + doubleVal, + floatVal, + int32Val, + int64Val, + stringVal, + timestampVal, + uint32Val, + uint64Val, + listVal, + arrayVal, + mapVal, + customSliceVal, + customMapVal, + customName); + result = 31 * result + Arrays.hashCode(bytesVal); + result = 31 * result + Arrays.hashCode(bytesArrayVal); + return result; + } + } + + public static final class TestPrivateFieldPojo { + // Intentionally unread to test private fields are not exposed + @SuppressWarnings("UnusedVariable") + private String secret; + } + + public static class TestPrefixLessGetterPojo { + private String value = "hello"; + + public String value() { + return value; + } + } + + public static class TestParentPojo { + private String parentValue = "parent"; + private String standardValue = "standard"; + + public String parentValue() { + return parentValue; + } + + public String getStandardValue() { + return standardValue; + } + } + + public static class TestChildPojo extends TestParentPojo { + private String childValue = "child"; + + public String childValue() { + return childValue; + } + } + + // Intentionally violating style guide to test special decapitalization. + @SuppressWarnings("IdentifierName") + public static class TestURLPojo { + public String getURL() { + return "https://google.com"; + } + } + + public static class TestDoubleMapKeyPojo { + public Map map; + } + + public static class TestWildcardPojo { + public List values; + } + + public static class TestArrayPojo { + public String[] values; + } + + public static class TestOptionalUrlPojo { + public Optional optionalUrl; + } + + public abstract static class TestAbstractPojo { + public String value; + } + + public static class TestCircularA { + public TestCircularB b; + } + + public static class TestCircularB { + public TestCircularA a; + } + + public static class CustomListImplementation extends ArrayList {} + + public static class CustomMapImplementation extends HashMap {} + + public static class TestCustomCollectionPojo { + public CustomListImplementation customList; + public CustomMapImplementation customMap; + } + + @SuppressWarnings("Immutable") + static final class TestCustomStructValuePojo extends StructValue { + private final ImmutableMap fields; + + public TestCustomStructValuePojo(ImmutableMap fields) { + this.fields = fields; + } + + @Override + public Object value() { + return this; + } + + @Override + public boolean isZeroValue() { + for (Object val : fields.values()) { + if (val != null && !val.equals("") && !val.equals(0L)) { + return false; + } + } + return true; + } + + @Override + public CelType celType() { + return StructTypeReference.create(TestCustomStructValuePojo.class.getCanonicalName()); + } + + @Override + public Optional find(String field) { + return Optional.ofNullable(fields.get(field)); + } + + @Override + public Object select(String field) { + Object val = fields.get(field); + if (val == null) { + throw new NoSuchElementException("Field not found: " + field); + } + return val; + } + } + + public static class TestStaticMembersPojo { + public static final String STATIC_FIELD = "static_value"; + + public static String getStaticGetter() { + return "static_getter_value"; + } + + public static String staticProperty() { + return "static_property_value"; + } + + public String instanceField = "instance_value"; + } + + public static class TestNestedGenericsPojo { + public List> nestedList; + public Map> nestedMap; + } + + public static class TestNestedSimplePojo { + public String value = "nested"; + } + + public static class TestGetterFieldTypeMismatchPojo { + public int mismatchField = 10; + + public String getMismatchField() { + return "mismatch"; + } + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java b/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java index 0000ad764..561e25f7f 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java @@ -192,7 +192,9 @@ private static Object applyQualifiers( // Avoid enhanced for loop to prevent UnmodifiableIterator from being allocated for (int i = 0; i < qualifiers.size(); i++) { - obj = qualifiers.get(i).qualify(obj); + Qualifier element = qualifiers.get(i); + obj = element.qualify(obj); + obj = celValueConverter.toRuntimeValue(obj); } return celValueConverter.maybeUnwrap(obj); diff --git a/runtime/src/main/java/dev/cel/runtime/planner/RelativeAttribute.java b/runtime/src/main/java/dev/cel/runtime/planner/RelativeAttribute.java index addbeb4d0..38f733c79 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/RelativeAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/RelativeAttribute.java @@ -42,7 +42,9 @@ public Object resolve(long exprId, GlobalResolver ctx, ExecutionFrame frame) { // Avoid enhanced for loop to prevent UnmodifiableIterator from being allocated for (int i = 0; i < qualifiers.size(); i++) { - obj = qualifiers.get(i).qualify(obj); + Qualifier element = qualifiers.get(i); + obj = element.qualify(obj); + obj = celValueConverter.toRuntimeValue(obj); } return celValueConverter.maybeUnwrap(obj); From 7bf3e04351570ce8a9e49e947aae18c335abb2bc Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 14 May 2026 16:32:52 -0700 Subject: [PATCH 63/66] Optimize list/map adaptations PiperOrigin-RevId: 915678440 --- .../java/dev/cel/common/values/BUILD.bazel | 32 +++++--- .../cel/common/values/CelPreAdaptedList.java | 49 ++++++++++++ .../cel/common/values/CelValueConverter.java | 74 +++++++++++++++---- .../common/values/ProtoCelValueConverter.java | 12 +++ 4 files changed, 144 insertions(+), 23 deletions(-) create mode 100644 common/src/main/java/dev/cel/common/values/CelPreAdaptedList.java diff --git a/common/src/main/java/dev/cel/common/values/BUILD.bazel b/common/src/main/java/dev/cel/common/values/BUILD.bazel index d572bb2bc..5ccc498fd 100644 --- a/common/src/main/java/dev/cel/common/values/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/values/BUILD.bazel @@ -60,7 +60,6 @@ java_library( deps = [ "//common/values", "@maven//:com_google_errorprone_error_prone_annotations", - "@maven//:com_google_guava_guava", ], ) @@ -72,7 +71,6 @@ cel_android_library( deps = [ "//common/values:values_android", "@maven//:com_google_errorprone_error_prone_annotations", - "@maven_android//:com_google_guava_guava", ], ) @@ -118,7 +116,6 @@ java_library( deps = [ ":values", "//common/annotations", - "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", "@maven//:org_jspecify_jspecify", ], @@ -134,12 +131,31 @@ cel_android_library( deps = [ ":values_android", "//common/annotations", - "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:org_jspecify_jspecify", "@maven_android//:com_google_guava_guava", ], ) +java_library( + name = "preadapted_list", + srcs = [ + "CelPreAdaptedList.java", + ], + tags = [ + ], + deps = ["//common/annotations"], +) + +cel_android_library( + name = "preadapted_list_android", + srcs = [ + "CelPreAdaptedList.java", + ], + tags = [ + ], + deps = ["//common/annotations"], +) + java_library( name = "values", srcs = CEL_VALUES_SOURCES, @@ -148,6 +164,7 @@ java_library( deps = [ ":cel_byte_string", ":cel_value", + ":preadapted_list", "//:auto_value", "//common/annotations", "//common/types", @@ -198,6 +215,7 @@ cel_android_library( deps = [ ":cel_byte_string", ":cel_value_android", + ":preadapted_list_android", "//:auto_value", "//common/annotations", "//common/types:type_providers_android", @@ -226,7 +244,6 @@ java_library( ], deps = [ ":cel_byte_string", - ":values", "//common/annotations", "//common/internal:proto_time_utils", "//common/internal:well_known_proto", @@ -261,6 +278,7 @@ java_library( ], deps = [ ":base_proto_cel_value_converter", + ":preadapted_list", ":values", "//:auto_value", "//common:options", @@ -273,7 +291,6 @@ java_library( "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", - "@maven//:org_jspecify_jspecify", ], ) @@ -316,8 +333,6 @@ java_library( "//protobuf:cel_lite_descriptor", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", - "@maven//:com_google_protobuf_protobuf_java", - "@maven//:org_jspecify_jspecify", "@maven_android//:com_google_protobuf_protobuf_javalite", ], ) @@ -343,7 +358,6 @@ cel_android_library( "//protobuf:cel_lite_descriptor", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", - "@maven//:org_jspecify_jspecify", "@maven_android//:com_google_guava_guava", "@maven_android//:com_google_protobuf_protobuf_javalite", ], diff --git a/common/src/main/java/dev/cel/common/values/CelPreAdaptedList.java b/common/src/main/java/dev/cel/common/values/CelPreAdaptedList.java new file mode 100644 index 000000000..c0ff25e45 --- /dev/null +++ b/common/src/main/java/dev/cel/common/values/CelPreAdaptedList.java @@ -0,0 +1,49 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.values; + +import dev.cel.common.annotations.Internal; +import java.util.AbstractList; +import java.util.List; +import java.util.RandomAccess; + +/** + * A zero-allocation view over a list we know is already adapted. + * + *

This class purely exists as an optimization scheme to avoid redundant collection traversals in + * {@link CelValueConverter}, and is not intended for general use. + */ +@Internal +final class CelPreAdaptedList extends AbstractList implements RandomAccess { + private final List delegate; + + private CelPreAdaptedList(List delegate) { + this.delegate = delegate; + } + + static CelPreAdaptedList wrap(List safeList) { + return new CelPreAdaptedList<>(safeList); + } + + @Override + public E get(int index) { + return delegate.get(index); + } + + @Override + public int size() { + return delegate.size(); + } +} diff --git a/common/src/main/java/dev/cel/common/values/CelValueConverter.java b/common/src/main/java/dev/cel/common/values/CelValueConverter.java index 89f5ab100..20deef1d3 100644 --- a/common/src/main/java/dev/cel/common/values/CelValueConverter.java +++ b/common/src/main/java/dev/cel/common/values/CelValueConverter.java @@ -20,8 +20,11 @@ import com.google.errorprone.annotations.Immutable; import dev.cel.common.annotations.Internal; import java.util.Collection; +import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.RandomAccess; import java.util.function.Function; /** @@ -53,16 +56,11 @@ public static CelValueConverter getDefaultInstance() { *

The value may be a {@link CelValue}, a {@link Collection} or a {@link Map}. */ public Object maybeUnwrap(Object value) { - if (value instanceof CelValue) { - return unwrap((CelValue) value); + if (value instanceof CelValue || value instanceof CelPreAdaptedList) { + return value instanceof CelValue ? unwrap((CelValue) value) : value; } - Object mapped = mapContainer(value, maybeUnwrapFunction); - if (mapped != value) { - return mapped; - } - - return value; + return mapContainer(value, maybeUnwrapFunction); } /** @@ -70,6 +68,34 @@ public Object maybeUnwrap(Object value) { * Returns the original value if it's not a supported container. */ protected Object mapContainer(Object value, Function mapper) { + + // Zero allocation path for standard lists that support O(1) indexing + // Generally, protobuf lists (backed by arrays) fall into this category + if (value instanceof List && value instanceof RandomAccess) { + List list = (List) value; + for (int i = 0; i < list.size(); i++) { + Object element = list.get(i); + Object mapped = mapper.apply(element); + + if (mapped != element) { + ImmutableList.Builder builder = + ImmutableList.builderWithExpectedSize(list.size()); + for (int j = 0; j < i; j++) { + builder.add(list.get(j)); + } + builder.add(mapped); + for (int j = i + 1; j < list.size(); j++) { + builder.add(mapper.apply(list.get(j))); + } + return builder.build(); + } + } + + // Zero allocations if unmodified + return value; + } + + // Fallback for lists that are unordered if (value instanceof Collection) { Collection collection = (Collection) value; ImmutableList.Builder builder = @@ -82,12 +108,32 @@ protected Object mapContainer(Object value, Function mapper) { if (value instanceof Map) { Map map = (Map) value; - ImmutableMap.Builder builder = - ImmutableMap.builderWithExpectedSize(map.size()); - for (Map.Entry entry : map.entrySet()) { - builder.put(mapper.apply(entry.getKey()), mapper.apply(entry.getValue())); + Iterator> iterator = map.entrySet().iterator(); + + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + Object mappedKey = mapper.apply(entry.getKey()); + Object mappedValue = mapper.apply(entry.getValue()); + + if (mappedKey != entry.getKey() || mappedValue != entry.getValue()) { + ImmutableMap.Builder builder = + ImmutableMap.builderWithExpectedSize(map.size()); + + for (Map.Entry prevEntry : map.entrySet()) { + if (prevEntry.getKey() == entry.getKey()) { + break; + } + builder.put(mapper.apply(prevEntry.getKey()), mapper.apply(prevEntry.getValue())); + } + builder.put(mappedKey, mappedValue); + while (iterator.hasNext()) { + Map.Entry nextEntry = iterator.next(); + builder.put(mapper.apply(nextEntry.getKey()), mapper.apply(nextEntry.getValue())); + } + return builder.buildOrThrow(); + } } - return builder.buildOrThrow(); + return value; } return value; @@ -96,7 +142,7 @@ protected Object mapContainer(Object value, Function mapper) { public Object toRuntimeValue(Object value) { Preconditions.checkNotNull(value); - if (value instanceof CelValue) { + if (value instanceof CelValue || value instanceof CelPreAdaptedList) { return value; } diff --git a/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java b/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java index 565c65438..948df759c 100644 --- a/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java +++ b/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java @@ -167,6 +167,18 @@ public Object fromProtoMessageFieldToCelValue(Message message, FieldDescriptor f break; } + if (fieldDescriptor.isRepeated()) { + switch (fieldDescriptor.getType()) { + case INT64: + case BOOL: + case STRING: + case DOUBLE: + return CelPreAdaptedList.wrap((List) result); + default: + break; + } + } + return toRuntimeValue(result); } From 0837a5190501466d3b4545c72f26ef4e8dbbedef Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 14 May 2026 18:21:01 -0700 Subject: [PATCH 64/66] Fix double qualification in NamespacedAttribute PiperOrigin-RevId: 915718942 --- .../cel/runtime/planner/NamespacedAttribute.java | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java b/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java index 561e25f7f..95a4489fd 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/NamespacedAttribute.java @@ -171,19 +171,15 @@ private GlobalResolver unwrapToNonLocal(GlobalResolver resolver) { @Override public NamespacedAttribute addQualifier(Qualifier qualifier) { - ImmutableMap.Builder attributesBuilder = ImmutableMap.builder(); - CelAttribute.Qualifier celQualifier = CelAttribute.Qualifier.fromGeneric(qualifier.value()); - - for (Map.Entry entry : candidateAttributes.entrySet()) { - attributesBuilder.put(entry.getKey(), entry.getValue().qualify(celQualifier)); - } - return new NamespacedAttribute( typeProvider, celValueConverter, - attributesBuilder.buildOrThrow(), + candidateAttributes, disambiguateNames, - ImmutableList.builder().addAll(qualifiers).add(qualifier).build()); + ImmutableList.builderWithExpectedSize(qualifiers.size() + 1) + .addAll(qualifiers) + .add(qualifier) + .build()); } private static Object applyQualifiers( From bfc4fdfa9a846c493028ca986c06eda88b2a207e Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 14 May 2026 19:43:25 -0700 Subject: [PATCH 65/66] Refactor native extensions to separately hold type references PiperOrigin-RevId: 915744276 --- .../extensions/CelNativeTypesExtensions.java | 75 ++++++++++--------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/extensions/src/main/java/dev/cel/extensions/CelNativeTypesExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelNativeTypesExtensions.java index fd579a3bc..ae9483f7c 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelNativeTypesExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelNativeTypesExtensions.java @@ -356,12 +356,12 @@ private static Optional buildPropertyAccessor( if (getter != null) { propType = getter.getReturnType(); genericPropType = getter.getGenericReturnType(); - discoverCustomTypes(genericPropType, queue); + queue.addAll(TypeReferenceCollector.collect(genericPropType)); compiledGetter = compileGetter(getter); } else if (field != null) { propType = field.getType(); genericPropType = field.getGenericType(); - discoverCustomTypes(genericPropType, queue); + queue.addAll(TypeReferenceCollector.collect(genericPropType)); compiledGetter = compileFieldGetter(field); } @@ -386,46 +386,53 @@ private static Optional buildPropertyAccessor( /** * Recursively explores a {@link Type} and discovers any transitive, user-defined custom POJO - * classes nested inside multi-level generic collections, lists, maps, or optionals, pushing - * them into the scanning discovery queue. + * classes nested inside multi-level generic collections, lists, maps, or optionals, collecting + * them for subsequent properties discovery. * *

"Custom types" are any public non-primitive, non-built-in Java classes that require * explicit properties reflective scanning and mapping to a CEL StructType schema (as opposed to * standard built-in types like {@code String}, {@code List}, or {@code Map}). - * - * @param type The Java type token or parameterized collection type to recursively unpack. - * @param queue The central scanning queue where newly discovered custom classes are pushed for - * subsequent properties discovery. */ - private static void discoverCustomTypes(Type type, Queue> queue) { - Preconditions.checkNotNull(type, "Type to discover cannot be null."); - Preconditions.checkNotNull(queue, "Queue cannot be null."); - TypeToken token = TypeToken.of(type); - Class rawType = token.getRawType(); - - if (List.class.isAssignableFrom(rawType)) { - Type elementType = ReflectionUtil.resolveGenericParameter(token, List.class, 0); - discoverCustomTypes(elementType, queue); - return; - } + private static final class TypeReferenceCollector { + private final Set> collectedTypes = new HashSet<>(); + + /** + * Traverses the given type and returns an immutable set of all custom POJO classes found. + * + * @param type The Java type token or parameterized collection type to recursively unpack. + */ + private static ImmutableSet> collect(Type type) { + TypeReferenceCollector collector = new TypeReferenceCollector(); + collector.discover(type); + return ImmutableSet.copyOf(collector.collectedTypes); + } + + private void discover(Type type) { + Preconditions.checkNotNull(type, "Type to discover cannot be null."); + TypeToken token = TypeToken.of(type); + Class rawType = token.getRawType(); + + if (List.class.isAssignableFrom(rawType)) { + discover(ReflectionUtil.resolveGenericParameter(token, List.class, 0)); + return; + } - if (Map.class.isAssignableFrom(rawType)) { - Type keyType = ReflectionUtil.resolveGenericParameter(token, Map.class, 0); - Type valueType = ReflectionUtil.resolveGenericParameter(token, Map.class, 1); - discoverCustomTypes(keyType, queue); - discoverCustomTypes(valueType, queue); - return; - } + if (Map.class.isAssignableFrom(rawType)) { + discover(ReflectionUtil.resolveGenericParameter(token, Map.class, 0)); + discover(ReflectionUtil.resolveGenericParameter(token, Map.class, 1)); + return; + } - if (rawType == Optional.class) { - Type optionalType = ReflectionUtil.resolveGenericParameter(token, Optional.class, 0); - discoverCustomTypes(optionalType, queue); - return; - } + if (rawType == Optional.class) { + discover(ReflectionUtil.resolveGenericParameter(token, Optional.class, 0)); + return; + } - if (!JAVA_TO_DEFAULT_VALUE_MAP.containsKey(rawType) - && Modifier.isPublic(rawType.getModifiers())) { - queue.add(rawType); + // Custom types are non-builtin, public classes + if (!JAVA_TO_DEFAULT_VALUE_MAP.containsKey(rawType) + && Modifier.isPublic(rawType.getModifiers())) { + collectedTypes.add(rawType); + } } } From c57d9d02259795c761e82f8870303b93a7a672c8 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 15 May 2026 16:48:33 -0700 Subject: [PATCH 66/66] Release 0.13.0 PiperOrigin-RevId: 916246350 --- MODULE.bazel | 2 +- README.md | 4 ++-- publish/cel_version.bzl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index 895715a5f..6689158c6 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -46,7 +46,7 @@ TRUTH_VERSION = "1.4.4" PROTOBUF_JAVA_VERSION = "4.33.5" -CEL_VERSION = "0.12.0" +CEL_VERSION = "0.13.0" # Compile only artifacts [ diff --git a/README.md b/README.md index 40bd9deac..78e38961f 100644 --- a/README.md +++ b/README.md @@ -55,14 +55,14 @@ CEL-Java is available in Maven Central Repository. [Download the JARs here][8] o dev.cel cel - 0.12.0 + 0.13.0 ``` **Gradle** ```gradle -implementation 'dev.cel:cel:0.12.0' +implementation 'dev.cel:cel:0.13.0' ``` Then run this example: diff --git a/publish/cel_version.bzl b/publish/cel_version.bzl index b40addd73..70fa1a010 100644 --- a/publish/cel_version.bzl +++ b/publish/cel_version.bzl @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. """Maven artifact version for CEL.""" -CEL_VERSION = "0.12.0" +CEL_VERSION = "0.13.0"