From 23c9913ef7f8ed8835165fc6c65617e3090f1630 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 28 Apr 2026 16:26:10 -0700 Subject: [PATCH 01/87] Add option to specify an arena and persist the resolved types from type checking. PiperOrigin-RevId: 907236091 --- checker/BUILD | 5 ++++ checker/internal/type_checker_impl.cc | 34 +++++++++++++++++++------ checker/internal/type_checker_impl.h | 4 +-- checker/type_checker.cc | 36 +++++++++++++++++++++++++++ checker/type_checker.h | 15 ++++++++--- checker/validation_result.h | 20 ++++++++++++++- compiler/BUILD | 2 ++ compiler/compiler.h | 12 +++++++-- compiler/compiler_factory.cc | 7 +++--- compiler/compiler_factory_test.cc | 23 +++++++++++++++++ 10 files changed, 139 insertions(+), 19 deletions(-) create mode 100644 checker/type_checker.cc diff --git a/checker/BUILD b/checker/BUILD index f1e0cef3c..27a1eb84e 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -50,7 +50,9 @@ cc_library( ":type_check_issue", "//common:ast", "//common:source", + "//common:type", "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -74,11 +76,14 @@ cc_test( cc_library( name = "type_checker", + srcs = ["type_checker.cc"], hdrs = ["type_checker.h"], deps = [ ":validation_result", "//common:ast", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", ], ) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 8f67efbde..05601fdbb 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -1176,11 +1177,13 @@ class ResolveRewriter : public AstRewriterBase { explicit ResolveRewriter(const ResolveVisitor& visitor, const TypeInferenceContext& inference_context, const CheckerOptions& options, - Ast::ReferenceMap& references, Ast::TypeMap& types) + Ast::ReferenceMap& references, Ast::TypeMap& types, + ValidationResult::TypeMap& resolved_types) : visitor_(visitor), inference_context_(inference_context), reference_map_(references), type_map_(types), + resolved_types_(resolved_types), options_(options) {} bool PostVisitRewrite(Expr& expr) override { bool rewritten = false; @@ -1235,6 +1238,7 @@ class ResolveRewriter : public AstRewriterBase { return rewritten; } type_map_[expr.id()] = *std::move(flattened_type); + resolved_types_[expr.id()] = iter->second; rewritten = true; } @@ -1249,23 +1253,28 @@ class ResolveRewriter : public AstRewriterBase { const TypeInferenceContext& inference_context_; Ast::ReferenceMap& reference_map_; Ast::TypeMap& type_map_; + ValidationResult::TypeMap& resolved_types_; const CheckerOptions& options_; }; } // namespace -absl::StatusOr TypeCheckerImpl::Check( - std::unique_ptr ast) const { - google::protobuf::Arena type_arena; +absl::StatusOr TypeCheckerImpl::CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* arena) const { + std::optional type_arena; + if (arena == nullptr) { + type_arena.emplace(); + arena = &(*type_arena); + } std::vector issues; CEL_ASSIGN_OR_RETURN(auto generator, NamespaceGenerator::Create(env_.container())); TypeInferenceContext type_inference_context( - &type_arena, options_.enable_legacy_null_assignment); + arena, options_.enable_legacy_null_assignment); ResolveVisitor visitor(std::move(generator), env_, *ast, - type_inference_context, issues, &type_arena); + type_inference_context, issues, arena); TraversalOptions opts; opts.use_comprehension_callbacks = true; @@ -1310,9 +1319,10 @@ absl::StatusOr TypeCheckerImpl::Check( // Apply updates as needed. // Happens in a second pass to simplify validating that pointers haven't // been invalidated by other updates. + ValidationResult::TypeMap resolved_types; ResolveRewriter rewriter(visitor, type_inference_context, options_, ast->mutable_reference_map(), - ast->mutable_type_map()); + ast->mutable_type_map(), resolved_types); AstRewrite(ast->mutable_root_expr(), rewriter); CEL_RETURN_IF_ERROR(rewriter.status()); @@ -1325,7 +1335,15 @@ absl::StatusOr TypeCheckerImpl::Check( {cel::ExtensionSpec::Component::kRuntime})); } - return ValidationResult(std::move(ast), std::move(issues)); + auto result = ValidationResult(std::move(ast), std::move(issues)); + if (!type_arena.has_value()) { + // cel::Type values will expire after this function returns when the local + // arena is destructed. Only set the resolved type map if we're using the + // caller's arena. + result.SetResolvedTypeMap(std::move(resolved_types)); + } + + return result; } std::unique_ptr TypeCheckerImpl::ToBuilder() const { diff --git a/checker/internal/type_checker_impl.h b/checker/internal/type_checker_impl.h index 71683276d..9ee9a50d0 100644 --- a/checker/internal/type_checker_impl.h +++ b/checker/internal/type_checker_impl.h @@ -42,8 +42,8 @@ class TypeCheckerImpl : public TypeChecker { TypeCheckerImpl(TypeCheckerImpl&&) = delete; TypeCheckerImpl& operator=(TypeCheckerImpl&&) = delete; - absl::StatusOr Check( - std::unique_ptr ast) const override; + absl::StatusOr CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* arena) const override; std::unique_ptr ToBuilder() const override; diff --git a/checker/type_checker.cc b/checker/type_checker.cc new file mode 100644 index 000000000..6d59e144d --- /dev/null +++ b/checker/type_checker.cc @@ -0,0 +1,36 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/type_checker.h" + +namespace cel { +absl::StatusOr TypeChecker::Check( + std::unique_ptr ast) const { + return CheckImpl(std::move(ast), nullptr); +} + +absl::StatusOr TypeChecker::Check( + std::unique_ptr ast, google::protobuf::Arena* arena) const { + return CheckImpl(std::move(ast), arena); +} + +absl::StatusOr TypeChecker::Check(const Ast& ast) const { + return CheckImpl(std::make_unique(ast), nullptr); +} + +absl::StatusOr TypeChecker::Check( + const Ast& ast, google::protobuf::Arena* arena) const { + return CheckImpl(std::make_unique(ast), arena); +} +} // namespace cel diff --git a/checker/type_checker.h b/checker/type_checker.h index e47b7dca6..edb6cc91f 100644 --- a/checker/type_checker.h +++ b/checker/type_checker.h @@ -16,10 +16,13 @@ #define THIRD_PARTY_CEL_CPP_CHECKER_TYPE_CHECKER_H_ #include +#include +#include "absl/base/nullability.h" #include "absl/status/statusor.h" #include "checker/validation_result.h" #include "common/ast.h" +#include "google/protobuf/arena.h" namespace cel { @@ -42,13 +45,19 @@ class TypeChecker { // A non-ok status is returned if type checking can't reasonably complete // (e.g. if an internal precondition is violated or an extension returns an // error). - virtual absl::StatusOr Check( - std::unique_ptr ast) const = 0; + absl::StatusOr Check(std::unique_ptr ast) const; + absl::StatusOr Check(std::unique_ptr ast, + google::protobuf::Arena* arena) const; + absl::StatusOr Check(const Ast& ast) const; + absl::StatusOr Check(const Ast& ast, + google::protobuf::Arena* arena) const; // Returns a builder initialized with the configuration of this type checker. virtual std::unique_ptr ToBuilder() const = 0; - // TODO(uncreated-issue/73): add overload for cref AST. + private: + virtual absl::StatusOr CheckImpl( + std::unique_ptr ast, google::protobuf::Arena* absl_nullable arena) const = 0; }; } // namespace cel diff --git a/checker/validation_result.h b/checker/validation_result.h index 8c84a84da..f424e7f6f 100644 --- a/checker/validation_result.h +++ b/checker/validation_result.h @@ -15,26 +15,31 @@ #ifndef THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ #define THIRD_PARTY_CEL_CPP_CHECKER_VALIDATION_RESULT_H_ +#include #include #include #include #include #include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "checker/type_check_issue.h" #include "common/ast.h" #include "common/source.h" +#include "common/type.h" namespace cel { -// ValidationResult holds the result of TypeChecking. +// ValidationResult holds the result of type checking. // // Error states are captured as type check issues where possible. class ValidationResult { public: + using TypeMap = absl::flat_hash_map; + ValidationResult(std::unique_ptr ast, std::vector issues) : ast_(std::move(ast)), issues_(std::move(issues)) {} @@ -71,6 +76,18 @@ class ValidationResult { return std::move(source_); } + // Returns the resolved type map for the AST. + // + // Only populated if the AST was checked with an explicit arena. + // + // The type entries may have storage in the arena or reference type + // information from the type checker that produced the AST. This means the map + // is only valid as long as both the type checker and the arena are valid. + const TypeMap& GetResolvedTypeMap() const { return resolved_type_map_; } + void SetResolvedTypeMap(TypeMap resolved_type_map) { + resolved_type_map_ = std::move(resolved_type_map); + } + // Returns a string representation of the issues in the result suitable for // display. // @@ -89,6 +106,7 @@ class ValidationResult { private: absl_nullable std::unique_ptr ast_; + TypeMap resolved_type_map_; std::vector issues_; absl_nullable std::unique_ptr source_; }; diff --git a/compiler/BUILD b/compiler/BUILD index 170f1068b..50bc1c9fa 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -28,9 +28,11 @@ cc_library( "//parser:options", "//parser:parser_interface", "//validator", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", ], ) diff --git a/compiler/compiler.h b/compiler/compiler.h index 48fa4e0b1..6d07e72c2 100644 --- a/compiler/compiler.h +++ b/compiler/compiler.h @@ -19,6 +19,7 @@ #include #include +#include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -29,6 +30,7 @@ #include "parser/options.h" #include "parser/parser_interface.h" #include "validator/validator.h" +#include "google/protobuf/arena.h" namespace cel { @@ -126,10 +128,16 @@ class Compiler { virtual ~Compiler() = default; virtual absl::StatusOr Compile( - absl::string_view source, absl::string_view description) const = 0; + absl::string_view source, absl::string_view description, + google::protobuf::Arena* absl_nullable arena) const = 0; absl::StatusOr Compile(absl::string_view source) const { - return Compile(source, ""); + return Compile(source, "", nullptr); + } + + absl::StatusOr Compile( + absl::string_view source, absl::string_view description) const { + return Compile(source, description, nullptr); } // Accessor for the underlying type checker. diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc index 3e9871706..14586825e 100644 --- a/compiler/compiler_factory.cc +++ b/compiler/compiler_factory.cc @@ -33,6 +33,7 @@ #include "parser/parser.h" #include "parser/parser_interface.h" #include "validator/validator.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { @@ -50,13 +51,13 @@ class CompilerImpl : public Compiler { validator_(std::move(validator)) {} absl::StatusOr Compile( - absl::string_view expression, - absl::string_view description) const override { + absl::string_view expression, absl::string_view description, + google::protobuf::Arena* arena) const override { CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expression, std::string(description))); CEL_ASSIGN_OR_RETURN(auto ast, parser_->Parse(*source)); CEL_ASSIGN_OR_RETURN(ValidationResult result, - type_checker_->Check(std::move(ast))); + type_checker_->Check(std::move(ast), arena)); result.SetSource(std::move(source)); if (!validator_.validations().empty()) { diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc index d217e4cc7..214c23765 100644 --- a/compiler/compiler_factory_test.cc +++ b/compiler/compiler_factory_test.cc @@ -37,6 +37,7 @@ #include "parser/parser_interface.h" #include "testutil/baseline_tests.h" #include "validator/timestamp_literal_validator.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" namespace cel { @@ -390,5 +391,27 @@ TEST(CompilerFactoryTest, ToBuilderWorks) { EXPECT_TRUE(result.IsValid()); } +TEST(CompilerFactoryTest, SpecifyArenaKeepsResolvedTypes) { + ASSERT_OK_AND_ASSIGN( + auto builder, + NewCompilerBuilder(cel::internal::GetSharedTestingDescriptorPool())); + + ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); + ASSERT_THAT(builder->AddLibrary(OptionalCompilerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler->Compile("[[1, 2, 3]][?0]", "", &arena)); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + auto it = result.GetResolvedTypeMap().find(ast->root_expr().id()); + ASSERT_TRUE(it != result.GetResolvedTypeMap().end()); + EXPECT_TRUE( + it->second.IsOptional() && + it->second.GetOptional().GetParameter().IsList() && + it->second.GetOptional().GetParameter().GetList().GetElement().IsInt()); +} + } // namespace } // namespace cel From 9e73d93f77a159a149edd9a465d092cff03d702a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 30 Apr 2026 13:04:30 -0700 Subject: [PATCH 02/87] Deprecate option to switch to old accumulator var name in parser. cel::ParserOption::enable_hidden_accumulator_var now has no effect and will be removed in a later change. The standard / extension macros should always use `@result` now. PiperOrigin-RevId: 908333116 --- common/expr.h | 4 +- common/expr_factory.h | 1 - extensions/comprehensions_v2_macros.cc | 72 +++---- parser/macro.cc | 44 ++-- parser/macro_expr_factory.h | 3 +- parser/macro_expr_factory_test.cc | 2 +- parser/options.h | 4 +- parser/parser.cc | 19 +- parser/parser_test.cc | 270 ------------------------- 9 files changed, 72 insertions(+), 347 deletions(-) diff --git a/common/expr.h b/common/expr.h index 9c6f508c6..7305c2c9f 100644 --- a/common/expr.h +++ b/common/expr.h @@ -45,7 +45,9 @@ class MapExprEntry; class MapExpr; class ComprehensionExpr; -inline constexpr absl::string_view kAccumulatorVariableName = "__result__"; +inline constexpr absl::string_view kAccumulatorVariableName = "@result"; +inline constexpr absl::string_view kDeprecatedAccumulatorVariableName = + "__result__"; bool operator==(const Expr& lhs, const Expr& rhs); diff --git a/common/expr_factory.h b/common/expr_factory.h index c8a9b831f..b9769b457 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -357,7 +357,6 @@ class ExprFactory { friend class ParserMacroExprFactory; ExprFactory() : accu_var_(kAccumulatorVariableName) {} - explicit ExprFactory(absl::string_view accu_var) : accu_var_(accu_var) {} std::string accu_var_; }; diff --git a/extensions/comprehensions_v2_macros.cc b/extensions/comprehensions_v2_macros.cc index a8de3a103..134fb80ff 100644 --- a/extensions/comprehensions_v2_macros.cc +++ b/extensions/comprehensions_v2_macros.cc @@ -56,15 +56,15 @@ absl::optional ExpandAllMacro2(MacroExprFactory& factory, Expr& target, args[0], "all() second variable must be different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("all() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("all() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(true); auto condition = @@ -102,15 +102,15 @@ absl::optional ExpandExistsMacro2(MacroExprFactory& factory, Expr& target, args[0], "exists() second variable must be different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("exists() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(false); auto condition = factory.NewCall( @@ -153,15 +153,15 @@ absl::optional ExpandExistsOneMacro2(MacroExprFactory& factory, "existsOne() second variable must be different " "from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("existsOne() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("existsOne() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewIntConst(0); auto condition = factory.NewBoolConst(true); @@ -205,15 +205,15 @@ absl::optional ExpandTransformList3Macro(MacroExprFactory& factory, "transformList() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformList() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformList() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -254,15 +254,15 @@ absl::optional ExpandTransformList4Macro(MacroExprFactory& factory, "transformList() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformList() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformList() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -305,15 +305,15 @@ absl::optional ExpandTransformMap3Macro(MacroExprFactory& factory, "transformMap() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMap() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMap() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -353,15 +353,15 @@ absl::optional ExpandTransformMap4Macro(MacroExprFactory& factory, "transformMap() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMap() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMap() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -403,17 +403,17 @@ absl::optional ExpandTransformMapEntry3Macro(MacroExprFactory& factory, "transformMapEntry() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMapEntry() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMapEntry() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); @@ -453,17 +453,17 @@ absl::optional ExpandTransformMapEntry4Macro(MacroExprFactory& factory, "transformMapEntry() second variable must be " "different from the first variable"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[0], absl::StrCat("transformMapEntry() first variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } - if (args[1].ident_expr().name() == kAccumulatorVariableName) { + if (args[1].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("transformMapEntry() second variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } std::string iter_var = args[0].ident_expr().name(); std::string iter_var2 = args[1].ident_expr().name(); diff --git a/parser/macro.cc b/parser/macro.cc index eaa1ebd1a..8f8c9e596 100644 --- a/parser/macro.cc +++ b/parser/macro.cc @@ -91,10 +91,10 @@ absl::optional ExpandAllMacro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "all() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { - return factory.ReportErrorAt(args[1], - absl::StrCat("all() variable name cannot be ", - kAccumulatorVariableName)); + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("all() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(true); auto condition = @@ -123,10 +123,10 @@ absl::optional ExpandExistsMacro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "exists() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewBoolConst(false); auto condition = factory.NewCall( @@ -157,10 +157,10 @@ absl::optional ExpandExistsOneMacro(MacroExprFactory& factory, return factory.ReportErrorAt( args[0], "exists_one() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("exists_one() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewIntConst(0); auto condition = factory.NewBoolConst(true); @@ -196,10 +196,10 @@ absl::optional ExpandMap2Macro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { - return factory.ReportErrorAt(args[1], - absl::StrCat("map() variable name cannot be ", - kAccumulatorVariableName)); + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("map() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewList(); auto condition = factory.NewBoolConst(true); @@ -229,10 +229,10 @@ absl::optional ExpandMap3Macro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { - return factory.ReportErrorAt(args[1], - absl::StrCat("map() variable name cannot be ", - kAccumulatorVariableName)); + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { + return factory.ReportErrorAt( + args[1], absl::StrCat("map() variable name cannot be ", + kDeprecatedAccumulatorVariableName)); } auto init = factory.NewList(); auto condition = factory.NewBoolConst(true); @@ -264,10 +264,10 @@ absl::optional ExpandFilterMacro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "filter() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("filter() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto name = args[0].ident_expr().name(); @@ -302,10 +302,10 @@ absl::optional ExpandOptMapMacro(MacroExprFactory& factory, Expr& target, return factory.ReportErrorAt( args[0], "optMap() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("optMap() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto var_name = args[0].ident_expr().name(); @@ -341,10 +341,10 @@ absl::optional ExpandOptFlatMapMacro(MacroExprFactory& factory, return factory.ReportErrorAt( args[0], "optFlatMap() variable name must be a simple identifier"); } - if (args[0].ident_expr().name() == kAccumulatorVariableName) { + if (args[0].ident_expr().name() == kDeprecatedAccumulatorVariableName) { return factory.ReportErrorAt( args[1], absl::StrCat("optFlatMap() variable name cannot be ", - kAccumulatorVariableName)); + kDeprecatedAccumulatorVariableName)); } auto var_name = args[0].ident_expr().name(); diff --git a/parser/macro_expr_factory.h b/parser/macro_expr_factory.h index ffba5e2f2..c66aa4fe0 100644 --- a/parser/macro_expr_factory.h +++ b/parser/macro_expr_factory.h @@ -319,8 +319,7 @@ class MacroExprFactory : protected ExprFactory { friend class ParserMacroExprFactory; friend class TestMacroExprFactory; - explicit MacroExprFactory(absl::string_view accu_var) - : ExprFactory(accu_var) {} + explicit MacroExprFactory() = default; }; } // namespace cel diff --git a/parser/macro_expr_factory_test.cc b/parser/macro_expr_factory_test.cc index 04705eec6..489538be1 100644 --- a/parser/macro_expr_factory_test.cc +++ b/parser/macro_expr_factory_test.cc @@ -27,7 +27,7 @@ namespace cel { class TestMacroExprFactory final : public MacroExprFactory { public: - TestMacroExprFactory() : MacroExprFactory(kAccumulatorVariableName) {} + TestMacroExprFactory() = default; ExprId id() const { return id_; } diff --git a/parser/options.h b/parser/options.h index a41d16104..916a941f0 100644 --- a/parser/options.h +++ b/parser/options.h @@ -51,7 +51,9 @@ struct ParserOptions final { // Disable standard macros (has, all, exists, exists_one, filter, map). bool disable_standard_macros = false; - // Enable hidden accumulator variable '@result' for builtin comprehensions. + // Deprecated: The builtin and extension macros now always use the new + // accumulator variable name. + // This option has no effect. bool enable_hidden_accumulator_var = true; // Enables support for identifier quoting syntax: diff --git a/parser/parser.cc b/parser/parser.cc index d9f74e712..f4ee3a1c5 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -163,9 +163,8 @@ SourceRange SourceRangeFromParserRuleContext( class ParserMacroExprFactory final : public MacroExprFactory { public: - explicit ParserMacroExprFactory(const cel::Source& source, - absl::string_view accu_var) - : MacroExprFactory(accu_var), source_(source) {} + explicit ParserMacroExprFactory(const cel::Source& source) + : source_(source) {} void BeginMacro(SourceRange macro_position) { macro_position_ = macro_position; @@ -607,13 +606,12 @@ class ParserVisitor final : public CelBaseVisitor, public antlr4::BaseErrorListener { public: ParserVisitor(const cel::Source& source, int max_recursion_depth, - absl::string_view accu_var, const cel::MacroRegistry& macro_registry, bool add_macro_calls = false, bool enable_optional_syntax = false, bool enable_quoted_identifiers = false) : source_(source), - factory_(source_, accu_var), + factory_(source_), macro_registry_(macro_registry), recursion_depth_(0), max_recursion_depth_(max_recursion_depth), @@ -1654,14 +1652,9 @@ absl::StatusOr ParseImpl(const cel::Source& source, CommonTokenStream tokens(&lexer); CelParser parser(&tokens); ExprRecursionListener listener(options.max_recursion_depth); - absl::string_view accu_var = cel::kAccumulatorVariableName; - if (options.enable_hidden_accumulator_var) { - accu_var = cel::kHiddenAccumulatorVariableName; - } - ParserVisitor visitor(source, options.max_recursion_depth, accu_var, - registry, options.add_macro_calls, - options.enable_optional_syntax, - options.enable_quoted_identifiers); + ParserVisitor visitor( + source, options.max_recursion_depth, registry, options.add_macro_calls, + options.enable_optional_syntax, options.enable_quoted_identifiers); lexer.removeErrorListeners(); parser.removeErrorListeners(); diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 3659fd8fd..a1a65481d 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1473,7 +1473,6 @@ class ExpressionTest : public testing::TestWithParam {}; TEST_P(ExpressionTest, Parse) { const TestInfo& test_info = GetParam(); ParserOptions options; - options.enable_hidden_accumulator_var = true; if (!test_info.M.empty()) { options.add_macro_calls = true; } @@ -1628,271 +1627,6 @@ TEST(ExpressionTest, RecursionDepthIgnoresParentheses) { EXPECT_THAT(result, IsOk()); } -const std::vector& UpdatedAccuVarTestCases() { - static const std::vector* kInstance = new std::vector{ - {"[].exists(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " false^#7:bool#,\n" - " // LoopCondition\n" - " @not_strictly_false(\n" - " !_(\n" - " __result__^#8:Expr.Ident#\n" - " )^#9:Expr.Call#\n" - " )^#10:Expr.Call#,\n" - " // LoopStep\n" - " _||_(\n" - " __result__^#11:Expr.Ident#,\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#\n" - " )^#12:Expr.Call#,\n" - " // Result\n" - " __result__^#13:Expr.Ident#)^#14:Expr.Comprehension#"}, - {"[].exists_one(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " 0^#7:int64#,\n" - " // LoopCondition\n" - " true^#8:bool#,\n" - " // LoopStep\n" - " _?_:_(\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#,\n" - " _+_(\n" - " __result__^#9:Expr.Ident#,\n" - " 1^#10:int64#\n" - " )^#11:Expr.Call#,\n" - " __result__^#12:Expr.Ident#\n" - " )^#13:Expr.Call#,\n" - " // Result\n" - " _==_(\n" - " __result__^#14:Expr.Ident#,\n" - " 1^#15:int64#\n" - " )^#16:Expr.Call#)^#17:Expr.Comprehension#"}, - {"[].all(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " true^#7:bool#,\n" - " // LoopCondition\n" - " @not_strictly_false(\n" - " __result__^#8:Expr.Ident#\n" - " )^#9:Expr.Call#,\n" - " // LoopStep\n" - " _&&_(\n" - " __result__^#10:Expr.Ident#,\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#\n" - " )^#11:Expr.Call#,\n" - " // Result\n" - " __result__^#12:Expr.Ident#)^#13:Expr.Comprehension#"}, - {"[].map(x, x + 1)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " []^#7:Expr.CreateList#,\n" - " // LoopCondition\n" - " true^#8:bool#,\n" - " // LoopStep\n" - " _+_(\n" - " __result__^#9:Expr.Ident#,\n" - " [\n" - " _+_(\n" - " x^#4:Expr.Ident#,\n" - " 1^#6:int64#\n" - " )^#5:Expr.Call#\n" - " ]^#10:Expr.CreateList#\n" - " )^#11:Expr.Call#,\n" - " // Result\n" - " __result__^#12:Expr.Ident#)^#13:Expr.Comprehension#"}, - {"[].map(x, x > 0, x + 1)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " []^#10:Expr.CreateList#,\n" - " // LoopCondition\n" - " true^#11:bool#,\n" - " // LoopStep\n" - " _?_:_(\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#,\n" - " _+_(\n" - " __result__^#12:Expr.Ident#,\n" - " [\n" - " _+_(\n" - " x^#7:Expr.Ident#,\n" - " 1^#9:int64#\n" - " )^#8:Expr.Call#\n" - " ]^#13:Expr.CreateList#\n" - " )^#14:Expr.Call#,\n" - " __result__^#15:Expr.Ident#\n" - " )^#16:Expr.Call#,\n" - " // Result\n" - " __result__^#17:Expr.Ident#)^#18:Expr.Comprehension#"}, - {"[].filter(x, x > 0)", - "__comprehension__(\n" - " // Variable\n" - " x,\n" - " // Target\n" - " []^#1:Expr.CreateList#,\n" - " // Accumulator\n" - " __result__,\n" - " // Init\n" - " []^#7:Expr.CreateList#,\n" - " // LoopCondition\n" - " true^#8:bool#,\n" - " // LoopStep\n" - " _?_:_(\n" - " _>_(\n" - " x^#4:Expr.Ident#,\n" - " 0^#6:int64#\n" - " )^#5:Expr.Call#,\n" - " _+_(\n" - " __result__^#9:Expr.Ident#,\n" - " [\n" - " x^#3:Expr.Ident#\n" - " ]^#10:Expr.CreateList#\n" - " )^#11:Expr.Call#,\n" - " __result__^#12:Expr.Ident#\n" - " )^#13:Expr.Call#,\n" - " // Result\n" - " __result__^#14:Expr.Ident#)^#15:Expr.Comprehension#"}, - // Maintain restriction on '__result__' variable name until the default is - // changed everywhere. - { - "[].map(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:20: map() variable name cannot be __result__\n" - " | [].map(__result__, true)\n" - " | ...................^", - }, - { - "[].map(__result__, true, false)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:20: map() variable name cannot be __result__\n" - " | [].map(__result__, true, false)\n" - " | ...................^", - }, - { - "[].filter(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:23: filter() variable name cannot be __result__\n" - " | [].filter(__result__, true)\n" - " | ......................^", - }, - { - "[].exists(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:23: exists() variable name cannot be __result__\n" - " | [].exists(__result__, true)\n" - " | ......................^", - }, - { - "[].all(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:20: all() variable name cannot be __result__\n" - " | [].all(__result__, true)\n" - " | ...................^", - }, - { - "[].exists_one(__result__, true)", - /*.P=*/"", - /*.E=*/ - "ERROR: :1:27: exists_one() variable name cannot be " - "__result__\n" - " | [].exists_one(__result__, true)\n" - " | ..........................^", - }}; - return *kInstance; -} - -class UpdatedAccuVarDisabledTest : public testing::TestWithParam {}; - -TEST_P(UpdatedAccuVarDisabledTest, Parse) { - const TestInfo& test_info = GetParam(); - ParserOptions options; - options.enable_hidden_accumulator_var = false; - if (!test_info.M.empty()) { - options.add_macro_calls = true; - } - - auto result = - EnrichedParse(test_info.I, Macro::AllMacros(), "", options); - if (test_info.E.empty()) { - EXPECT_THAT(result, IsOk()); - } else { - EXPECT_THAT(result, Not(IsOk())); - EXPECT_EQ(test_info.E, result.status().message()); - } - - if (!test_info.P.empty()) { - KindAndIdAdorner kind_and_id_adorner; - ExprPrinter w(kind_and_id_adorner); - std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.P, adorned_string) - << result->parsed_expr().ShortDebugString(); - } - - if (!test_info.L.empty()) { - LocationAdorner location_adorner(result->parsed_expr().source_info()); - ExprPrinter w(location_adorner); - std::string adorned_string = w.PrintProto(result->parsed_expr().expr()); - EXPECT_EQ(test_info.L, adorned_string) - << result->parsed_expr().ShortDebugString(); - } - - if (!test_info.R.empty()) { - EXPECT_EQ(test_info.R, ConvertEnrichedSourceInfoToString( - result->enriched_source_info())); - } - - if (!test_info.M.empty()) { - EXPECT_EQ(test_info.M, ConvertMacroCallsToString( - result.value().parsed_expr().source_info())) - << result->parsed_expr().ShortDebugString(); - } -} - TEST(NewParserBuilderTest, Defaults) { auto builder = cel::NewParserBuilder(); ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); @@ -2058,9 +1792,5 @@ std::string TestName(const testing::TestParamInfo& test_info) { INSTANTIATE_TEST_SUITE_P(CelParserTest, ExpressionTest, testing::ValuesIn(test_cases), TestName); -INSTANTIATE_TEST_SUITE_P(UpdatedAccuVarTest, UpdatedAccuVarDisabledTest, - testing::ValuesIn(UpdatedAccuVarTestCases()), - TestName); - } // namespace } // namespace google::api::expr::parser From 1cf21eec91baa4181b481ff4bf6e25b9b5e9afe9 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 4 May 2026 10:35:24 -0700 Subject: [PATCH 03/87] Optionally track structured error messages in parser. Add a parse overload that reports issues with location to an output parameter where reasonable. This is used make the error handling more consistent when using bundled parse + typecheck. PiperOrigin-RevId: 910112013 --- compiler/BUILD | 2 + compiler/compiler.h | 2 + compiler/compiler_factory.cc | 49 +++++++++++++++----- compiler/compiler_factory_test.cc | 14 ++++++ extensions/math_ext_test.cc | 38 +++++----------- parser/BUILD | 2 + parser/parser.cc | 74 +++++++++++++++++++++---------- parser/parser_interface.h | 50 ++++++++++++++++++++- parser/parser_test.cc | 19 ++++++++ 9 files changed, 186 insertions(+), 64 deletions(-) diff --git a/compiler/BUILD b/compiler/BUILD index 50bc1c9fa..d4a0ab4ac 100644 --- a/compiler/BUILD +++ b/compiler/BUILD @@ -42,10 +42,12 @@ cc_library( hdrs = ["compiler_factory.h"], deps = [ ":compiler", + "//checker:type_check_issue", "//checker:type_checker", "//checker:type_checker_builder", "//checker:type_checker_builder_factory", "//checker:validation_result", + "//common:ast", "//common:source", "//internal:noop_delete", "//internal:status_macros", diff --git a/compiler/compiler.h b/compiler/compiler.h index 6d07e72c2..27237df60 100644 --- a/compiler/compiler.h +++ b/compiler/compiler.h @@ -97,6 +97,8 @@ struct CompilerLibrarySubset { struct CompilerOptions { ParserOptions parser_options; CheckerOptions checker_options; + // If true, parse errors will be adapted to issues where possible. + bool adapt_parser_errors = false; }; // Interface for CEL CompilerBuilder objects. diff --git a/compiler/compiler_factory.cc b/compiler/compiler_factory.cc index 14586825e..ed22c5630 100644 --- a/compiler/compiler_factory.cc +++ b/compiler/compiler_factory.cc @@ -17,16 +17,19 @@ #include #include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "checker/type_check_issue.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" #include "checker/type_checker_builder_factory.h" #include "checker/validation_result.h" +#include "common/ast.h" #include "common/source.h" #include "compiler/compiler.h" #include "internal/status_macros.h" @@ -45,19 +48,38 @@ class CompilerImpl : public Compiler { CompilerImpl(std::unique_ptr type_checker, std::unique_ptr parser, // Copy the validator in case builder is reused. - Validator validator) + Validator validator, CompilerOptions options) : type_checker_(std::move(type_checker)), parser_(std::move(parser)), - validator_(std::move(validator)) {} + validator_(std::move(validator)), + options_(options) {} absl::StatusOr Compile( absl::string_view expression, absl::string_view description, google::protobuf::Arena* arena) const override { CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expression, std::string(description))); - CEL_ASSIGN_OR_RETURN(auto ast, parser_->Parse(*source)); + std::vector parse_issues; + absl::StatusOr> ast = + parser_->Parse(*source, &parse_issues); + if (!ast.ok()) { + if (!options_.adapt_parser_errors || + ast.status().code() != absl::StatusCode::kInvalidArgument || + parse_issues.empty()) { + return ast.status(); + } + std::vector check_issues; + check_issues.reserve(parse_issues.size()); + for (const auto& issue : parse_issues) { + check_issues.push_back(TypeCheckIssue::CreateError( + issue.location(), std::string(issue.message()))); + } + ValidationResult result(std::move(check_issues)); + result.SetSource(std::move(source)); + return result; + } CEL_ASSIGN_OR_RETURN(ValidationResult result, - type_checker_->Check(std::move(ast), arena)); + type_checker_->Check(*std::move(ast), arena)); result.SetSource(std::move(source)); if (!validator_.validations().empty()) { @@ -76,16 +98,18 @@ class CompilerImpl : public Compiler { std::unique_ptr type_checker_; std::unique_ptr parser_; Validator validator_; + CompilerOptions options_; }; class CompilerBuilderImpl : public CompilerBuilder { public: CompilerBuilderImpl(std::unique_ptr type_checker_builder, std::unique_ptr parser_builder, - Validator validator = Validator()) + Validator validator, CompilerOptions options) : type_checker_builder_(std::move(type_checker_builder)), parser_builder_(std::move(parser_builder)), - validator_(std::move(validator)) {} + validator_(std::move(validator)), + options_(options) {} absl::Status AddLibrary(CompilerLibrary library) override { if (!library.id.empty()) { @@ -146,23 +170,23 @@ class CompilerBuilderImpl : public CompilerBuilder { absl::StatusOr> Build() override { CEL_ASSIGN_OR_RETURN(auto parser, parser_builder_->Build()); CEL_ASSIGN_OR_RETURN(auto type_checker, type_checker_builder_->Build()); - return std::make_unique(std::move(type_checker), - std::move(parser), validator_); + return std::make_unique( + std::move(type_checker), std::move(parser), validator_, options_); } private: std::unique_ptr type_checker_builder_; std::unique_ptr parser_builder_; Validator validator_; + CompilerOptions options_; absl::flat_hash_set library_ids_; absl::flat_hash_set subsets_; }; std::unique_ptr CompilerImpl::ToBuilder() const { - auto builder = std::make_unique( - type_checker_->ToBuilder(), parser_->ToBuilder(), validator_); - return builder; + return std::make_unique( + type_checker_->ToBuilder(), parser_->ToBuilder(), validator_, options_); } } // namespace @@ -179,7 +203,8 @@ absl::StatusOr> NewCompilerBuilder( auto parser_builder = NewParserBuilder(options.parser_options); return std::make_unique(std::move(type_checker_builder), - std::move(parser_builder)); + std::move(parser_builder), + Validator(), options); } } // namespace cel diff --git a/compiler/compiler_factory_test.cc b/compiler/compiler_factory_test.cc index 214c23765..035fd8aa6 100644 --- a/compiler/compiler_factory_test.cc +++ b/compiler/compiler_factory_test.cc @@ -413,5 +413,19 @@ TEST(CompilerFactoryTest, SpecifyArenaKeepsResolvedTypes) { it->second.GetOptional().GetParameter().GetList().GetElement().IsInt()); } +TEST(CompilerFactoryTest, ReturnsIssuesFromParser) { + CompilerOptions opts; + opts.adapt_parser_errors = true; + ASSERT_OK_AND_ASSIGN( + auto builder, NewCompilerBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + + ASSERT_OK_AND_ASSIGN(auto compiler, builder->Build()); + + ASSERT_OK_AND_ASSIGN(ValidationResult result, compiler->Compile("a +")); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), testing::Not(testing::IsEmpty())); +} + } // namespace } // namespace cel diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc index 3088e6fa8..72605648f 100644 --- a/extensions/math_ext_test.cc +++ b/extensions/math_ext_test.cc @@ -23,7 +23,6 @@ #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -110,19 +109,6 @@ struct MacroTestCase { absl::string_view err = ""; }; -std::string FormatIssues(const cel::ValidationResult& result) { - std::string issues; - for (const auto& issue : result.GetIssues()) { - if (!issues.empty()) { - absl::StrAppend(&issues, "\n", - issue.ToDisplayString(*result.GetSource())); - } else { - issues = issue.ToDisplayString(*result.GetSource()); - } - } - return issues; -} - class TestFunction : public CelFunction { public: explicit TestFunction(absl::string_view name) @@ -352,10 +338,11 @@ TEST_P(MathExtMacroParamsTest, ParserTests) { TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { const MacroTestCase& test_case = GetParam(); - - ASSERT_OK_AND_ASSIGN( - auto compiler_builder, - cel::NewCompilerBuilder(internal::GetTestingDescriptorPool())); + CompilerOptions compile_opts; + compile_opts.adapt_parser_errors = true; + ASSERT_OK_AND_ASSIGN(auto compiler_builder, + cel::NewCompilerBuilder( + internal::GetTestingDescriptorPool(), compile_opts)); ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); ASSERT_THAT(compiler_builder->AddLibrary(MathCompilerLibrary()), IsOk()); @@ -381,16 +368,16 @@ TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { ASSERT_OK_AND_ASSIGN(auto compiler, std::move(*compiler_builder).Build()); - auto result = compiler->Compile(test_case.expr, ""); + ASSERT_OK_AND_ASSIGN(auto result, + compiler->Compile(test_case.expr, "")); if (!test_case.err.empty()) { - EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr(test_case.err))); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), HasSubstr(test_case.err)); return; } - ASSERT_THAT(result, IsOk()); - ASSERT_TRUE(result->IsValid()) << FormatIssues(*result); + ASSERT_TRUE(result.IsValid()) << result.FormatError(); RuntimeOptions opts; ASSERT_OK_AND_ASSIGN( @@ -411,9 +398,8 @@ TEST_P(MathExtMacroParamsTest, ParserAndCheckerTests) { IsOk()); ASSERT_OK_AND_ASSIGN(auto runtime, std::move(runtime_builder).Build()); - - ASSERT_OK_AND_ASSIGN(auto program, - runtime->CreateProgram(*result->ReleaseAst())); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); google::protobuf::Arena arena; cel::Activation activation; diff --git a/parser/BUILD b/parser/BUILD index 63813bb59..6650d9fe9 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -244,9 +244,11 @@ cc_library( ":options", "//common:ast", "//common:source", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) diff --git a/parser/parser.cc b/parser/parser.cc index f4ee3a1c5..709e2fd41 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -112,13 +112,12 @@ struct ParserError { }; std::string DisplayParserError(const cel::Source& source, - const ParserError& error) { - auto location = - source.GetLocation(error.range.begin).value_or(SourceLocation{}); + SourceLocation location, + absl::string_view message) { return absl::StrCat(absl::StrFormat("ERROR: %s:%zu:%zu: %s", source.description(), location.line, // add one to the 0-based column - location.column + 1, error.message), + location.column + 1, message), source.DisplayErrorLocation(location)); } @@ -209,7 +208,7 @@ class ParserMacroExprFactory final : public MacroExprFactory { bool HasErrors() const { return error_count_ != 0; } - std::string ErrorMessage() { + std::vector CollectIssues() { // Errors are collected as they are encountered, not by their location // within the source. To have a more stable error message as implementation // details change, we sort the collected errors by their source location @@ -226,20 +225,23 @@ class ParserMacroExprFactory final : public MacroExprFactory { }); // Build the summary error message using the sorted errors. bool errors_truncated = error_count_ > 100; - std::vector messages; - messages.reserve( + std::vector issues; + issues.reserve( errors_.size() + errors_truncated); // Reserve space for the transform and an // additional element when truncation occurs. - std::transform(errors_.begin(), errors_.end(), std::back_inserter(messages), - [this](const ParserError& error) { - return cel::DisplayParserError(source_, error); - }); + std::transform( + errors_.begin(), errors_.end(), std::back_inserter(issues), + [this](const ParserError& error) { + auto location = + source_.GetLocation(error.range.begin).value_or(SourceLocation{}); + return cel::ParseIssue(location, error.message); + }); if (errors_truncated) { - messages.emplace_back( - absl::StrCat(error_count_ - 100, " more errors were truncated.")); + issues.push_back(cel::ParseIssue( + absl::StrCat(error_count_ - 100, " more errors were truncated."))); } - return absl::StrJoin(messages, "\n"); + return issues; } void AddMacroCall(int64_t macro_id, absl::string_view function, @@ -602,6 +604,15 @@ Expr ExpressionBalancer::BalancedTree(int lo, int hi) { return factory_.NewCall(ops_[mid], function_, std::move(arguments)); } +std::string FormatIssues(const cel::Source& source, + absl::Span issues) { + return absl::StrJoin( + issues, "\n", [&source](std::string* out, const cel::ParseIssue& issue) { + absl::StrAppend(out, cel::DisplayParserError(source, issue.location(), + issue.message())); + }); +} + class ParserVisitor final : public CelBaseVisitor, public antlr4::BaseErrorListener { public: @@ -673,7 +684,7 @@ class ParserVisitor final : public CelBaseVisitor, const std::string& msg, std::exception_ptr e) override; bool HasErrored() const; - std::string ErrorMessage(); + std::vector CollectIssues(); private: template @@ -1434,7 +1445,9 @@ void ParserVisitor::syntaxError(antlr4::Recognizer* recognizer, bool ParserVisitor::HasErrored() const { return factory_.HasErrors(); } -std::string ParserVisitor::ErrorMessage() { return factory_.ErrorMessage(); } +std::vector ParserVisitor::CollectIssues() { + return factory_.CollectIssues(); +} Expr ParserVisitor::GlobalCallOrMacroImpl(int64_t expr_id, absl::string_view function, @@ -1638,9 +1651,10 @@ struct ParseResult { EnrichedSourceInfo enriched_source_info; }; -absl::StatusOr ParseImpl(const cel::Source& source, - const cel::MacroRegistry& registry, - const ParserOptions& options) { +absl::StatusOr ParseImpl( + const cel::Source& source, const cel::MacroRegistry& registry, + const ParserOptions& options, + std::vector* parse_issues = nullptr) { try { CodePointStream input(source.content(), source.description()); if (input.size() > options.expression_size_codepoint_limit) { @@ -1673,13 +1687,23 @@ absl::StatusOr ParseImpl(const cel::Source& source, expr = ExprFromAny(visitor.visit(parser.start())); } catch (const ParseCancellationException& e) { if (visitor.HasErrored()) { - return absl::InvalidArgumentError(visitor.ErrorMessage()); + auto issues = visitor.CollectIssues(); + std::string error_message = FormatIssues(source, issues); + if (parse_issues != nullptr) { + *parse_issues = std::move(issues); + } + return absl::InvalidArgumentError(error_message); } return absl::CancelledError(e.what()); } if (visitor.HasErrored()) { - return absl::InvalidArgumentError(visitor.ErrorMessage()); + auto issues = visitor.CollectIssues(); + std::string error_message = FormatIssues(source, issues); + if (parse_issues != nullptr) { + *parse_issues = std::move(issues); + } + return absl::InvalidArgumentError(error_message); } return { @@ -1706,10 +1730,12 @@ class ParserImpl : public cel::Parser { macro_registry_(std::move(macro_registry)), library_ids_(std::move(library_ids)) {} - absl::StatusOr> Parse( - const cel::Source& source) const override { + absl::StatusOr> ParseImpl( + const cel::Source& source, + std::vector* parse_issues) const override { CEL_ASSIGN_OR_RETURN(auto parse_result, - ParseImpl(source, macro_registry_, options_)); + ::google::api::expr::parser::ParseImpl( + source, macro_registry_, options_, parse_issues)); return std::make_unique(std::move(parse_result.expr), std::move(parse_result.source_info)); } diff --git a/parser/parser_interface.h b/parser/parser_interface.h index 7cc21ff26..ad6e8ca84 100644 --- a/parser/parser_interface.h +++ b/parser/parser_interface.h @@ -16,10 +16,14 @@ #include #include +#include +#include +#include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "common/ast.h" #include "common/source.h" #include "parser/macro.h" @@ -73,6 +77,26 @@ class ParserBuilder { virtual absl::StatusOr> Build() = 0; }; +// Information about a parse failure. +class ParseIssue { + public: + explicit ParseIssue(std::string message) : message_(std::move(message)) {} + ParseIssue(SourceLocation location, std::string message) + : location_(location), message_(std::move(message)) {} + + ParseIssue(const ParseIssue& other) = default; + ParseIssue& operator=(const ParseIssue& other) = default; + ParseIssue(ParseIssue&& other) = default; + ParseIssue& operator=(ParseIssue&& other) = default; + + SourceLocation location() const { return location_; } + absl::string_view message() const { return message_; } + + private: + SourceLocation location_; + std::string message_; +}; + // Interface for stateful CEL parser objects for use with a `Compiler` // (bundled parse and type check). This is not needed for most users: // prefer using the free functions in `parser.h` for more flexibility. @@ -81,13 +105,35 @@ class Parser { virtual ~Parser() = default; // Parses the given source into a CEL AST. - virtual absl::StatusOr> Parse( - const cel::Source& source) const = 0; + absl::StatusOr> Parse( + const cel::Source& source) const; + + // Parses the given source into a CEL AST, collecting parse errors in + // `issues`. If `issues` is non-null, it will be cleared and all parse + // issues will be appended to it. + absl::StatusOr> Parse( + const cel::Source& source, std::vector* issues) const; // Returns a builder initialized with the configuration of this parser. virtual std::unique_ptr ToBuilder() const = 0; + + protected: + virtual absl::StatusOr> ParseImpl( + const cel::Source& source, + std::vector* absl_nullable parse_issues) const = 0; }; +inline absl::StatusOr> Parser::Parse( + const cel::Source& source) const { + return ParseImpl(source, nullptr); +} + +inline absl::StatusOr> Parser::Parse( + const cel::Source& source, std::vector* issues) const { + if (issues != nullptr) issues->clear(); + return ParseImpl(source, issues); +} + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_PARSER_PARSER_INTERFACE_H_ diff --git a/parser/parser_test.cc b/parser/parser_test.cc index a1a65481d..587b63a30 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1782,6 +1782,25 @@ TEST(NewParserBuilderTest, ToBuilderPreservesStdlibAndOptionalFromOptions) { EXPECT_FALSE(ast->IsChecked()); } +TEST(ParserTest, ParseFailurePopulatesIssues) { + auto builder = cel::NewParserBuilder(); + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource("a +", "test.cel")); + std::vector issues; + auto ast_result = parser->Parse(*source, &issues); + EXPECT_THAT(ast_result, Not(IsOk())); + ASSERT_THAT(issues, testing::SizeIs(1)); + EXPECT_THAT(ast_result.status().message(), + HasSubstr("ERROR: test.cel:1:4: Syntax error: mismatched input " + "'' expecting")); + EXPECT_THAT(issues[0].message(), + HasSubstr("Syntax error: mismatched input '' expecting")); + EXPECT_EQ(issues[0].location().line, 1); + // 0-based, but adjusted to 1-based in error message. + EXPECT_EQ(issues[0].location().column, 3); +} + std::string TestName(const testing::TestParamInfo& test_info) { std::string name = absl::StrCat(test_info.index, "-", test_info.param.I); absl::c_replace_if(name, [](char c) { return !absl::ascii_isalnum(c); }, '_'); From 2f06d90f5b593269c2b1f58de3bfd5c8fc2fa895 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 4 May 2026 12:32:22 -0700 Subject: [PATCH 04/87] Add checker support for block. This is needed for re-checking expressions that were produced as a part of policy compilation. PiperOrigin-RevId: 910179322 --- checker/internal/BUILD | 6 + checker/internal/type_checker_impl.cc | 77 ++++++++- checker/internal/type_checker_impl_test.cc | 95 +++++++++++ conformance/BUILD | 5 +- conformance/service.cc | 115 +------------- extensions/BUILD | 5 +- extensions/bindings_ext.cc | 32 +++- extensions/bindings_ext.h | 6 +- testutil/BUILD | 20 +++ testutil/test_macros.cc | 175 +++++++++++++++++++++ testutil/test_macros.h | 33 ++++ 11 files changed, 446 insertions(+), 123 deletions(-) create mode 100644 testutil/test_macros.cc create mode 100644 testutil/test_macros.h diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 1c560cdb9..f4c60f937 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -155,6 +155,7 @@ cc_library( "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -179,6 +180,7 @@ cc_test( "//checker:type_checker_builder", "//checker:validation_result", "//common:ast", + "//common:ast_proto", "//common:container", "//common:decl", "//common:expr", @@ -187,13 +189,17 @@ cc_test( "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", + "//parser", + "//parser:macro_registry", "//testutil:baseline_tests", + "//testutil:test_macros", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 05601fdbb..2472d7def 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -25,6 +25,7 @@ #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -59,6 +60,15 @@ namespace cel::checker_internal { namespace { +bool MatchesBlock(const Expr& expr) { + if (!expr.has_call_expr()) { + return false; + } + const auto& call = expr.call_expr(); + return call.function() == "cel.@block" && call.args().size() == 2 && + call.args()[0].has_list_expr(); +} + using AstType = cel::TypeSpec; using Severity = TypeCheckIssue::Severity; @@ -204,13 +214,23 @@ class ResolveVisitor : public AstVisitorBase { arena_(arena), current_scope_(&root_scope_) {} - void PreVisitExpr(const Expr& expr) override { expr_stack_.push_back(&expr); } + void PreVisitExpr(const Expr& expr) override { + expr_stack_.push_back(&expr); + if (expr_stack_.size() == 1 && MatchesBlock(expr)) { + ABSL_DCHECK_EQ(expr.call_expr().args().size(), 2); + ABSL_DCHECK(block_init_list_ == nullptr); + block_init_list_ = &expr.call_expr().args()[0]; + } + } void PostVisitExpr(const Expr& expr) override { if (expr_stack_.empty()) { return; } expr_stack_.pop_back(); + if (expr_stack_.size() == 2 && expr_stack_.back() == block_init_list_) { + HandleBlockIndex(&expr); + } } void PostVisitConst(const Expr& expr, const Constant& constant) override; @@ -389,6 +409,7 @@ class ResolveVisitor : public AstVisitorBase { absl::string_view field_name); void HandleOptSelect(const Expr& expr); + void HandleBlockIndex(const Expr* expr); // Get the assigned type of the given subexpression. Should only be called if // the given subexpression is expected to have already been checked. @@ -421,6 +442,7 @@ class ResolveVisitor : public AstVisitorBase { std::vector expr_stack_; absl::flat_hash_map> maybe_namespaced_functions_; + const Expr* block_init_list_ = nullptr; // Select operations that need to be resolved outside of the traversal. // These are handled separately to disambiguate between namespaces and field // accesses @@ -609,8 +631,15 @@ void ResolveVisitor::PostVisitMap(const Expr& expr, const MapExpr& map) { } void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) { - // Follows list type inferencing behavior in Go (see map comments above). + if (&expr == block_init_list_) { + // Don't try to coalesce list type here because it can influence the + // resolved type of the list elements. cel.@block is always list and + // the elements are treated independently at runtime. + types_[&expr] = ListType(); + return; + } + // Follows list type inferencing behavior in Go (see map comments above). Type overall_elem_type = inference_context_->InstantiateTypeParams(TypeParamType("E")); auto assignability_context = inference_context_->CreateAssignabilityContext(); @@ -1172,6 +1201,44 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) { } } +void ResolveVisitor::HandleBlockIndex(const Expr* expr) { + ABSL_DCHECK(block_init_list_ != nullptr); + ABSL_DCHECK(block_init_list_->has_list_expr()); + const auto& elements = block_init_list_->list_expr().elements(); + int index = -1; + for (size_t i = 0; i < elements.size(); ++i) { + if (&elements[i].expr() == expr) { + index = i; + break; + } + } + if (index < 0) { + status_.Update(absl::InternalError( + "could not resolve expression as a cel.@block subexpression")); + return; + } + std::string var_name = absl::StrCat("@index", index); + + // Block is typically manually assembled from logically separate + // expressions so fix the type instead of inferring any remaining free type + // params as for normal subexpressions. + auto type = inference_context_->FinalizeType(GetDeducedType(expr)); + + VariableDecl decl = MakeVariableDecl(var_name, std::move(type)); + + // The C++ runtime requires that the indexes are topologically ordered. + // They just come into scope in order as we walk the AST so we don't need + // to do any additional work to check references to other initializers in + // an init expr. + // + // TODO(uncreated-issue/90): This is slightly inconsistent with the java + // runtime implementation which just requires the references to be acyclic. + auto* scope = + comprehension_vars_.emplace_back(current_scope_->MakeNestedScope()).get(); + scope->InsertVariableIfAbsent(std::move(decl)); + current_scope_ = scope; +} + class ResolveRewriter : public AstRewriterBase { public: explicit ResolveRewriter(const ResolveVisitor& visitor, @@ -1230,15 +1297,15 @@ class ResolveRewriter : public AstRewriterBase { if (auto iter = visitor_.types().find(&expr); iter != visitor_.types().end()) { - auto flattened_type = - FlattenType(inference_context_.FinalizeType(iter->second)); + cel::Type finalized_type = inference_context_.FinalizeType(iter->second); + auto flattened_type = FlattenType(finalized_type); if (!flattened_type.ok()) { status_.Update(flattened_type.status()); return rewritten; } type_map_[expr.id()] = *std::move(flattened_type); - resolved_types_[expr.id()] = iter->second; + resolved_types_[expr.id()] = finalized_type; rewritten = true; } diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index e6cd641d6..893f0689d 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -26,6 +26,7 @@ #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -36,6 +37,7 @@ #include "checker/type_checker_builder.h" #include "checker/validation_result.h" #include "common/ast.h" +#include "common/ast_proto.h" #include "common/container.h" #include "common/decl.h" #include "common/expr.h" @@ -45,7 +47,10 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" +#include "parser/macro_registry.h" +#include "parser/parser.h" #include "testutil/baseline_tests.h" +#include "testutil/test_macros.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" @@ -108,6 +113,17 @@ google::protobuf::Arena* absl_nonnull TestTypeArena() { return &(*kArena); } +absl::StatusOr> MakeTestParsedAstWithMacros( + absl::string_view expression, const cel::MacroRegistry& registry) { + CEL_ASSIGN_OR_RETURN( + auto source, + cel::NewSource(expression, /*description=*/std::string(expression))); + CEL_ASSIGN_OR_RETURN(auto parsed_expr, google::api::expr::parser::Parse( + *source, registry, + {.enable_optional_syntax = true})); + return cel::CreateAstFromParsedExpr(parsed_expr); +} + FunctionDecl MakeIdentFunction() { auto decl = MakeFunctionDecl( "identity", @@ -272,6 +288,12 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena /*return_type=*/TypeType(arena, TypeParamType("A")), TypeParamType("A")))); + Type kParam(TypeParamType("T")); + CEL_ASSIGN_OR_RETURN( + auto block_decl, + MakeFunctionDecl("cel.@block", MakeOverloadDecl("cel_block_list", kParam, + ListType(), kParam))); + env.InsertFunctionIfAbsent(std::move(not_op)); env.InsertFunctionIfAbsent(std::move(not_strictly_false)); env.InsertFunctionIfAbsent(std::move(add_op)); @@ -289,6 +311,7 @@ absl::Status RegisterMinimalBuiltins(google::protobuf::Arena* absl_nonnull arena env.InsertFunctionIfAbsent(std::move(to_type)); env.InsertFunctionIfAbsent(std::move(to_duration)); env.InsertFunctionIfAbsent(std::move(to_timestamp)); + env.InsertFunctionIfAbsent(std::move(block_decl)); return absl::OkStatus(); } @@ -308,6 +331,78 @@ TEST(TypeCheckerImplTest, SmokeTest) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } +TEST(TypeCheckerImplTest, BlockMacroSupport) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAstWithMacros( + "cel.block([1, 2], cel.index(0) + cel.index(1))", registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // Overall type should be int. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + auto root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(), + PrimitiveType::kInt64); +} + +TEST(TypeCheckerImplTest, BlockMacroSupportMixedTypes) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(1))", + registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + // cel.index(1) refers to 'a' which is string. + // So overall type should be string. + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + auto root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->type_map().at(root_id).primitive(), + PrimitiveType::kString); +} + +TEST(TypeCheckerImplTest, BadIndex) { + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + MacroRegistry registry; + ASSERT_THAT(cel::test::RegisterTestMacros(registry), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAstWithMacros("cel.block([1, 'a'], cel.index(2))", + registry)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatError(), + HasSubstr("undeclared reference to '@index2' (in container")); +} + TEST(TypeCheckerImplTest, SimpleIdentsResolved) { TypeCheckEnv env(GetSharedTestingDescriptorPool()); diff --git a/conformance/BUILD b/conformance/BUILD index 139739891..9b527cf35 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -69,6 +69,7 @@ cc_library( "//runtime:reference_resolver", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", + "//testutil:test_macros", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -221,7 +222,7 @@ _TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [ "proto3/set_null/list_value", "proto3/set_null/single_struct", - # cel.@block + # no optional support for legacy types "block_ext/basic/optional_list", "block_ext/basic/optional_map", "block_ext/basic/optional_map_chained", @@ -231,7 +232,7 @@ _TESTS_TO_SKIP_LEGACY = _TESTS_TO_SKIP + [ _TESTS_TO_SKIP_CHECKED = [ # block is a post-check optimization that inserts internal variables. The C++ type checker # needs support for a proper optimizer for this to work. - "block_ext", + # "block_ext", ] _TESTS_TO_SKIP_LEGACY_DASHBOARD = [ diff --git a/conformance/service.cc b/conformance/service.cc index 3edc214e6..463334bb5 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -14,7 +14,6 @@ #include "conformance/service.h" -#include #include #include #include @@ -36,11 +35,8 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" #include "checker/optional.h" #include "checker/standard_library.h" #include "checker/type_checker_builder.h" @@ -48,7 +44,6 @@ #include "common/ast.h" #include "common/ast_proto.h" #include "common/decl_proto_v1alpha1.h" -#include "common/expr.h" #include "common/internal/value_conversion.h" #include "common/source.h" #include "common/value.h" @@ -72,8 +67,6 @@ #include "extensions/select_optimization.h" #include "extensions/strings.h" #include "internal/status_macros.h" -#include "parser/macro.h" -#include "parser/macro_expr_factory.h" #include "parser/macro_registry.h" #include "parser/options.h" #include "parser/parser.h" @@ -85,6 +78,7 @@ #include "runtime/runtime.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" +#include "testutil/test_macros.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "cel/expr/conformance/proto2/test_all_types_extensions.pb.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" @@ -106,109 +100,6 @@ namespace google::api::expr::runtime { namespace { -bool IsCelNamespace(const cel::Expr& target) { - return target.has_ident_expr() && target.ident_expr().name() == "cel"; -} - -absl::optional CelBlockMacroExpander(cel::MacroExprFactory& factory, - cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& bindings_arg = args[0]; - if (!bindings_arg.has_list_expr()) { - return factory.ReportErrorAt( - bindings_arg, "cel.block requires the first arg to be a list literal"); - } - return factory.NewCall("cel.@block", args); -} - -absl::optional CelIndexMacroExpander(cel::MacroExprFactory& factory, - cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& index_arg = args[0]; - if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { - return factory.ReportErrorAt( - index_arg, "cel.index requires a single non-negative int constant arg"); - } - int64_t index = index_arg.const_expr().int_value(); - if (index < 0) { - return factory.ReportErrorAt( - index_arg, "cel.index requires a single non-negative int constant arg"); - } - return factory.NewIdent(absl::StrCat("@index", index)); -} - -absl::optional CelIterVarMacroExpander( - cel::MacroExprFactory& factory, cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& depth_arg = args[0]; - if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || - depth_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - depth_arg, "cel.iterVar requires two non-negative int constant args"); - } - cel::Expr& unique_arg = args[1]; - if (!unique_arg.has_const_expr() || - !unique_arg.const_expr().has_int_value() || - unique_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - unique_arg, "cel.iterVar requires two non-negative int constant args"); - } - return factory.NewIdent( - absl::StrCat("@it:", depth_arg.const_expr().int_value(), ":", - unique_arg.const_expr().int_value())); -} - -absl::optional CelAccuVarMacroExpander( - cel::MacroExprFactory& factory, cel::Expr& target, - absl::Span args) { - if (!IsCelNamespace(target)) { - return absl::nullopt; - } - cel::Expr& depth_arg = args[0]; - if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || - depth_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - depth_arg, "cel.accuVar requires two non-negative int constant args"); - } - cel::Expr& unique_arg = args[1]; - if (!unique_arg.has_const_expr() || - !unique_arg.const_expr().has_int_value() || - unique_arg.const_expr().int_value() < 0) { - return factory.ReportErrorAt( - unique_arg, "cel.accuVar requires two non-negative int constant args"); - } - return factory.NewIdent( - absl::StrCat("@ac:", depth_arg.const_expr().int_value(), ":", - unique_arg.const_expr().int_value())); -} - -absl::Status RegisterCelBlockMacros(cel::MacroRegistry& registry) { - CEL_ASSIGN_OR_RETURN(auto block_macro, - cel::Macro::Receiver("block", 2, CelBlockMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(block_macro)); - CEL_ASSIGN_OR_RETURN(auto index_macro, - cel::Macro::Receiver("index", 1, CelIndexMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(index_macro)); - CEL_ASSIGN_OR_RETURN( - auto iter_var_macro, - cel::Macro::Receiver("iterVar", 2, CelIterVarMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(iter_var_macro)); - CEL_ASSIGN_OR_RETURN( - auto accu_var_macro, - cel::Macro::Receiver("accuVar", 2, CelAccuVarMacroExpander)); - CEL_RETURN_IF_ERROR(registry.RegisterMacro(accu_var_macro)); - return absl::OkStatus(); -} - google::rpc::Code ToGrpcCode(absl::StatusCode code) { return static_cast(code); } @@ -250,7 +141,7 @@ absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, CEL_RETURN_IF_ERROR(cel::extensions::RegisterBindingsMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathMacros(macros, options)); CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtoMacros(macros, options)); - CEL_RETURN_IF_ERROR(RegisterCelBlockMacros(macros)); + CEL_RETURN_IF_ERROR(cel::test::RegisterTestMacros(macros)); CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(request.cel_source(), request.source_location())); CEL_ASSIGN_OR_RETURN(auto parsed_expr, @@ -285,6 +176,8 @@ absl::Status CheckImpl(google::protobuf::Arena* arena, if (!request.no_std_env()) { CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCheckerLibrary())); CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCheckerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCheckerLibrary())); CEL_RETURN_IF_ERROR( builder->AddLibrary(cel::extensions::StringsCheckerLibrary())); CEL_RETURN_IF_ERROR( diff --git a/extensions/BUILD b/extensions/BUILD index ff37e2c3f..05104a4a5 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -215,7 +215,10 @@ cc_library( srcs = ["bindings_ext.cc"], hdrs = ["bindings_ext.h"], deps = [ - "//common:ast", + "//checker:type_checker_builder", + "//common:decl", + "//common:expr", + "//common:type", "//compiler", "//internal:status_macros", "//parser:macro", diff --git a/extensions/bindings_ext.cc b/extensions/bindings_ext.cc index f097709ca..c59f724bd 100644 --- a/extensions/bindings_ext.cc +++ b/extensions/bindings_ext.cc @@ -21,7 +21,10 @@ #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" -#include "common/ast.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/type.h" #include "compiler/compiler.h" #include "internal/status_macros.h" #include "parser/macro.h" @@ -34,6 +37,8 @@ namespace { static constexpr char kCelNamespace[] = "cel"; static constexpr char kBind[] = "bind"; +static constexpr char kBlock[] = "cel.@block"; +static constexpr char kBlockOverloadId[] = "cel_block_list"; static constexpr char kUnusedIterVar[] = "#unused"; bool IsTargetNamespace(const Expr& target) { @@ -47,6 +52,19 @@ inline absl::Status ConfigureParser(ParserBuilder& parser_builder) { return absl::OkStatus(); } +absl::Status ConfigureChecker(int version, + TypeCheckerBuilder& type_checker_builder) { + if (version < 1) { + return absl::OkStatus(); + } + static Type kParam(TypeParamType("T")); + CEL_ASSIGN_OR_RETURN( + auto decl, + MakeFunctionDecl(kBlock, MakeOverloadDecl(kBlockOverloadId, kParam, + ListType(), kParam))); + return type_checker_builder.AddFunction(std::move(decl)); +} + } // namespace std::vector bindings_macros() { @@ -70,8 +88,16 @@ std::vector bindings_macros() { return {*cel_bind}; } -CompilerLibrary BindingsCompilerLibrary() { - return CompilerLibrary("cel.lib.ext.bindings", &ConfigureParser); +CompilerLibrary BindingsCompilerLibrary(int version) { + return CompilerLibrary( + "cel.lib.ext.bindings", &ConfigureParser, + [version](auto& b) { return ConfigureChecker(version, b); }); +} + +CheckerLibrary BindingsCheckerLibrary(int version) { + return CheckerLibrary{"cel.lib.ext.bindings", [version](auto& b) { + return ConfigureChecker(version, b); + }}; } } // namespace cel::extensions diff --git a/extensions/bindings_ext.h b/extensions/bindings_ext.h index a338b24f6..40b83a37f 100644 --- a/extensions/bindings_ext.h +++ b/extensions/bindings_ext.h @@ -25,6 +25,7 @@ namespace cel::extensions { +constexpr int kBindingsVersionLatest = 1; // bindings_macros() returns a macro for cel.bind() which can be used to support // local variable bindings within expressions. std::vector bindings_macros(); @@ -35,7 +36,10 @@ inline absl::Status RegisterBindingsMacros(MacroRegistry& registry, } // Declarations for the bindings extension library. -CompilerLibrary BindingsCompilerLibrary(); +CompilerLibrary BindingsCompilerLibrary(int version = kBindingsVersionLatest); + +// Declarations for the bindings extension library. +CheckerLibrary BindingsCheckerLibrary(int version = kBindingsVersionLatest); } // namespace cel::extensions diff --git a/testutil/BUILD b/testutil/BUILD index 292696033..782c95ca6 100644 --- a/testutil/BUILD +++ b/testutil/BUILD @@ -62,6 +62,26 @@ cc_library( deps = ["//internal:proto_matchers"], ) +cc_library( + name = "test_macros", + testonly = True, + srcs = ["test_macros.cc"], + hdrs = ["test_macros.h"], + deps = [ + "//common:expr", + "//internal:status_macros", + "//parser:macro", + "//parser:macro_expr_factory", + "//parser:macro_registry", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "baseline_tests", testonly = True, diff --git a/testutil/test_macros.cc b/testutil/test_macros.cc new file mode 100644 index 000000000..158135762 --- /dev/null +++ b/testutil/test_macros.cc @@ -0,0 +1,175 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "testutil/test_macros.h" + +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/expr.h" +#include "internal/status_macros.h" +#include "parser/macro.h" +#include "parser/macro_expr_factory.h" +#include "parser/macro_registry.h" + +namespace cel::test { + +namespace { + +bool IsCelNamespace(const Expr& target) { + return target.has_ident_expr() && target.ident_expr().name() == "cel"; +} + +absl::optional CelBlockMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& bindings_arg = args[0]; + if (!bindings_arg.has_list_expr()) { + return factory.ReportErrorAt( + bindings_arg, "cel.block requires the first arg to be a list literal"); + } + return factory.NewCall("cel.@block", args); +} + +absl::optional CelIndexMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& index_arg = args[0]; + if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + int64_t index = index_arg.const_expr().int_value(); + if (index < 0) { + return factory.ReportErrorAt( + index_arg, "cel.index requires a single non-negative int constant arg"); + } + return factory.NewIdent(absl::StrCat("@index", index)); +} + +absl::optional CelIterVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& depth_arg = args[0]; + if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || + depth_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + depth_arg, "cel.iterVar requires two non-negative int constant args"); + } + Expr& unique_arg = args[1]; + if (!unique_arg.has_const_expr() || + !unique_arg.const_expr().has_int_value() || + unique_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + unique_arg, "cel.iterVar requires two non-negative int constant args"); + } + return factory.NewIdent( + absl::StrCat("@it:", depth_arg.const_expr().int_value(), ":", + unique_arg.const_expr().int_value())); +} + +absl::optional CelAccuVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { + if (!IsCelNamespace(target)) { + return absl::nullopt; + } + Expr& depth_arg = args[0]; + if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || + depth_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + depth_arg, "cel.accuVar requires two non-negative int constant args"); + } + Expr& unique_arg = args[1]; + if (!unique_arg.has_const_expr() || + !unique_arg.const_expr().has_int_value() || + unique_arg.const_expr().int_value() < 0) { + return factory.ReportErrorAt( + unique_arg, "cel.accuVar requires two non-negative int constant args"); + } + return factory.NewIdent( + absl::StrCat("@ac:", depth_arg.const_expr().int_value(), ":", + unique_arg.const_expr().int_value())); +} + +Macro MakeCelBlockMacro() { + auto macro_or_status = Macro::Receiver("block", 2, CelBlockMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelIndexMacro() { + auto macro_or_status = Macro::Receiver("index", 1, CelIndexMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelIterVarMacro() { + auto macro_or_status = Macro::Receiver("iterVar", 2, CelIterVarMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +Macro MakeCelAccuVarMacro() { + auto macro_or_status = Macro::Receiver("accuVar", 2, CelAccuVarMacroExpander); + ABSL_CHECK_OK(macro_or_status); // Crash OK + return std::move(*macro_or_status); +} + +} // namespace + +const Macro& CelBlockMacro() { + static const absl::NoDestructor macro(MakeCelBlockMacro()); + return *macro; +} + +const Macro& CelIndexMacro() { + static const absl::NoDestructor macro(MakeCelIndexMacro()); + return *macro; +} + +const Macro& CelIterVarMacro() { + static const absl::NoDestructor macro(MakeCelIterVarMacro()); + return *macro; +} + +const Macro& CelAccuVarMacro() { + static const absl::NoDestructor macro(MakeCelAccuVarMacro()); + return *macro; +} + +absl::Status RegisterTestMacros(MacroRegistry& registry) { + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelBlockMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelIndexMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelIterVarMacro())); + CEL_RETURN_IF_ERROR(registry.RegisterMacro(CelAccuVarMacro())); + return absl::OkStatus(); +} + +} // namespace cel::test diff --git a/testutil/test_macros.h b/testutil/test_macros.h new file mode 100644 index 000000000..cad897999 --- /dev/null +++ b/testutil/test_macros.h @@ -0,0 +1,33 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ + +#include "absl/status/status.h" +#include "parser/macro.h" +#include "parser/macro_registry.h" + +namespace cel::test { + +const Macro& CelBlockMacro(); +const Macro& CelIndexMacro(); +const Macro& CelIterVarMacro(); +const Macro& CelAccuVarMacro(); + +absl::Status RegisterTestMacros(MacroRegistry& registry); + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_TESTUTIL_TEST_MACROS_H_ From ddcece1479ed78ccde1594e47d94eeb841de115f Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 6 May 2026 13:56:21 -0700 Subject: [PATCH 05/87] Refactor optional dispatch tables. PiperOrigin-RevId: 911533283 --- common/values/optional_value.cc | 255 ++++++++++++++++---------------- 1 file changed, 124 insertions(+), 131 deletions(-) diff --git a/common/values/optional_value.cc b/common/values/optional_value.cc index ad0a65efb..688cf8fb0 100644 --- a/common/values/optional_value.cc +++ b/common/values/optional_value.cc @@ -122,200 +122,185 @@ absl::Status OptionalValueEqual( return absl::OkStatus(); } +google::protobuf::Arena* absl_nullable OptionalValueGetArenaNull( + const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent) { + return nullptr; +} + +OpaqueValue OptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) { + return common_internal::MakeOptionalValue(dispatcher, content); +} + +bool OptionalValueHasNoValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content) { + return false; +} + +void EmptyOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = + ErrorValue(absl::FailedPreconditionError("optional.none() dereference")); +} + +void NullOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = NullValue(); +} + +void BoolOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = BoolValue(content.To()); +} + +void IntOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = IntValue(content.To()); +} + +void UintOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UintValue(content.To()); +} + +void DoubleOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = DoubleValue(content.To()); +} + +void DurationOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UnsafeDurationValue(content.To()); +} + +void TimestampOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = UnsafeTimestampValue(content.To()); +} + ABSL_CONST_INIT const OptionalValueDispatcher empty_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, - }, - [](const OptionalValueDispatcher* absl_nonnull dispatcher, - CustomValueContent content) -> bool { return false; }, - [](const OptionalValueDispatcher* absl_nonnull dispatcher, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = ErrorValue( - absl::FailedPreconditionError("optional.none() dereference")); + .clone = &OptionalValueClone, }, + &OptionalValueHasNoValue, + &EmptyOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher null_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent, - cel::Value* absl_nonnull result) -> void { *result = NullValue(); }, + &NullOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher bool_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = BoolValue(content.To()); - }, + &BoolOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher int_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = IntValue(content.To()); - }, + &IntOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher uint_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) -> google::protobuf::Arena* absl_nullable { - return nullptr; - }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = UintValue(content.To()); - }, + &UintOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher double_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = DoubleValue(content.To()); - }, + &DoubleOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher duration_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = UnsafeDurationValue(content.To()); - }, + &DurationOptionalValueValue, }; ABSL_CONST_INIT const OptionalValueDispatcher timestamp_optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent) - -> google::protobuf::Arena* absl_nullable { return nullptr; }, + .get_arena = &OptionalValueGetArenaNull, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - return common_internal::MakeOptionalValue(dispatcher, content); - }, + .clone = &OptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, - CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = UnsafeTimestampValue(content.To()); - }, + &TimestampOptionalValueValue, }; struct OptionalValueContent { @@ -323,43 +308,51 @@ struct OptionalValueContent { google::protobuf::Arena* absl_nonnull arena; }; +google::protobuf::Arena* absl_nullable GenericOptionalValueGetArena( + const OpaqueValueDispatcher* absl_nonnull, OpaqueValueContent content) { + return content.To().arena; +} + +OpaqueValue GenericOptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena); + +void GenericOptionalValueValue(const OptionalValueDispatcher* absl_nonnull, + CustomValueContent content, + cel::Value* absl_nonnull result) { + *result = *content.To().value; +} + ABSL_CONST_INIT const OptionalValueDispatcher optional_value_dispatcher = { { .get_type_id = &OptionalValueGetTypeId, - .get_arena = - [](const OpaqueValueDispatcher* absl_nonnull, - OpaqueValueContent content) -> google::protobuf::Arena* absl_nullable { - return content.To().arena; - }, + .get_arena = &GenericOptionalValueGetArena, .get_type_name = &OptionalValueGetTypeName, .debug_string = &OptionalValueDebugString, .get_runtime_type = &OptionalValueGetRuntimeType, .equal = &OptionalValueEqual, - .clone = [](const OpaqueValueDispatcher* absl_nonnull dispatcher, - OpaqueValueContent content, - google::protobuf::Arena* absl_nonnull arena) -> OpaqueValue { - ABSL_DCHECK(arena != nullptr); - - cel::Value* absl_nonnull result = ::new ( - arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) - cel::Value( - content.To().value->Clone(arena)); - if (!ArenaTraits<>::trivially_destructible(result)) { - arena->OwnDestructor(result); - } - return common_internal::MakeOptionalValue( - &optional_value_dispatcher, - OpaqueValueContent::From( - OptionalValueContent{.value = result, .arena = arena})); - }, + .clone = &GenericOptionalValueClone, }, &OptionalValueHasValue, - [](const OptionalValueDispatcher* absl_nonnull, CustomValueContent content, - cel::Value* absl_nonnull result) -> void { - *result = *content.To().value; - }, + &GenericOptionalValueValue, }; +OpaqueValue GenericOptionalValueClone( + const OpaqueValueDispatcher* absl_nonnull dispatcher, + OpaqueValueContent content, google::protobuf::Arena* absl_nonnull arena) { + ABSL_DCHECK(arena != nullptr); + + cel::Value* absl_nonnull result = + ::new (arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) + cel::Value(content.To().value->Clone(arena)); + if (!ArenaTraits<>::trivially_destructible(result)) { + arena->OwnDestructor(result); + } + return common_internal::MakeOptionalValue( + &optional_value_dispatcher, OpaqueValueContent::From(OptionalValueContent{ + .value = result, .arena = arena})); +} + } // namespace OptionalValue OptionalValue::Of(cel::Value value, From 5806d30ba86ca40d8ab111e59fa78983afe5319c Mon Sep 17 00:00:00 2001 From: Justin King Date: Fri, 8 May 2026 13:03:46 -0700 Subject: [PATCH 06/87] Update conformance test skip list PiperOrigin-RevId: 912656156 --- conformance/BUILD | 20 ++++++++++++++++++++ conformance/run.bzl | 6 +++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index 9b527cf35..726a11b0b 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -201,6 +201,26 @@ _TESTS_TO_SKIP = [ # precision to preserve value. Not available on older compilers where we just use absl::Format. # We should probably update the spec to allow different formats that parse to the same value. "conversions/string/double_hard", + + # Recent changes + "proto2/set_null/repeated_field_timestamp_null_pruned", + "proto2/set_null/repeated_field_duration_null_pruned", + "proto2/set_null/repeated_field_wrapper_null_pruned", + "proto2/set_null/map_timestamp_null_pruned", + "proto2/set_null/map_duration_null_pruned", + "proto2/set_null/map_wrapper_null_pruned", + "proto3/set_null/repeated_field_timestamp_null_pruned", + "proto3/set_null/repeated_field_duration_null_pruned", + "proto3/set_null/repeated_field_wrapper_null_pruned", + "proto3/set_null/map_timestamp_null_pruned", + "proto3/set_null/map_duration_null_pruned", + "proto3/set_null/map_wrapper_null_pruned", + "string_ext/format/default precision for fixed-point clause with int", + "string_ext/format/default precision for fixed-point clause with uint", + "string_ext/format/default precision for scientific notation with int", + "string_ext/format/default precision for scientific notation with uint", + "namespace/namespace_shadowing/basic", + "namespace/namespace_shadowing/comprehension_shadowing_namespaced_selector_disambiguation", ] _TESTS_TO_SKIP_MODERN = _TESTS_TO_SKIP diff --git a/conformance/run.bzl b/conformance/run.bzl index 4fcf325c6..d53fd539c 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -70,7 +70,7 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, args.append("--skip_check") else: args.append("--noskip_check") - args.append("--skip_tests={}".format(",".join(_expand_tests_to_skip(skip_tests)))) + args.append("--skip_tests=\"{}\"".format(",".join(_expand_tests_to_skip(skip_tests)))) if dashboard: args.append("--dashboard") return args @@ -80,8 +80,8 @@ def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_ name = _conformance_test_name(name, optimize, recursive), args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data] + select( { - "@platforms//os:windows": ["--skip_tests={}".format(",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS))], - "//conditions:default": ["--skip_tests={}".format(",".join(skip_tests))], + "@platforms//os:windows": ["--skip_tests=\"{}\"".format(",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS))], + "//conditions:default": ["--skip_tests=\"{}\"".format(",".join(skip_tests))], }, ), data = data, From cb9dc8a2e71e503655b1992bdba3debc7fda12a7 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 8 May 2026 15:45:31 -0700 Subject: [PATCH 07/87] Fix command line argument splitting issue for conformance tests. PiperOrigin-RevId: 912731724 --- conformance/run.bzl | 10 +++++----- conformance/run.cc | 8 ++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/conformance/run.bzl b/conformance/run.bzl index d53fd539c..15850b0aa 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -56,7 +56,7 @@ def _conformance_test_name(name, optimize, recursive): ], ) -def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard): +def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard): args = [] if modern: args.append("--modern") @@ -70,7 +70,6 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, args.append("--skip_check") else: args.append("--noskip_check") - args.append("--skip_tests=\"{}\"".format(",".join(_expand_tests_to_skip(skip_tests)))) if dashboard: args.append("--dashboard") return args @@ -78,10 +77,11 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard): cc_test( name = _conformance_test_name(name, optimize, recursive), - args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data] + select( + args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard) + ["$(location " + test + ")" for test in data], + env = select( { - "@platforms//os:windows": ["--skip_tests=\"{}\"".format(",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS))], - "//conditions:default": ["--skip_tests=\"{}\"".format(",".join(skip_tests))], + "@platforms//os:windows": {"CEL_SKIP_TESTS": ",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS)}, + "//conditions:default": {"CEL_SKIP_TESTS": ",".join(skip_tests)}, }, ), data = data, diff --git a/conformance/run.cc b/conformance/run.cc index d5a919d76..80164d9a4 100644 --- a/conformance/run.cc +++ b/conformance/run.cc @@ -42,6 +42,7 @@ #include "absl/strings/cord.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "absl/types/span.h" @@ -273,6 +274,13 @@ int main(int argc, char** argv) { { auto service = NewConformanceServiceFromFlags(); auto tests_to_skip = absl::GetFlag(FLAGS_skip_tests); + if (const char* env_skip = std::getenv("CEL_SKIP_TESTS"); + env_skip != nullptr) { + for (absl::string_view test : + absl::StrSplit(env_skip, ',', absl::SkipEmpty())) { + tests_to_skip.push_back(std::string(test)); + } + } for (int argi = 1; argi < argc; argi++) { ABSL_CHECK_OK(RegisterTestsFromFile(service, tests_to_skip, absl::string_view(argv[argi]))); From cf31ddf620b9d809014418e82428863b54190cbb Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 11 May 2026 10:50:27 -0700 Subject: [PATCH 08/87] Introduce `Bind` expression factory helper PiperOrigin-RevId: 913778503 --- common/expr_factory.h | 23 ++++++++++++++ parser/macro_expr_factory_test.cc | 51 +++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/common/expr_factory.h b/common/expr_factory.h index b9769b457..773217ad9 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -352,6 +352,29 @@ class ExprFactory { return expr; } + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + Expr NewBind(NextIdFunc next_id, BindVar bind_var, BindExpr bind_expr, + RestExpr rest_expr) { + Expr expr; + expr.set_id(next_id()); + auto& comprehension_expr = expr.mutable_comprehension_expr(); + comprehension_expr.set_iter_var("#unused"); + comprehension_expr.set_iter_range( + NewList(next_id(), std::vector{})); + comprehension_expr.set_accu_var(bind_var); + comprehension_expr.set_accu_init(std::move(bind_expr)); + comprehension_expr.set_loop_condition(NewBoolConst(next_id(), false)); + comprehension_expr.set_loop_step(NewIdent(next_id(), bind_var)); + comprehension_expr.set_result(std::move(rest_expr)); + return expr; + } + private: friend class MacroExprFactory; friend class ParserMacroExprFactory; diff --git a/parser/macro_expr_factory_test.cc b/parser/macro_expr_factory_test.cc index 489538be1..b95cbe16f 100644 --- a/parser/macro_expr_factory_test.cc +++ b/parser/macro_expr_factory_test.cc @@ -15,6 +15,7 @@ #include "parser/macro_expr_factory.h" #include +#include #include #include "absl/strings/string_view.h" @@ -39,6 +40,7 @@ class TestMacroExprFactory final : public MacroExprFactory { return NewUnspecified(NextId()); } + using MacroExprFactory::NewBind; using MacroExprFactory::NewBoolConst; using MacroExprFactory::NewCall; using MacroExprFactory::NewComprehension; @@ -69,6 +71,8 @@ class TestMacroExprFactory final : public MacroExprFactory { namespace { +using ::testing::IsEmpty; + TEST(MacroExprFactory, CopyUnspecified) { TestMacroExprFactory factory; EXPECT_EQ(factory.Copy(factory.NewUnspecified()), factory.NewUnspecified(2)); @@ -147,5 +151,52 @@ TEST(MacroExprFactory, CopyComprehension) { factory.NewIdent(11, "foo"), factory.NewIdent(12, "bar"))); } +TEST(MacroExprFactory, NewBind) { + TestMacroExprFactory factory; + Expr bind_expr = factory.NewIdent(10, "x"); + Expr rest_expr = factory.NewIdent(20, "y"); + + auto next_id = [id = 100]() mutable { return id++; }; + + Expr expr = + factory.NewBind(next_id, "a", std::move(bind_expr), std::move(rest_expr)); + + EXPECT_EQ(expr.id(), 100); + ASSERT_TRUE(expr.has_comprehension_expr()); + + const auto& comp = expr.comprehension_expr(); + EXPECT_EQ(comp.iter_var(), "#unused"); + + ASSERT_TRUE(comp.has_iter_range()); + EXPECT_EQ(comp.iter_range().id(), 101); + EXPECT_EQ(comp.iter_range().kind_case(), ExprKindCase::kListExpr); + EXPECT_THAT(comp.iter_range().list_expr().elements(), IsEmpty()); + + EXPECT_EQ(comp.accu_var(), "a"); + + ASSERT_TRUE(comp.has_accu_init()); + Expr expected_bind_expr; + expected_bind_expr.set_id(10); + expected_bind_expr.mutable_ident_expr().set_name("x"); + EXPECT_EQ(comp.accu_init(), expected_bind_expr); + + ASSERT_TRUE(comp.has_loop_condition()); + EXPECT_EQ(comp.loop_condition().id(), 102); + EXPECT_EQ(comp.loop_condition().kind_case(), ExprKindCase::kConstant); + EXPECT_TRUE(comp.loop_condition().const_expr().has_bool_value()); + EXPECT_FALSE(comp.loop_condition().const_expr().bool_value()); + + ASSERT_TRUE(comp.has_loop_step()); + EXPECT_EQ(comp.loop_step().id(), 103); + EXPECT_EQ(comp.loop_step().kind_case(), ExprKindCase::kIdentExpr); + EXPECT_EQ(comp.loop_step().ident_expr().name(), "a"); + + ASSERT_TRUE(comp.has_result()); + Expr expected_rest_expr; + expected_rest_expr.set_id(20); + expected_rest_expr.mutable_ident_expr().set_name("y"); + EXPECT_EQ(comp.result(), expected_rest_expr); +} + } // namespace } // namespace cel From 2e6e9ff4493bfbe0baf883107f3fb7ce6f675d88 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 11 May 2026 21:06:47 -0700 Subject: [PATCH 09/87] Add support for abbreviations and aliases in container configuration for CEL C++ environment YAML. This allows specifying name, abbreviations, and aliases in a container config instead of just a string. The string syntax is preserved as an alternative PiperOrigin-RevId: 914038623 --- env/BUILD | 1 + env/config.h | 11 +++- env/env.cc | 12 +++- env/env_test.cc | 30 ++++++++++ env/env_yaml.cc | 107 +++++++++++++++++++++++++++++++-- env/env_yaml_test.cc | 139 ++++++++++++++++++++++++++++++++++++++++++- 6 files changed, 289 insertions(+), 11 deletions(-) diff --git a/env/BUILD b/env/BUILD index 55297b190..41ffc1723 100644 --- a/env/BUILD +++ b/env/BUILD @@ -52,6 +52,7 @@ cc_library( ":config", "//checker:type_checker_builder", "//common:constant", + "//common:container", "//common:decl", "//common:type", "//compiler", diff --git a/env/config.h b/env/config.h index 10b23d030..e427832ff 100644 --- a/env/config.h +++ b/env/config.h @@ -34,9 +34,16 @@ class Config { struct ContainerConfig { std::string name; - // TODO(uncreated-issue/87): add support for aliases and abbreviations. + std::vector abbreviations; + struct Alias { + std::string alias; + std::string qualified_name; + }; + std::vector aliases; - bool IsEmpty() const { return name.empty(); } + bool IsEmpty() const { + return name.empty() && abbreviations.empty() && aliases.empty(); + } }; void SetContainerConfig(ContainerConfig container_config) { diff --git a/env/env.cc b/env/env.cc index 5a4198497..42652ce59 100644 --- a/env/env.cc +++ b/env/env.cc @@ -24,6 +24,7 @@ #include "absl/strings/string_view.h" #include "checker/type_checker_builder.h" #include "common/constant.h" +#include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "compiler/compiler.h" @@ -130,7 +131,16 @@ absl::StatusOr> Env::NewCompilerBuilder() { cel::TypeCheckerBuilder& checker_builder = compiler_builder->GetCheckerBuilder(); - checker_builder.set_container(config_.GetContainerConfig().name); + ExpressionContainer container; + CEL_RETURN_IF_ERROR( + container.SetContainer(config_.GetContainerConfig().name)); + for (const auto& abbr : config_.GetContainerConfig().abbreviations) { + CEL_RETURN_IF_ERROR(container.AddAbbreviation(abbr)); + } + for (const auto& alias : config_.GetContainerConfig().aliases) { + CEL_RETURN_IF_ERROR(container.AddAlias(alias.alias, alias.qualified_name)); + } + checker_builder.SetExpressionContainer(std::move(container)); if (!config_.GetStandardLibraryConfig().disable) { CEL_RETURN_IF_ERROR( diff --git a/env/env_test.cc b/env/env_test.cc index 076eb57bc..b599aa569 100644 --- a/env/env_test.cc +++ b/env/env_test.cc @@ -314,6 +314,36 @@ TEST(ContainerConfigTest, ContainerConfig) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); } +TEST(ContainerConfigTest, ContainerConfigWithAbbreviations) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig( + {.name = "cel.expr.conformance", + .abbreviations = {"cel.expr.conformance.proto2.TestAllTypes"}}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("TestAllTypes{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + +TEST(ContainerConfigTest, ContainerConfigWithAliases) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContainerConfig( + {.name = "cel.expr.conformance", + .aliases = { + {.alias = "MyTestType", + .qualified_name = "cel.expr.conformance.proto2.TestAllTypes"}}}); + env.SetConfig(config); + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("MyTestType{}")); + + EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); +} + struct VariableConfigWithValueTestCase { Config::VariableConfig variable_config; std::string validate_type_expr; diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 4ba16ea84..159786598 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -150,12 +150,72 @@ absl::Status ParseName(Config& config, absl::string_view yaml, absl::Status ParseContainerConfig(Config& config, absl::string_view yaml, const YAML::Node& root) { const YAML::Node container = root["container"]; - if (container.IsDefined()) { - if (!container.IsScalar()) { - return YamlError(yaml, container, "Node 'container' is not a string"); - } + if (!container.IsDefined()) { + return absl::OkStatus(); + } + + if (container.IsScalar()) { config.SetContainerConfig({.name = GetString(yaml, container)}); + return absl::OkStatus(); } + + if (!container.IsMap()) { + return YamlError(yaml, container, + "Node 'container' is neither a string nor a map"); + } + + Config::ContainerConfig container_config; + + const YAML::Node name = container["name"]; + if (name.IsDefined()) { + if (!name.IsScalar()) { + return YamlError(yaml, name, "Node 'name' in container is not a string"); + } + container_config.name = GetString(yaml, name); + } + + const YAML::Node abbreviations = container["abbreviations"]; + if (abbreviations.IsDefined()) { + if (!abbreviations.IsSequence()) { + return YamlError(yaml, abbreviations, + "Node 'abbreviations' is not a sequence"); + } + for (const YAML::Node& abbr : abbreviations) { + if (!abbr.IsScalar()) { + return YamlError(yaml, abbr, "Abbreviation is not a string"); + } + container_config.abbreviations.push_back(GetString(yaml, abbr)); + } + } + + const YAML::Node aliases = container["aliases"]; + if (aliases.IsDefined()) { + if (!aliases.IsSequence()) { + return YamlError(yaml, aliases, "Node 'aliases' is not a sequence"); + } + for (const YAML::Node& alias_node : aliases) { + if (!alias_node.IsMap()) { + return YamlError(yaml, alias_node, "Alias entry is not a map"); + } + const YAML::Node alias_key = alias_node["alias"]; + const YAML::Node qualified_name_key = alias_node["qualified_name"]; + + if (!alias_key.IsDefined() || !alias_key.IsScalar()) { + return YamlError(yaml, alias_node, + "Alias entry missing 'alias' string"); + } + if (!qualified_name_key.IsDefined() || !qualified_name_key.IsScalar()) { + return YamlError(yaml, alias_node, + "Alias entry missing 'qualified_name' string"); + } + + container_config.aliases.push_back( + {.alias = GetString(yaml, alias_key), + .qualified_name = GetString(yaml, qualified_name_key)}); + } + } + + config.SetContainerConfig(std::move(container_config)); return absl::OkStatus(); } @@ -686,7 +746,44 @@ void EmitContainerConfig(const Config& env_config, YAML::Emitter& out) { } out << YAML::Key << "container"; - out << YAML::Value << YAML::DoubleQuoted << container_config.name; + if (container_config.abbreviations.empty() && + container_config.aliases.empty()) { + out << YAML::Value << YAML::DoubleQuoted << container_config.name; + } else { + out << YAML::Value << YAML::BeginMap; + if (!container_config.name.empty()) { + out << YAML::Key << "name" << YAML::Value << YAML::DoubleQuoted + << container_config.name; + } + if (!container_config.abbreviations.empty()) { + std::vector sorted_abbrs = container_config.abbreviations; + absl::c_sort(sorted_abbrs); + out << YAML::Key << "abbreviations" << YAML::Value << YAML::BeginSeq; + for (const auto& abbr : sorted_abbrs) { + out << YAML::Value << YAML::DoubleQuoted << abbr; + } + out << YAML::EndSeq; + } + if (!container_config.aliases.empty()) { + std::vector sorted_aliases = + container_config.aliases; + absl::c_sort(sorted_aliases, [](const Config::ContainerConfig::Alias& a, + const Config::ContainerConfig::Alias& b) { + return a.alias < b.alias; + }); + out << YAML::Key << "aliases" << YAML::Value << YAML::BeginSeq; + for (const auto& alias : sorted_aliases) { + out << YAML::BeginMap; + out << YAML::Key << "alias" << YAML::Value << YAML::DoubleQuoted + << alias.alias; + out << YAML::Key << "qualified_name" << YAML::Value + << YAML::DoubleQuoted << alias.qualified_name; + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + out << YAML::EndMap; + } } void EmitExtensionConfigs(const Config& env_config, YAML::Emitter& out) { diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index 25cc63206..d19c0dbfb 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -55,6 +55,31 @@ TEST(EnvYamlTest, ParseContainerConfig) { Field(&Config::ContainerConfig::name, "test.container")); } +TEST(EnvYamlTest, ParseContainerConfig_AlternativeSyntax) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + container: + name: test.container + abbreviations: + - abbr1.Abbr1 + - abbr2.Abbr2 + aliases: + - alias: alias1 + qualified_name: qual.name1 + - alias: alias2 + qualified_name: qual.name2 + )yaml")); + + const auto& container_config = config.GetContainerConfig(); + EXPECT_EQ(container_config.name, "test.container"); + EXPECT_THAT(container_config.abbreviations, + UnorderedElementsAre("abbr1.Abbr1", "abbr2.Abbr2")); + ASSERT_THAT(container_config.aliases, SizeIs(2)); + EXPECT_EQ(container_config.aliases[0].alias, "alias1"); + EXPECT_EQ(container_config.aliases[0].qualified_name, "qual.name1"); + EXPECT_EQ(container_config.aliases[1].alias, "alias2"); + EXPECT_EQ(container_config.aliases[1].qualified_name, "qual.name2"); +} + TEST(EnvYamlTest, ParseExtensionConfigs) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( extensions: @@ -550,9 +575,78 @@ INSTANTIATE_TEST_SUITE_P( container: - error: "error" )yaml", - .expected_error = "3:19: Node 'container' is not a string\n" - "| - error: \"error\"\n" - "| ^", + .expected_error = + "3:19: Node 'container' is neither a string nor a map\n" + "| - error: \"error\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + name: [] + )yaml", + .expected_error = "3:25: Node 'name' in container is not a string\n" + "| name: []\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + abbreviations: "abbr" + )yaml", + .expected_error = "3:34: Node 'abbreviations' is not a sequence\n" + "| abbreviations: \"abbr\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + abbreviations: + - [] + )yaml", + .expected_error = "4:21: Abbreviation is not a string\n" + "| - []\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: "not a sequence" + )yaml", + .expected_error = "3:28: Node 'aliases' is not a sequence\n" + "| aliases: \"not a sequence\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - "not a map" + )yaml", + .expected_error = "4:21: Alias entry is not a map\n" + "| - \"not a map\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - qualified_name: "qual" + )yaml", + .expected_error = "4:21: Alias entry missing 'alias' string\n" + "| - qualified_name: \"qual\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + container: + aliases: + - alias: "my_alias" + )yaml", + .expected_error = "4:21: Alias entry missing" + " 'qualified_name' string\n" + "| - alias: \"my_alias\"\n" + "| ^", }, ParseTestCase{ .yaml = R"yaml( @@ -946,6 +1040,33 @@ std::vector GetExportTestCases() { container: "test.container" )yaml", }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + config.SetName("test.env"); + config.SetContainerConfig( + {.name = "test.container", + .abbreviations = {"foo", "bar"}, + .aliases = { + {.alias = "foo", .qualified_name = "test.foo"}, + {.alias = "bar", .qualified_name = "test.bar"}, + }}); + return config; + }(), + .expected_yaml = R"yaml( + name: "test.env" + container: + name: "test.container" + abbreviations: + - "bar" + - "foo" + aliases: + - alias: "bar" + qualified_name: "test.bar" + - alias: "foo" + qualified_name: "test.foo" + )yaml", + }, ExportTestCase{ .config = []() -> absl::StatusOr { Config config; @@ -1385,6 +1506,18 @@ std::vector GetRoundTripTestCases() { overloads: - id: "string_to_timestamp" )yaml", + R"yaml( + container: + name: "test.container" + abbreviations: + - "abbr1.Abbr1" + - "abbr2.Abbr2" + aliases: + - alias: "alias1" + qualified_name: "qual.name1" + - alias: "alias2" + qualified_name: "qual.name2" + )yaml", R"yaml( extensions: - name: "bindings" From cd9f059a5833c92576e85e3ffb2eaee2fd328e76 Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 12 May 2026 12:32:02 -0700 Subject: [PATCH 10/87] Fix repeated field null pruning for proto2/proto3 PiperOrigin-RevId: 914421409 --- common/values/struct_value_builder.cc | 11 +++++++++++ conformance/BUILD | 6 ------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/common/values/struct_value_builder.cc b/common/values/struct_value_builder.cc index 359596267..c342d6478 100644 --- a/common/values/struct_value_builder.cc +++ b/common/values/struct_value_builder.cc @@ -812,6 +812,17 @@ ProtoMessageRepeatedFieldFromValueMutator( const google::protobuf::Reflection* absl_nonnull reflection, google::protobuf::Message* absl_nonnull message, const google::protobuf::FieldDescriptor* absl_nonnull field, const Value& value) { + // If the value is null and the target repeated field is anything except + // google.protobuf.{Any,ListValue,Struct,Value}, it should be pruned. + if (value.IsNull()) { + const auto well_known_type = field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY) { + return absl::nullopt; + } + } auto* element = reflection->AddMessage(message, field, factory); auto result = ProtoMessageFromValueImpl(value, pool, factory, well_known_types, element); diff --git a/conformance/BUILD b/conformance/BUILD index 726a11b0b..abc0d918a 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -203,15 +203,9 @@ _TESTS_TO_SKIP = [ "conversions/string/double_hard", # Recent changes - "proto2/set_null/repeated_field_timestamp_null_pruned", - "proto2/set_null/repeated_field_duration_null_pruned", - "proto2/set_null/repeated_field_wrapper_null_pruned", "proto2/set_null/map_timestamp_null_pruned", "proto2/set_null/map_duration_null_pruned", "proto2/set_null/map_wrapper_null_pruned", - "proto3/set_null/repeated_field_timestamp_null_pruned", - "proto3/set_null/repeated_field_duration_null_pruned", - "proto3/set_null/repeated_field_wrapper_null_pruned", "proto3/set_null/map_timestamp_null_pruned", "proto3/set_null/map_duration_null_pruned", "proto3/set_null/map_wrapper_null_pruned", From 4749cf81003d9264fd87ca8b0640b5189bcc2b9e Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 12 May 2026 16:00:59 -0700 Subject: [PATCH 11/87] Fix scientific notation and fixed point formatting for int and uint PiperOrigin-RevId: 914525320 --- conformance/BUILD | 4 ---- extensions/formatting.cc | 6 ++++++ extensions/formatting_test.cc | 12 ++++++++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index abc0d918a..4f9232ab6 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -209,10 +209,6 @@ _TESTS_TO_SKIP = [ "proto3/set_null/map_timestamp_null_pruned", "proto3/set_null/map_duration_null_pruned", "proto3/set_null/map_wrapper_null_pruned", - "string_ext/format/default precision for fixed-point clause with int", - "string_ext/format/default precision for fixed-point clause with uint", - "string_ext/format/default precision for scientific notation with int", - "string_ext/format/default precision for scientific notation with uint", "namespace/namespace_shadowing/basic", "namespace/namespace_shadowing/comprehension_shadowing_namespaced_selector_disambiguation", ] diff --git a/extensions/formatting.cc b/extensions/formatting.cc index 935815569..252fdc7bd 100644 --- a/extensions/formatting.cc +++ b/extensions/formatting.cc @@ -419,6 +419,12 @@ absl::StatusOr GetDouble(const Value& value, std::string& scratch) { str)); } } + if (value.kind() == ValueKind::kInt) { + return static_cast(value.GetInt().NativeValue()); + } + if (value.kind() == ValueKind::kUint) { + return static_cast(value.GetUint().NativeValue()); + } if (value.kind() != ValueKind::kDouble) { return absl::InvalidArgumentError( absl::StrCat("expected a double but got a ", value.GetTypeName())); diff --git a/extensions/formatting_test.cc b/extensions/formatting_test.cc index b80fe9bc0..6a7fb300b 100644 --- a/extensions/formatting_test.cc +++ b/extensions/formatting_test.cc @@ -553,6 +553,18 @@ INSTANTIATE_TEST_SUITE_P( .format_args = "2.71828", .expected = "2.718280e+00", }, + { + .name = "FixedPointClauseWithInt", + .format = "%f", + .format_args = "3", + .expected = "3.000000", + }, + { + .name = "ScientificNotationWithUint", + .format = "%e", + .format_args = "uint(3)", + .expected = "3.000000e+00", + }, { .name = "NaNSupportForFixedPoint", .format = "%f", From 352666fba7822dd0d1f54dc00b332cc527aa81b1 Mon Sep 17 00:00:00 2001 From: Justin King Date: Wed, 13 May 2026 13:07:09 -0700 Subject: [PATCH 12/87] Fix map field value null pruning for proto2/proto3 PiperOrigin-RevId: 915015044 --- common/values/struct_value_builder.cc | 23 +++++++++++++++++++ conformance/BUILD | 6 ----- .../structs/proto_message_type_adapter.cc | 16 +++++++++++++ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/common/values/struct_value_builder.cc b/common/values/struct_value_builder.cc index c342d6478..446b18421 100644 --- a/common/values/struct_value_builder.cc +++ b/common/values/struct_value_builder.cc @@ -956,6 +956,19 @@ class MessageValueBuilderImpl { if (error_value) { return false; } + if (map_value_field->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE && + entry_value.IsNull()) { + auto well_known_type = + map_value_field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + return true; + } + } google::protobuf::MapValueRef proto_value; extensions::protobuf_internal::InsertOrLookupMapValue( *reflection_, message_, *field, proto_key, &proto_value); @@ -989,6 +1002,16 @@ class MessageValueBuilderImpl { CEL_RETURN_IF_ERROR(list_value->ForEach( [this, field, accessor, &error_value](const Value& element) -> absl::StatusOr { + if (field->message_type() != nullptr && element.IsNull()) { + auto well_known_type = field->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != + google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + return true; + } + } CEL_ASSIGN_OR_RETURN(error_value, (*accessor)(descriptor_pool_, message_factory_, &well_known_types_, reflection_, diff --git a/conformance/BUILD b/conformance/BUILD index 4f9232ab6..ccd2844c9 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -203,12 +203,6 @@ _TESTS_TO_SKIP = [ "conversions/string/double_hard", # Recent changes - "proto2/set_null/map_timestamp_null_pruned", - "proto2/set_null/map_duration_null_pruned", - "proto2/set_null/map_wrapper_null_pruned", - "proto3/set_null/map_timestamp_null_pruned", - "proto3/set_null/map_duration_null_pruned", - "proto3/set_null/map_wrapper_null_pruned", "namespace/namespace_shadowing/basic", "namespace/namespace_shadowing/comprehension_shadowing_namespaced_selector_disambiguation", ] diff --git a/eval/public/structs/proto_message_type_adapter.cc b/eval/public/structs/proto_message_type_adapter.cc index a351890c2..6a3417ba3 100644 --- a/eval/public/structs/proto_message_type_adapter.cc +++ b/eval/public/structs/proto_message_type_adapter.cc @@ -582,6 +582,19 @@ absl::Status ProtoMessageTypeAdapter::SetField( ValidateSetFieldOp(value_field_descriptor != nullptr, field->name(), "failed to find value field descriptor")); + bool prune_when_null = false; + if (value_field_descriptor->cpp_type() == + google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + auto well_known_type = + value_field_descriptor->message_type()->well_known_type(); + if (well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_ANY && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE && + well_known_type != google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT) { + prune_when_null = true; + } + } + CEL_ASSIGN_OR_RETURN(const CelList* key_list, cel_map->ListKeys(arena)); for (int i = 0; i < key_list->size(); i++) { CelValue key = (*key_list).Get(arena, i); @@ -589,6 +602,9 @@ absl::Status ProtoMessageTypeAdapter::SetField( auto value = (*cel_map).Get(arena, key); CEL_RETURN_IF_ERROR(ValidateSetFieldOp(value.has_value(), field->name(), "error serializing CelMap")); + if (prune_when_null && value->IsNull()) { + continue; + } Message* entry_msg = message->GetReflection()->AddMessage(message, field); CEL_RETURN_IF_ERROR(internal::SetValueToSingleField( key, key_field_descriptor, entry_msg, arena)); From 037e0bb42339376640024de353451e372bb47820 Mon Sep 17 00:00:00 2001 From: Muhammad Askri Date: Wed, 13 May 2026 13:33:46 -0700 Subject: [PATCH 13/87] Adding a TypeSpec to Type resolver. PiperOrigin-RevId: 915029748 --- common/BUILD | 32 ++++ common/type_spec_resolver.cc | 182 +++++++++++++++++++++ common/type_spec_resolver.h | 37 +++++ common/type_spec_resolver_test.cc | 257 ++++++++++++++++++++++++++++++ 4 files changed, 508 insertions(+) create mode 100644 common/type_spec_resolver.cc create mode 100644 common/type_spec_resolver.h create mode 100644 common/type_spec_resolver_test.cc diff --git a/common/BUILD b/common/BUILD index 0ead8b15a..ffc4ae1e9 100644 --- a/common/BUILD +++ b/common/BUILD @@ -46,6 +46,38 @@ cc_test( ], ) +cc_library( + name = "type_spec_resolver", + srcs = ["type_spec_resolver.cc"], + hdrs = ["type_spec_resolver.h"], + deps = [ + ":ast", + ":type", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "type_spec_resolver_test", + srcs = ["type_spec_resolver_test.cc"], + deps = [ + ":ast", + ":type", + ":type_kind", + ":type_spec_resolver", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "expr", srcs = ["expr.cc"], diff --git a/common/type_spec_resolver.cc b/common/type_spec_resolver.cc new file mode 100644 index 000000000..97451f390 --- /dev/null +++ b/common/type_spec_resolver.cc @@ -0,0 +1,182 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_spec_resolver.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/type.h" +#include "internal/status_macros.h" + +namespace cel { + +absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool) { + if (type_spec.has_null()) return Type(NullType{}); + if (type_spec.has_dyn()) return Type(DynType{}); + + if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + return Type(BoolType{}); + case PrimitiveType::kInt64: + return Type(IntType{}); + case PrimitiveType::kUint64: + return Type(UintType{}); + case PrimitiveType::kDouble: + return Type(DoubleType{}); + case PrimitiveType::kString: + return Type(StringType{}); + case PrimitiveType::kBytes: + return Type(BytesType{}); + default: + return absl::InvalidArgumentError("Unsupported primitive type"); + } + } + + if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + return Type(AnyType{}); + case WellKnownTypeSpec::kTimestamp: + return Type(TimestampType{}); + case WellKnownTypeSpec::kDuration: + return Type(DurationType{}); + default: + return absl::InvalidArgumentError("Unsupported well-known type"); + } + } + + if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + return Type(BoolWrapperType{}); + case PrimitiveType::kInt64: + return Type(IntWrapperType{}); + case PrimitiveType::kUint64: + return Type(UintWrapperType{}); + case PrimitiveType::kDouble: + return Type(DoubleWrapperType{}); + case PrimitiveType::kString: + return Type(StringWrapperType{}); + case PrimitiveType::kBytes: + return Type(BytesWrapperType{}); + default: + return absl::InvalidArgumentError("Unsupported wrapper type"); + } + } + + if (type_spec.has_list_type()) { + CEL_ASSIGN_OR_RETURN( + auto elem_type, + ConvertTypeSpecToType(type_spec.list_type().elem_type(), arena, pool)); + return Type(ListType(arena, elem_type)); + } + + if (type_spec.has_map_type()) { + CEL_ASSIGN_OR_RETURN( + auto key_type, + ConvertTypeSpecToType(type_spec.map_type().key_type(), arena, pool)); + CEL_ASSIGN_OR_RETURN( + auto value_type, + ConvertTypeSpecToType(type_spec.map_type().value_type(), arena, pool)); + return Type(MapType(arena, key_type, value_type)); + } + + if (type_spec.has_function()) { + const auto& func_spec = type_spec.function(); + CEL_ASSIGN_OR_RETURN( + auto result_type, + ConvertTypeSpecToType(func_spec.result_type(), arena, pool)); + std::vector arg_types; + for (const auto& arg_spec : func_spec.arg_types()) { + CEL_ASSIGN_OR_RETURN(auto arg_type, + ConvertTypeSpecToType(arg_spec, arena, pool)); + arg_types.push_back(std::move(arg_type)); + } + return Type(FunctionType(arena, result_type, arg_types)); + } + + if (type_spec.has_type_param()) { + const std::string& name = type_spec.type_param().type(); + auto* allocated_name = google::protobuf::Arena::Create(arena, name); + return Type(TypeParamType(absl::string_view(*allocated_name))); + } + + if (type_spec.has_message_type()) { + const std::string& name = type_spec.message_type().type(); + const google::protobuf::Descriptor* descriptor = pool.FindMessageTypeByName(name); + if (descriptor == nullptr) { + return absl::InvalidArgumentError(absl::StrCat( + "Message type '", name, "' not found in descriptor pool")); + } + return Type::Message(descriptor); + } + + if (type_spec.has_abstract_type()) { + const std::string& name = type_spec.abstract_type().name(); + + // Check if it's a message type in the pool + const google::protobuf::Descriptor* descriptor = pool.FindMessageTypeByName(name); + if (descriptor != nullptr) { + if (!type_spec.abstract_type().parameter_types().empty()) { + return absl::InvalidArgumentError(absl::StrCat( + "Message type '", name, "' cannot have type parameters")); + } + return Type::Message(descriptor); + } + + // Check if it's an enum type in the pool + const google::protobuf::EnumDescriptor* enum_descriptor = + pool.FindEnumTypeByName(name); + if (enum_descriptor != nullptr) { + if (!type_spec.abstract_type().parameter_types().empty()) { + return absl::InvalidArgumentError( + absl::StrCat("Enum type '", name, "' cannot have type parameters")); + } + return Type::Enum(enum_descriptor); + } + + // Otherwise fallback to OpaqueType + std::vector params; + for (const auto& param_spec : type_spec.abstract_type().parameter_types()) { + CEL_ASSIGN_OR_RETURN(auto param, + ConvertTypeSpecToType(param_spec, arena, pool)); + params.push_back(std::move(param)); + } + auto* allocated_name = google::protobuf::Arena::Create(arena, name); + return Type(OpaqueType(arena, absl::string_view(*allocated_name), params)); + } + + if (type_spec.has_type()) { + CEL_ASSIGN_OR_RETURN(auto contained_type, + ConvertTypeSpecToType(type_spec.type(), arena, pool)); + return Type(TypeType(arena, contained_type)); + } + + if (type_spec.has_error()) { + return Type(ErrorType{}); + } + + return absl::InvalidArgumentError("Unknown TypeSpec kind"); +} + +} // namespace cel diff --git a/common/type_spec_resolver.h b/common/type_spec_resolver.h new file mode 100644 index 000000000..44e1e088f --- /dev/null +++ b/common/type_spec_resolver.h @@ -0,0 +1,37 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ + +#include "absl/status/statusor.h" +#include "common/ast.h" +#include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +// Resolves a `cel::TypeSpec` to a `cel::Type`. +// +// TypeSpec only specifies a type while Type provides support for inspecting +// properties of the type when used in CEL. Returns a status with code +// `InvalidArgument` if the input cannot be resolved to a type. +absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, + google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ diff --git a/common/type_spec_resolver_test.cc b/common/type_spec_resolver_test.cc new file mode 100644 index 000000000..c7fbb2cf8 --- /dev/null +++ b/common/type_spec_resolver_test.cc @@ -0,0 +1,257 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/type_spec_resolver.h" + +#include +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::internal::GetTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::TestWithParam; +using ::testing::Values; + +google::protobuf::Arena* GetTestArena() { + static absl::NoDestructor arena; + return &*arena; +} + +TEST(TypeSpecResolverTest, NullTypeSpec) { + TypeSpec spec(NullTypeSpec{}); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsNull()); +} + +TEST(TypeSpecResolverTest, DynTypeSpec) { + TypeSpec spec(DynTypeSpec{}); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsDyn()); +} + +using ConversionTest = testing::TestWithParam>; + +TEST_P(ConversionTest, TestTypeSpecConversion) { + ASSERT_OK_AND_ASSIGN( + auto t, ConvertTypeSpecToType(std::get<0>(GetParam()), GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_EQ(t.kind(), std::get<1>(GetParam())); +} + +INSTANTIATE_TEST_SUITE_P( + TypeSpecResolverTest, ConversionTest, + testing::Values( + std::make_tuple(TypeSpec(PrimitiveType::kBool), TypeKind::kBool), + std::make_tuple(TypeSpec(PrimitiveType::kInt64), TypeKind::kInt), + std::make_tuple(TypeSpec(PrimitiveType::kUint64), TypeKind::kUint), + std::make_tuple(TypeSpec(PrimitiveType::kDouble), TypeKind::kDouble), + std::make_tuple(TypeSpec(PrimitiveType::kString), TypeKind::kString), + std::make_tuple(TypeSpec(PrimitiveType::kBytes), TypeKind::kBytes), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kAny), TypeKind::kAny), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kTimestamp), + TypeKind::kTimestamp), + std::make_tuple(TypeSpec(WellKnownTypeSpec::kDuration), + TypeKind::kDuration), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)), + TypeKind::kBoolWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)), + TypeKind::kIntWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)), + TypeKind::kUintWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + TypeKind::kDoubleWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)), + TypeKind::kStringWrapper), + std::make_tuple(TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)), + TypeKind::kBytesWrapper))); + +TEST(TypeSpecResolverTest, ListTypeConversion) { + auto elem = std::make_unique(PrimitiveType::kInt64); + TypeSpec spec(ListTypeSpec(std::move(elem))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsList()); + EXPECT_TRUE(t->GetList().element().IsInt()); +} + +TEST(TypeSpecResolverTest, MapTypeConversion) { + auto key = std::make_unique(PrimitiveType::kString); + auto val = std::make_unique(PrimitiveType::kBytes); + TypeSpec spec(MapTypeSpec(std::move(key), std::move(val))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMap()); + EXPECT_TRUE(t->GetMap().key().IsString()); + EXPECT_TRUE(t->GetMap().value().IsBytes()); +} + +TEST(TypeSpecResolverTest, FunctionTypeConversion) { + auto result = std::make_unique(PrimitiveType::kBool); + std::vector args; + args.push_back(TypeSpec(PrimitiveType::kString)); + TypeSpec spec(FunctionTypeSpec(std::move(result), std::move(args))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsFunction()); + EXPECT_EQ(t->GetFunction().args().size(), 1); + EXPECT_TRUE(t->GetFunction().result().IsBool()); +} + +TEST(TypeSpecResolverTest, TypeParamConversion) { + TypeSpec spec(ParamTypeSpec("T")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsTypeParam()); + EXPECT_EQ(t->GetTypeParam().name(), "T"); +} + +TEST(TypeSpecResolverTest, MessageTypeConversion) { + TypeSpec spec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", /*params=*/{})); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMessage()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(TypeSpecResolverTest, MessageTypeWithParamsError) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("cel.expr.conformance.proto3.TestAllTypes", + std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(TypeSpecResolverTest, UnresolvedAbstractTypeFallbackToOpaque) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("my.custom.OpaqueType", std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsOpaque()); + EXPECT_EQ(t->name(), "my.custom.OpaqueType"); + EXPECT_EQ(t->GetParameters().size(), 1); + EXPECT_TRUE(t->GetParameters()[0].IsInt()); +} + +TEST(TypeSpecResolverTest, OptionalType) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec(AbstractType("optional_type", std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsOpaque()); + EXPECT_EQ(t->name(), "optional_type"); + EXPECT_EQ(t->GetParameters().size(), 1); + EXPECT_TRUE(t->GetParameters()[0].IsInt()); + EXPECT_TRUE(t->IsOptional()); +} + +TEST(TypeSpecResolverTest, TypeTypeConversion) { + auto nested = std::make_unique(PrimitiveType::kInt64); + TypeSpec spec(std::move(nested)); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsType()); + EXPECT_TRUE(t->GetType().GetType().IsInt()); +} + +TEST(TypeSpecResolverTest, ErrorTypeConversion) { + TypeSpec spec(ErrorTypeSpec::kValue); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsError()); +} + +TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { + TypeSpec spec(MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsMessage()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(TypeSpecResolverTest, MessageTypeSpecNotFoundError) { + TypeSpec spec(MessageTypeSpec("cel.expr.conformance.proto3.NonExistentType")); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("not found in descriptor pool"))); +} + +TEST(TypeSpecResolverTest, EnumTypeConversion) { + TypeSpec spec(AbstractType( + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum", /*params=*/{})); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t, IsOk()); + EXPECT_TRUE(t->IsEnum()); + EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"); +} + +TEST(TypeSpecResolverTest, EnumTypeWithParamsError) { + std::vector params; + params.push_back(TypeSpec(PrimitiveType::kInt64)); + TypeSpec spec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes.NestedEnum", + std::move(params))); + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(TypeSpecResolverTest, UnknownTypeSpecKindError) { + TypeSpec spec; + auto t = + ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); + EXPECT_THAT(t, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unknown TypeSpec kind"))); +} + +} // namespace +} // namespace cel From ad18948079b2d3d8b9e62a202889076f872992e7 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 May 2026 22:06:31 -0700 Subject: [PATCH 14/87] No public description PiperOrigin-RevId: 915223448 --- eval/public/ast_rewrite.cc | 2 +- eval/public/ast_traverse.cc | 2 +- eval/public/cel_attribute.cc | 4 ++-- eval/public/equality_function_registrar_test.cc | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/eval/public/ast_rewrite.cc b/eval/public/ast_rewrite.cc index 3c210e607..87c667eb5 100644 --- a/eval/public/ast_rewrite.cc +++ b/eval/public/ast_rewrite.cc @@ -68,7 +68,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/eval/public/ast_traverse.cc b/eval/public/ast_traverse.cc index a86923c67..c18b806b9 100644 --- a/eval/public/ast_traverse.cc +++ b/eval/public/ast_traverse.cc @@ -67,7 +67,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/eval/public/cel_attribute.cc b/eval/public/cel_attribute.cc index 015289bed..70525a04d 100644 --- a/eval/public/cel_attribute.cc +++ b/eval/public/cel_attribute.cc @@ -76,8 +76,8 @@ CelAttributeQualifier CreateCelAttributeQualifier(const CelValue& value) { CelAttributePattern CreateCelAttributePattern( absl::string_view variable, - std::initializer_list> + std::initializer_list> path_spec) { std::vector path; path.reserve(path_spec.size()); diff --git a/eval/public/equality_function_registrar_test.cc b/eval/public/equality_function_registrar_test.cc index 772ddfeba..577c4be22 100644 --- a/eval/public/equality_function_registrar_test.cc +++ b/eval/public/equality_function_registrar_test.cc @@ -86,7 +86,7 @@ MATCHER_P2(DefinesHomogenousOverload, name, argument_type, struct EqualityTestCase { enum class ErrorKind { kMissingOverload, kMissingIdentifier }; absl::string_view expr; - absl::variant result; + std::variant result; CelValue lhs = CelValue::CreateNull(); CelValue rhs = CelValue::CreateNull(); }; From 1f5a7e62900ae2ad1021228df04f2a950744c001 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 May 2026 22:10:11 -0700 Subject: [PATCH 15/87] No public description PiperOrigin-RevId: 915224564 --- common/ast/constant_proto.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/ast/constant_proto.cc b/common/ast/constant_proto.cc index c0fe1c9f6..1982c05b4 100644 --- a/common/ast/constant_proto.cc +++ b/common/ast/constant_proto.cc @@ -35,7 +35,7 @@ using ConstantProto = cel::expr::Constant; absl::Status ConstantToProto(const Constant& constant, ConstantProto* absl_nonnull proto) { return absl::visit(absl::Overload( - [proto](absl::monostate) -> absl::Status { + [proto](std::monostate) -> absl::Status { proto->clear_constant_kind(); return absl::OkStatus(); }, From 6d311f704ade7aea062dd1091dfe3e683938fc78 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 13 May 2026 22:12:46 -0700 Subject: [PATCH 16/87] No public description PiperOrigin-RevId: 915225296 --- internal/json.cc | 2 +- internal/message_equality.cc | 8 ++++---- internal/well_known_types.cc | 2 +- internal/well_known_types_test.cc | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/json.cc b/internal/json.cc index 630ceb267..cdd4c1a5d 100644 --- a/internal/json.cc +++ b/internal/json.cc @@ -1417,7 +1417,7 @@ class JsonMapIterator final { } private: - absl::variant variant_; + std::variant variant_; }; class JsonAccessor { diff --git a/internal/message_equality.cc b/internal/message_equality.cc index 945cca8df..33ef78089 100644 --- a/internal/message_equality.cc +++ b/internal/message_equality.cc @@ -86,10 +86,10 @@ class EquatableMessage final }; using EquatableValue = - absl::variant; + std::variant; struct NullValueEqualer { bool operator()(std::nullptr_t, std::nullptr_t) const { return true; } diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index dee029534..02e50c3e3 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -2174,7 +2174,7 @@ absl::StatusOr AdaptFromMessage( if (adapted) { return adapted; } - return absl::monostate{}; + return std::monostate{}; } } diff --git a/internal/well_known_types_test.cc b/internal/well_known_types_test.cc index 0d2c9fe33..afc8ce396 100644 --- a/internal/well_known_types_test.cc +++ b/internal/well_known_types_test.cc @@ -806,7 +806,7 @@ TEST_F(AdaptFromMessageTest, Struct) { TEST_F(AdaptFromMessageTest, TestAllTypesProto3) { auto message = DynamicParseTextProto(R"pb()pb"); EXPECT_THAT(AdaptFromMessage(*message), - IsOkAndHolds(VariantWith(absl::monostate()))); + IsOkAndHolds(VariantWith(std::monostate()))); } TEST_F(AdaptFromMessageTest, Any_BoolValue) { From fb51dcdfd1082e67d209c1ba0c84e58b577c378a Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 00:13:04 -0700 Subject: [PATCH 17/87] No public description PiperOrigin-RevId: 915268033 --- runtime/internal/convert_constant.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/internal/convert_constant.cc b/runtime/internal/convert_constant.cc index a9effd229..33f382858 100644 --- a/runtime/internal/convert_constant.cc +++ b/runtime/internal/convert_constant.cc @@ -33,7 +33,7 @@ using ::cel::Constant; struct ConvertVisitor { Allocator<> allocator; - absl::StatusOr operator()(absl::monostate) { + absl::StatusOr operator()(std::monostate) { return absl::InvalidArgumentError("unspecified constant"); } absl::StatusOr operator()(std::nullptr_t) { return NullValue(); } From 877239571674284da3d22bcc7ccfe2e175643de7 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 00:15:52 -0700 Subject: [PATCH 18/87] No public description PiperOrigin-RevId: 915269088 --- common/values/message_value.cc | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/common/values/message_value.cc b/common/values/message_value.cc index e06206407..66dfd9511 100644 --- a/common/values/message_value.cc +++ b/common/values/message_value.cc @@ -46,7 +46,7 @@ const google::protobuf::Descriptor* absl_nonnull MessageValue::GetDescriptor() c ABSL_CHECK(*this); // Crash OK return absl::visit( absl::Overload( - [](absl::monostate) -> const google::protobuf::Descriptor* absl_nonnull { + [](std::monostate) -> const google::protobuf::Descriptor* absl_nonnull { ABSL_UNREACHABLE(); }, [](const ParsedMessageValue& alternative) @@ -58,7 +58,7 @@ const google::protobuf::Descriptor* absl_nonnull MessageValue::GetDescriptor() c std::string MessageValue::DebugString() const { return absl::visit( - absl::Overload([](absl::monostate) -> std::string { return "INVALID"; }, + absl::Overload([](std::monostate) -> std::string { return "INVALID"; }, [](const ParsedMessageValue& alternative) -> std::string { return alternative.DebugString(); }), @@ -68,7 +68,7 @@ std::string MessageValue::DebugString() const { bool MessageValue::IsZeroValue() const { ABSL_DCHECK(*this); return absl::visit( - absl::Overload([](absl::monostate) -> bool { return true; }, + absl::Overload([](std::monostate) -> bool { return true; }, [](const ParsedMessageValue& alternative) -> bool { return alternative.IsZeroValue(); }), @@ -81,7 +81,7 @@ absl::Status MessageValue::SerializeTo( google::protobuf::io::ZeroCopyOutputStream* absl_nonnull output) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJson` on " "an invalid `MessageValue`"); @@ -99,7 +99,7 @@ absl::Status MessageValue::ConvertToJson( google::protobuf::Message* absl_nonnull json) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJson` on " "an invalid `MessageValue`"); @@ -117,7 +117,7 @@ absl::Status MessageValue::ConvertToJsonObject( google::protobuf::Message* absl_nonnull json) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ConvertToJsonObject` on " "an invalid `MessageValue`"); @@ -136,7 +136,7 @@ absl::Status MessageValue::Equal( google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `Equal` on " "an invalid `MessageValue`"); @@ -155,7 +155,7 @@ absl::Status MessageValue::GetFieldByName( google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `GetFieldByName` on " "an invalid `MessageValue`"); @@ -175,7 +175,7 @@ absl::Status MessageValue::GetFieldByNumber( google::protobuf::Arena* absl_nonnull arena, Value* absl_nonnull result) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `GetFieldByNumber` on " "an invalid `MessageValue`"); @@ -192,7 +192,7 @@ absl::StatusOr MessageValue::HasFieldByName( absl::string_view name) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::StatusOr { + [](std::monostate) -> absl::StatusOr { return absl::InternalError( "unexpected attempt to invoke `HasFieldByName` on " "an invalid `MessageValue`"); @@ -206,7 +206,7 @@ absl::StatusOr MessageValue::HasFieldByName( absl::StatusOr MessageValue::HasFieldByNumber(int64_t number) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::StatusOr { + [](std::monostate) -> absl::StatusOr { return absl::InternalError( "unexpected attempt to invoke `HasFieldByNumber` on " "an invalid `MessageValue`"); @@ -224,7 +224,7 @@ absl::Status MessageValue::ForEachField( google::protobuf::Arena* absl_nonnull arena) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `ForEachField` on " "an invalid `MessageValue`"); @@ -244,7 +244,7 @@ absl::Status MessageValue::Qualify( int* absl_nonnull count) const { return absl::visit( absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError( "unexpected attempt to invoke `Qualify` on " "an invalid `MessageValue`"); From ff45a7c2a096ed1d38e6ed4d80a7180be32874b7 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 03:39:55 -0700 Subject: [PATCH 19/87] No public description PiperOrigin-RevId: 915340723 --- common/ast_rewrite.cc | 2 +- common/ast_traverse.cc | 2 +- common/decl_proto.cc | 2 +- common/decl_proto_test.cc | 4 ++-- common/decl_proto_v1alpha1.cc | 2 +- common/type.cc | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/common/ast_rewrite.cc b/common/ast_rewrite.cc index 14582f44f..b61e1fab6 100644 --- a/common/ast_rewrite.cc +++ b/common/ast_rewrite.cc @@ -54,7 +54,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/common/ast_traverse.cc b/common/ast_traverse.cc index a6ba0d1ba..fb4f9731e 100644 --- a/common/ast_traverse.cc +++ b/common/ast_traverse.cc @@ -53,7 +53,7 @@ struct ExprRecord { }; using StackRecordKind = - absl::variant; + std::variant; struct StackRecord { public: diff --git a/common/decl_proto.cc b/common/decl_proto.cc index 89f7f4453..098c5068c 100644 --- a/common/decl_proto.cc +++ b/common/decl_proto.cc @@ -69,7 +69,7 @@ absl::StatusOr FunctionDeclFromProto( return decl; } -absl::StatusOr> DeclFromProto( +absl::StatusOr> DeclFromProto( const cel::expr::Decl& decl, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { diff --git a/common/decl_proto_test.cc b/common/decl_proto_test.cc index 62215f07f..d72d97e09 100644 --- a/common/decl_proto_test.cc +++ b/common/decl_proto_test.cc @@ -49,7 +49,7 @@ TEST_P(DeclFromProtoTest, FromProtoWorks) { cel::expr::Decl decl_pb; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(test_case.proto_decl, &decl_pb)); - absl::StatusOr> decl_or = + absl::StatusOr> decl_or = DeclFromProto(decl_pb, descriptor_pool, &arena); switch (test_case.decl_type) { case DeclType::kVariable: { @@ -79,7 +79,7 @@ TEST_P(DeclFromProtoTest, FromV1Alpha1ProtoWorks) { google::api::expr::v1alpha1::Decl decl_pb; ASSERT_TRUE( google::protobuf::TextFormat::ParseFromString(test_case.proto_decl, &decl_pb)); - absl::StatusOr> decl_or = + absl::StatusOr> decl_or = DeclFromV1Alpha1Proto(decl_pb, descriptor_pool, &arena); switch (test_case.decl_type) { case DeclType::kVariable: { diff --git a/common/decl_proto_v1alpha1.cc b/common/decl_proto_v1alpha1.cc index 2c6cfb6e4..a8d73e5c2 100644 --- a/common/decl_proto_v1alpha1.cc +++ b/common/decl_proto_v1alpha1.cc @@ -52,7 +52,7 @@ absl::StatusOr FunctionDeclFromV1Alpha1Proto( return FunctionDeclFromProto(name, unversioned, descriptor_pool, arena); } -absl::StatusOr> DeclFromV1Alpha1Proto( +absl::StatusOr> DeclFromV1Alpha1Proto( const google::api::expr::v1alpha1::Decl& decl, const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, google::protobuf::Arena* absl_nonnull arena) { diff --git a/common/type.cc b/common/type.cc index ce8c7a89a..f94e8bc52 100644 --- a/common/type.cc +++ b/common/type.cc @@ -97,7 +97,7 @@ static constexpr std::array kTypeToKindArray = { TypeKind::kUnknown}; static_assert(kTypeToKindArray.size() == - absl::variant_size(), + std::variant_size(), "Kind indexer must match variant declaration for cel::Type."); } // namespace From 366498bd3820ab8382282ca15279753e4789be31 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 03:46:00 -0700 Subject: [PATCH 20/87] No public description PiperOrigin-RevId: 915342975 --- common/types/struct_type.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/types/struct_type.cc b/common/types/struct_type.cc index 4540cec9c..a1be1f786 100644 --- a/common/types/struct_type.cc +++ b/common/types/struct_type.cc @@ -27,7 +27,7 @@ namespace cel { absl::string_view StructType::name() const { ABSL_DCHECK(*this); return absl::visit( - absl::Overload([](absl::monostate) { return absl::string_view(); }, + absl::Overload([](std::monostate) { return absl::string_view(); }, [](const common_internal::BasicStructType& alt) { return alt.name(); }, @@ -39,7 +39,7 @@ TypeParameters StructType::GetParameters() const { ABSL_DCHECK(*this); return absl::visit( absl::Overload( - [](absl::monostate) { return TypeParameters(); }, + [](std::monostate) { return TypeParameters(); }, [](const common_internal::BasicStructType& alt) { return alt.GetParameters(); }, @@ -49,7 +49,7 @@ TypeParameters StructType::GetParameters() const { std::string StructType::DebugString() const { return absl::visit( - absl::Overload([](absl::monostate) { return std::string(); }, + absl::Overload([](std::monostate) { return std::string(); }, [](common_internal::BasicStructType alt) { return alt.DebugString(); }, @@ -72,7 +72,7 @@ MessageType StructType::GetMessage() const { common_internal::TypeVariant StructType::ToTypeVariant() const { return absl::visit( absl::Overload( - [](absl::monostate) { return common_internal::TypeVariant(); }, + [](std::monostate) { return common_internal::TypeVariant(); }, [](common_internal::BasicStructType alt) { return static_cast(alt) ? common_internal::TypeVariant(alt) : common_internal::TypeVariant(); From 33156b1e59b458ff6c24208dbcb66ace3186ab9b Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 04:28:29 -0700 Subject: [PATCH 21/87] No public description PiperOrigin-RevId: 915359111 --- extensions/select_optimization.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index 0f09773ae..44da4c48a 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -92,7 +92,7 @@ struct SelectInstruction { // Represents a single qualifier in a traversal path. // TODO(uncreated-issue/51): support variable indexes. using QualifierInstruction = - absl::variant; + std::variant; struct SelectPath { Expr* operand; From a88afaca5106943d9f835cb622be9813b6bdee55 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 05:07:40 -0700 Subject: [PATCH 22/87] No public description PiperOrigin-RevId: 915372200 --- runtime/memory_safety_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/memory_safety_test.cc b/runtime/memory_safety_test.cc index 7e864ecf6..2a09be666 100644 --- a/runtime/memory_safety_test.cc +++ b/runtime/memory_safety_test.cc @@ -73,7 +73,7 @@ struct TestCase { std::string name; std::string expression; absl::flat_hash_map> + std::variant> activation; test::ValueMatcher expected_matcher; bool reference_resolver_enabled = false; From 513af3c2c338c0aabbbef21419018880dc9c23c4 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 22:02:12 -0700 Subject: [PATCH 23/87] No public description PiperOrigin-RevId: 915787513 --- eval/internal/cel_value_equal_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/internal/cel_value_equal_test.cc b/eval/internal/cel_value_equal_test.cc index f52f38916..109a63795 100644 --- a/eval/internal/cel_value_equal_test.cc +++ b/eval/internal/cel_value_equal_test.cc @@ -67,7 +67,7 @@ using ::testing::ValuesIn; struct EqualityTestCase { enum class ErrorKind { kMissingOverload, kMissingIdentifier }; absl::string_view expr; - absl::variant result; + std::variant result; CelValue lhs = CelValue::CreateNull(); CelValue rhs = CelValue::CreateNull(); }; From f62419d04f3a4c12ecf2a802e95d26e33aa2b115 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 22:02:24 -0700 Subject: [PATCH 24/87] No public description PiperOrigin-RevId: 915787600 --- tools/branch_coverage.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/branch_coverage.cc b/tools/branch_coverage.cc index 00ab7cb5a..b5bba3ffe 100644 --- a/tools/branch_coverage.cc +++ b/tools/branch_coverage.cc @@ -71,7 +71,7 @@ struct OtherNode { // Representation for coverage of an AST node. struct CoverageNode { int evaluate_count; - absl::variant kind; + std::variant kind; }; const Type* absl_nullable FindCheckerType(const CheckedExpr& expr, From 7820913cda14e09bbb667c10d65391c1a79fb95d Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 14 May 2026 22:16:10 -0700 Subject: [PATCH 25/87] No public description PiperOrigin-RevId: 915792480 --- common/value.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/common/value.cc b/common/value.cc index 535ddead8..1cd3f54e1 100644 --- a/common/value.cc +++ b/common/value.cc @@ -115,7 +115,7 @@ Type Value::GetRuntimeType() const { namespace { template -struct IsMonostate : std::is_same, absl::monostate> {}; +struct IsMonostate : std::is_same, std::monostate> {}; } // namespace @@ -171,7 +171,7 @@ absl::Status Value::ConvertToJsonArray( google::protobuf::Descriptor::WELLKNOWNTYPE_LISTVALUE); return variant_.Visit(absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError("use of invalid Value"); }, [descriptor_pool, message_factory, json]( @@ -212,7 +212,7 @@ absl::Status Value::ConvertToJsonObject( google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); return variant_.Visit(absl::Overload( - [](absl::monostate) -> absl::Status { + [](std::monostate) -> absl::Status { return absl::InternalError("use of invalid Value"); }, [descriptor_pool, message_factory, json]( @@ -1363,7 +1363,7 @@ Value Value::FromMessage( return absl::visit( absl::Overload(OwningWellKnownTypesValueVisitor{ /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { auto* cloned = message.New(arena); cloned->CopyFrom(message); return ParsedMessageValue(cloned, arena); @@ -1391,7 +1391,7 @@ Value Value::FromMessage( return absl::visit( absl::Overload(OwningWellKnownTypesValueVisitor{ /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { auto* cloned = message.New(arena); cloned->GetReflection()->Swap(cloned, &message); return ParsedMessageValue(cloned, arena); @@ -1422,7 +1422,7 @@ Value Value::WrapMessage( absl::Overload(BorrowingWellKnownTypesValueVisitor{ /* .message = */ message, /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { if (message->GetArena() != arena) { auto* cloned = message->New(arena); cloned->CopyFrom(*message); @@ -1456,7 +1456,7 @@ Value Value::WrapMessageUnsafe( absl::Overload(BorrowingWellKnownTypesValueVisitor{ /* .message = */ message, /* .arena = */ arena, /* .scratch = */ &scratch}, - [&](absl::monostate) -> Value { + [&](std::monostate) -> Value { if (message->GetArena() != arena) { return UnsafeParsedMessageValue(message); } From da45d34071c8fe9f77fcc17e33e518841d382cdc Mon Sep 17 00:00:00 2001 From: Antoine Pietri Date: Mon, 18 May 2026 08:43:19 -0700 Subject: [PATCH 26/87] Add missing include for `google/rpc/status.proto.h`. This code was relying on the transitive inclusion of third_party/cel/cpp/* to provide the type information for the Status proto. This makes the code brittle and prone to breakages when doing internal header refactors. PiperOrigin-RevId: 917253701 --- conformance/BUILD | 1 + conformance/service.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/conformance/BUILD b/conformance/BUILD index ccd2844c9..0ca90a4bc 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -83,6 +83,7 @@ cc_library( "@com_google_googleapis//google/api/expr/conformance/v1alpha1:conformance_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/rpc:code_cc_proto", + "@com_google_googleapis//google/rpc:status_cc_proto", "@com_google_protobuf//:duration_cc_proto", "@com_google_protobuf//:empty_cc_proto", "@com_google_protobuf//:protobuf", diff --git a/conformance/service.cc b/conformance/service.cc index 463334bb5..7e3eded82 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -30,6 +30,7 @@ #include "google/protobuf/struct.pb.h" #include "google/protobuf/timestamp.pb.h" #include "google/rpc/code.pb.h" +#include "google/rpc/status.pb.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" #include "absl/status/status.h" From d475cc6726ef85fefed557c8eb0e400119d13e95 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 20 May 2026 20:28:14 -0700 Subject: [PATCH 27/87] No public description PiperOrigin-RevId: 918791547 --- codelab/network_functions.cc | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/codelab/network_functions.cc b/codelab/network_functions.cc index f4f729827..64f199cb3 100644 --- a/codelab/network_functions.cc +++ b/codelab/network_functions.cc @@ -213,8 +213,7 @@ absl::Status NetworkAddressRepEqual( return absl::OkStatus(); } const NetworkAddressRep rep = content.To(); - absl::optional other_rep = - NetworkAddressRep::Unwrap(other); + std::optional other_rep = NetworkAddressRep::Unwrap(other); ABSL_DCHECK(other_rep.has_value()); *result = cel::BoolValue(rep.IsEqualTo(*other_rep)); return absl::OkStatus(); @@ -311,7 +310,7 @@ cel::Value parseAddress( google::protobuf::Arena* absl_nonnull arena) { std::string buf; absl::string_view addr = str.ToStringView(&buf); - absl::optional rep = NetworkAddressRep::Parse(addr); + std::optional rep = NetworkAddressRep::Parse(addr); if (!rep.has_value()) { return cel::ErrorValue(absl::InvalidArgumentError("invalid address")); } @@ -321,7 +320,7 @@ cel::Value parseAddress( cel::Value parseAddressOrZero(const cel::StringValue& str) { std::string buf; absl::string_view addr = str.ToStringView(&buf); - absl::optional rep = NetworkAddressRep::Parse(addr); + std::optional rep = NetworkAddressRep::Parse(addr); static const NetworkAddressRep kZero; if (!rep.has_value()) { return NetworkAddressRep::MakeValue(kZero); @@ -336,8 +335,7 @@ cel::Value parseAddressMatcher( google::protobuf::Arena* absl_nonnull arena) { std::string buf; absl::string_view addr = str.ToStringView(&buf); - absl::optional rep = - NetworkAddressMatcher::Parse(addr); + std::optional rep = NetworkAddressMatcher::Parse(addr); if (!rep.has_value()) { return cel::ErrorValue( absl::InvalidArgumentError("invalid address matcher")); @@ -365,7 +363,7 @@ cel::Value NetworkAddressRep::MakeValue(const NetworkAddressRep& rep) { cel::OpaqueValueContent::From(rep)); } -absl::optional NetworkAddressRep::Unwrap( +std::optional NetworkAddressRep::Unwrap( const cel::Value& value) { auto opaque = value.AsOpaque(); if (!opaque.has_value() || @@ -381,7 +379,7 @@ absl::optional NetworkAddressRep::Unwrap( return opaque->content().To(); } -absl::optional NetworkAddressRep::Parse( +std::optional NetworkAddressRep::Parse( absl::string_view str) { uint32_t ipv4 = 0; char ipv6[16]; @@ -418,7 +416,7 @@ bool NetworkAddressRep::IsLessThan(const NetworkAddressRep& other) const { return false; } -absl::optional NetworkAddressMatcher::Parse( +std::optional NetworkAddressMatcher::Parse( absl::string_view str) { // range style addr-addr int dash_pos = str.find('-'); From b7096df80e0d7b6facc2943326a1c04cde0f1d27 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 20 May 2026 20:28:15 -0700 Subject: [PATCH 28/87] No public description PiperOrigin-RevId: 918791557 --- eval/public/equality_function_registrar_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eval/public/equality_function_registrar_test.cc b/eval/public/equality_function_registrar_test.cc index 577c4be22..a77a92734 100644 --- a/eval/public/equality_function_registrar_test.cc +++ b/eval/public/equality_function_registrar_test.cc @@ -204,7 +204,7 @@ std::string CelValueEqualTestName( } TEST_P(CelValueEqualImplTypesTest, Basic) { - absl::optional result = CelValueEqualImpl(lhs(), rhs()); + std::optional result = CelValueEqualImpl(lhs(), rhs()); if (lhs().IsNull() || rhs().IsNull()) { if (lhs().IsNull() && rhs().IsNull()) { @@ -286,7 +286,7 @@ const std::vector& NumericValuesNotEqualExample() { using NumericInequalityTest = testing::TestWithParam; TEST_P(NumericInequalityTest, NumericValues) { NumericInequalityTestCase test_case = GetParam(); - absl::optional result = CelValueEqualImpl(test_case.a, test_case.b); + std::optional result = CelValueEqualImpl(test_case.a, test_case.b); EXPECT_TRUE(result.has_value()); EXPECT_EQ(*result, false); } @@ -299,7 +299,7 @@ INSTANTIATE_TEST_SUITE_P( }); TEST(CelValueEqualImplTest, LossyNumericEquality) { - absl::optional result = CelValueEqualImpl( + std::optional result = CelValueEqualImpl( CelValue::CreateDouble( static_cast(std::numeric_limits::max()) - 1), CelValue::CreateInt64(std::numeric_limits::max())); From 719f3eed5919bc964b30c7e06a77a3a7eeb64953 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 20 May 2026 21:47:09 -0700 Subject: [PATCH 29/87] No public description PiperOrigin-RevId: 918818689 --- eval/tests/benchmark_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/tests/benchmark_test.cc b/eval/tests/benchmark_test.cc index fc0c39294..f188dc0b7 100644 --- a/eval/tests/benchmark_test.cc +++ b/eval/tests/benchmark_test.cc @@ -317,7 +317,7 @@ BENCHMARK(BM_PolicySymbolic); class RequestMap : public CelMap { public: - absl::optional operator[](CelValue key) const override { + std::optional operator[](CelValue key) const override { if (!key.IsString()) { return {}; } From a552526a3c58346438cec05cbfe3afeb20657ed6 Mon Sep 17 00:00:00 2001 From: Muhammad Askri Date: Thu, 21 May 2026 12:28:31 -0700 Subject: [PATCH 30/87] Add functions to parse type and function signatures into cel types. PiperOrigin-RevId: 919194450 --- common/internal/BUILD | 8 +- common/internal/signature.cc | 390 +++++++++++++++++++++++- common/internal/signature.h | 21 ++ common/internal/signature_test.cc | 489 ++++++++++++++++++++++++++++-- 4 files changed, 889 insertions(+), 19 deletions(-) diff --git a/common/internal/BUILD b/common/internal/BUILD index 10084b685..48a8dfe8b 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -143,14 +143,17 @@ cc_library( srcs = ["signature.cc"], hdrs = ["signature.h"], deps = [ + "//common:ast", "//common:type", "//common:type_kind", + "//common:type_spec_resolver", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", ], ) @@ -159,11 +162,14 @@ cc_test( srcs = ["signature_test.cc"], deps = [ ":signature", + "//common:ast", "//common:type", + "//common:type_kind", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", "@com_google_protobuf//:protobuf", ], diff --git a/common/internal/signature.cc b/common/internal/signature.cc index f63049878..5c75225f9 100644 --- a/common/internal/signature.cc +++ b/common/internal/signature.cc @@ -15,20 +15,30 @@ #include "common/internal/signature.h" #include +#include +#include #include #include +#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/types/optional.h" +#include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" +#include "common/type_spec_resolver.h" #include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel::common_internal { +// Signature generator helper functions. namespace { void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) { @@ -58,7 +68,7 @@ absl::Status AppendTypeParameters(std::string* result, const Type& type); // Recursively appends a string representation of the given `type` to `result`. // Type parameters are enclosed in angle brackets and separated by commas. - +// // Grammar: // TypeDesc = NamespaceIdentifier [ "<" TypeList ">" ] ; // NamespaceIdentifier = [ "." ] Identifier { "." Identifier } ; @@ -208,4 +218,382 @@ absl::StatusOr MakeOverloadSignature( return result; } + +// Signature parser helper functions. +namespace { + +std::string StripUnescapedWhitespace(std::string_view str) { + std::string result; + result.reserve(str.size()); + bool escaped = false; + for (char c : str) { + if (escaped) { + result.push_back(c); + escaped = false; + continue; + } + if (c == '\\') { + result.push_back(c); + escaped = true; + continue; + } + if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { + continue; + } + result.push_back(c); + } + return result; +} + +absl::optional ParseBuiltinOrWrapper(std::string_view name_str) { + if (name_str == "null") return TypeSpec(NullTypeSpec()); + if (name_str == "bool") return TypeSpec(PrimitiveType::kBool); + if (name_str == "int") return TypeSpec(PrimitiveType::kInt64); + if (name_str == "uint") return TypeSpec(PrimitiveType::kUint64); + if (name_str == "double") return TypeSpec(PrimitiveType::kDouble); + if (name_str == "string") return TypeSpec(PrimitiveType::kString); + if (name_str == "bytes") return TypeSpec(PrimitiveType::kBytes); + if (name_str == "any" || name_str == "google.protobuf.Any") + return TypeSpec(WellKnownTypeSpec::kAny); + if (name_str == "timestamp" || name_str == "google.protobuf.Timestamp") + return TypeSpec(WellKnownTypeSpec::kTimestamp); + if (name_str == "duration" || name_str == "google.protobuf.Duration") + return TypeSpec(WellKnownTypeSpec::kDuration); + if (name_str == "dyn" || name_str == "google.protobuf.Value") + return TypeSpec(DynTypeSpec()); + + // Handle standard Protobuf well-known wrapper types to preserve + // backward compatibility for users migrating YAML configuration files. + if (name_str == "bool_wrapper" || name_str == "google.protobuf.BoolValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + if (name_str == "int_wrapper" || name_str == "google.protobuf.Int64Value" || + name_str == "google.protobuf.Int32Value") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + if (name_str == "uint_wrapper" || name_str == "google.protobuf.UInt64Value" || + name_str == "google.protobuf.UInt32Value") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + if (name_str == "double_wrapper" || + name_str == "google.protobuf.DoubleValue" || + name_str == "google.protobuf.FloatValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + if (name_str == "string_wrapper" || name_str == "google.protobuf.StringValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + if (name_str == "bytes_wrapper" || name_str == "google.protobuf.BytesValue") + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + + if (name_str == "google.protobuf.ListValue") { + return TypeSpec(ListTypeSpec(std::make_unique(DynTypeSpec()))); + } + if (name_str == "google.protobuf.Struct") { + return TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec()))); + } + + return absl::nullopt; +} + +std::string Unescape(std::string_view str) { + size_t first_escape = str.find('\\'); + if (first_escape == std::string_view::npos) { + return std::string(str); + } + std::string result; + result.reserve(str.size()); + result.append(str.substr(0, first_escape)); + bool escaped = false; + for (size_t i = first_escape; i < str.size(); ++i) { + char c = str[i]; + if (escaped) { + result.push_back(c); + escaped = false; + } else if (c == '\\') { + escaped = true; + } else { + result.push_back(c); + } + } + if (escaped) { + result.push_back('\\'); + } + return result; +} + +class SignatureScanner { + public: + explicit SignatureScanner(std::string_view input, + std::string_view error_prefix = "Invalid signature") + : input_(input), error_prefix_(error_prefix) {} + + absl::StatusOr FindTopLevelChar(char target, bool find_last = false) { + size_t found_idx = std::string_view::npos; + int nesting = 0; + bool escaped = false; + // Scanning str for delimiter boundaries while ensuring + // brackets are balanced and escape backslashes are bypassed. + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + continue; + } + if (c == target && nesting == 0) { + if (find_last || found_idx == std::string_view::npos) { + found_idx = i; + } + } + if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + return found_idx; + } + + absl::StatusOr> SplitTopLevel(char delimiter) { + std::vector result; + int nesting = 0; + bool escaped = false; + size_t start = 0; + // Scanning str for delimiter while ensuring brackets are balanced and + // escape backslashes are bypassed. + for (size_t i = 0; i < input_.size(); ++i) { + char c = input_[i]; + if (escaped) { + escaped = false; + continue; + } + if (c == '\\') { + escaped = true; + continue; + } + if (c == delimiter && nesting == 0) { + result.push_back(input_.substr(start, i - start)); + start = i + 1; + } + if (c == '<') { + nesting++; + } else if (c == '>') { + nesting--; + if (nesting < 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + } + } + if (nesting != 0) { + return absl::InvalidArgumentError( + absl::StrCat(error_prefix_, ": mismatched brackets")); + } + result.push_back(input_.substr(start)); + return result; + } + + private: + std::string_view input_; + std::string_view error_prefix_; +}; + +absl::StatusOr> SplitTypeList( + std::string_view params) { + return SignatureScanner(params, "Invalid type signature").SplitTopLevel(','); +} + +absl::StatusOr ParseTypeSignature(std::string_view signature) { + if (signature.empty()) { + return absl::InvalidArgumentError("Empty type signature"); + } + + if (signature[0] == '~') { + std::string_view param_name = signature.substr(1); + if (param_name.empty()) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid type parameter name"); + } + CEL_ASSIGN_OR_RETURN(size_t less_idx, + SignatureScanner(param_name) + .FindTopLevelChar('<', /*find_last=*/false)); + CEL_ASSIGN_OR_RETURN(size_t comma_idx, + SignatureScanner(param_name) + .FindTopLevelChar(',', /*find_last=*/false)); + if (less_idx != std::string_view::npos || + comma_idx != std::string_view::npos) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid type parameter name"); + } + return TypeSpec(ParamTypeSpec(Unescape(param_name))); + } + + CEL_ASSIGN_OR_RETURN(size_t less_idx, + SignatureScanner(signature, "Invalid type signature") + .FindTopLevelChar('<', /*find_last=*/false)); + + std::string name_str; + std::vector params; + + if (less_idx != std::string_view::npos) { + // If the signature contains a '<', it must also contain a matching '>'. + if (signature.back() != '>') { + return absl::InvalidArgumentError( + "Invalid type signature: missing closing >"); + } + name_str = Unescape(signature.substr(0, less_idx)); + std::string_view params_str = + signature.substr(less_idx + 1, signature.size() - less_idx - 2); + CEL_ASSIGN_OR_RETURN(auto param_list, SplitTypeList(params_str)); + for (std::string_view param_str : param_list) { + CEL_ASSIGN_OR_RETURN(auto param, ParseTypeSignature(param_str)); + params.push_back(std::move(param)); + } + } else { + name_str = Unescape(signature); + } + + auto read_param_or_dyn = [¶ms](size_t index) { + auto spec = std::make_unique(DynTypeSpec()); + if (params.size() > index) { + *spec = std::move(params[index]); + } + return spec; + }; + + if (!params.empty()) { + if (ParseBuiltinOrWrapper(name_str).has_value()) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid type signature: ", name_str, + " cannot have type parameters")); + } + } else { + if (auto builtin = ParseBuiltinOrWrapper(name_str); builtin.has_value()) { + return *builtin; + } + } + + if (name_str == "type") { + if (params.size() > 1) { + return absl::InvalidArgumentError( + "Invalid type signature: type expects at most 1 parameter"); + } + return TypeSpec(read_param_or_dyn(0)); + } + + if (name_str == "list") { + if (params.size() > 1) { + return absl::InvalidArgumentError( + "Invalid type signature: list expects at most 1 parameter"); + } + return TypeSpec(ListTypeSpec(read_param_or_dyn(0))); + } + + if (name_str == "map") { + if (!params.empty() && params.size() != 2) { + return absl::InvalidArgumentError( + "Invalid type signature: map expects 0 or 2 parameters"); + } + auto key = read_param_or_dyn(0); + auto value = read_param_or_dyn(1); + return TypeSpec(MapTypeSpec(std::move(key), std::move(value))); + } + + if (name_str == "function") { + auto result_type = read_param_or_dyn(0); + std::vector arg_types; + for (size_t i = 1; i < params.size(); ++i) { + arg_types.push_back(std::move(params[i])); + } + return TypeSpec( + FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + } + + if (name_str.empty() || absl::StrContains(name_str, "..")) { + return absl::InvalidArgumentError( + "Invalid type signature: invalid identifier"); + } + + return TypeSpec(AbstractType(name_str, std::move(params))); +} + +} // namespace + +absl::StatusOr ParseFunctionSignature( + std::string_view signature) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + if (stripped_sig.empty()) { + return absl::InvalidArgumentError("Empty function signature"); + } + + CEL_ASSIGN_OR_RETURN( + size_t paren_idx, + SignatureScanner(stripped_sig, "Invalid function signature") + .FindTopLevelChar('(', /*find_last=*/false)); + + if (paren_idx == std::string_view::npos || stripped_sig.back() != ')') { + return absl::InvalidArgumentError("Invalid function signature"); + } + + std::string_view prefix = std::string_view(stripped_sig).substr(0, paren_idx); + std::string_view args_str = + std::string_view(stripped_sig) + .substr(paren_idx + 1, stripped_sig.size() - paren_idx - 2); + + std::vector arg_types; + ParsedFunctionOverload out; + + CEL_ASSIGN_OR_RETURN(size_t dot_idx, + SignatureScanner(prefix, "Invalid function signature") + .FindTopLevelChar('.', /*find_last=*/true)); + + if (dot_idx != std::string_view::npos) { + out.is_member = true; + std::string_view receiver_str = prefix.substr(0, dot_idx); + std::string_view func_str = prefix.substr(dot_idx + 1); + + CEL_ASSIGN_OR_RETURN(auto receiver_param, ParseTypeSignature(receiver_str)); + arg_types.push_back(std::move(receiver_param)); + out.function_name = Unescape(func_str); + } else { + out.is_member = false; + out.function_name = Unescape(prefix); + } + + if (out.function_name.empty()) { + return absl::InvalidArgumentError( + "Invalid function signature: empty function name"); + } + + if (!args_str.empty()) { + CEL_ASSIGN_OR_RETURN(auto arg_list, SplitTypeList(args_str)); + for (std::string_view arg_str : arg_list) { + CEL_ASSIGN_OR_RETURN(auto arg_param, ParseTypeSignature(arg_str)); + arg_types.push_back(std::move(arg_param)); + } + } + + auto result_type = std::make_unique(DynTypeSpec()); + out.signature_type = + TypeSpec(FunctionTypeSpec(std::move(result_type), std::move(arg_types))); + + return out; +} + +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSignature(stripped_sig)); + return cel::ConvertTypeSpecToType(type_spec, arena, pool); +} + } // namespace cel::common_internal diff --git a/common/internal/signature.h b/common/internal/signature.h index 3f31d8fd1..3fdba4b2e 100644 --- a/common/internal/signature.h +++ b/common/internal/signature.h @@ -20,7 +20,10 @@ #include #include "absl/status/statusor.h" +#include "common/ast.h" #include "common/type.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel::common_internal { @@ -56,6 +59,24 @@ absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member); +// Parses a string type signature directly into a `cel::Type`. +absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, + const google::protobuf::DescriptorPool& pool); + +// A parsed function overload signature with the function name, flag for member +// function, and the function signature type. +struct ParsedFunctionOverload { + std::string function_name; + bool is_member = false; + // The function signature type, configured as a `FunctionTypeSpec`. + TypeSpec signature_type; +}; + +// Parses a string function overload signature directly into a +// `cel::TypeSpec` configured as a `FunctionTypeSpec`. +absl::StatusOr ParseFunctionSignature( + std::string_view signature); + } // namespace cel::common_internal #endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ diff --git a/common/internal/signature_test.cc b/common/internal/signature_test.cc index 8e41c70fb..765055f75 100644 --- a/common/internal/signature_test.cc +++ b/common/internal/signature_test.cc @@ -13,13 +13,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include "absl/base/no_destructor.h" #include "absl/status/status.h" +#include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "common/ast.h" #include "common/type.h" +#include "common/type_kind.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" @@ -38,6 +42,101 @@ google::protobuf::Arena* GetTestArena() { return &*arena; } +void VerifyParsedMatchesType(const TypeSpec& parsed, const Type& original) { + switch (original.kind()) { + case TypeKind::kDyn: + EXPECT_TRUE(parsed.has_dyn()); + break; + case TypeKind::kNull: + EXPECT_TRUE(parsed.has_null()); + break; + case TypeKind::kBool: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kBool); + break; + case TypeKind::kInt: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kInt64); + break; + case TypeKind::kUint: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kUint64); + break; + case TypeKind::kDouble: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kDouble); + break; + case TypeKind::kString: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kString); + break; + case TypeKind::kBytes: + EXPECT_EQ(parsed.primitive(), PrimitiveType::kBytes); + break; + case TypeKind::kAny: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kAny); + break; + case TypeKind::kTimestamp: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kTimestamp); + break; + case TypeKind::kDuration: + EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kDuration); + break; + case TypeKind::kList: + EXPECT_TRUE(parsed.has_list_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.list_type().elem_type(), + original.GetParameters()[0]); + } + break; + case TypeKind::kMap: + EXPECT_TRUE(parsed.has_map_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.map_type().key_type(), + original.GetParameters()[0]); + } + if (original.GetParameters().size() > 1) { + VerifyParsedMatchesType(parsed.map_type().value_type(), + original.GetParameters()[1]); + } + break; + case TypeKind::kBoolWrapper: + case TypeKind::kIntWrapper: + case TypeKind::kUintWrapper: + case TypeKind::kDoubleWrapper: + case TypeKind::kStringWrapper: + case TypeKind::kBytesWrapper: + EXPECT_TRUE(parsed.has_wrapper()); + break; + case TypeKind::kType: + EXPECT_TRUE(parsed.has_type()); + if (!original.GetParameters().empty()) { + VerifyParsedMatchesType(parsed.type(), original.GetParameters()[0]); + } + break; + case TypeKind::kTypeParam: + EXPECT_TRUE(parsed.has_type_param()); + break; + default: + EXPECT_TRUE(parsed.has_abstract_type()); + break; + } +} + +void VerifyTypesEqual(const Type& lhs, const Type& rhs) { + EXPECT_EQ(lhs.kind(), rhs.kind()); + if (lhs.kind() != rhs.kind()) return; + + if (lhs.kind() == TypeKind::kOpaque || lhs.kind() == TypeKind::kStruct || + lhs.kind() == TypeKind::kTypeParam) { + EXPECT_EQ(lhs.name(), rhs.name()); + } + + const auto& lhs_params = lhs.GetParameters(); + const auto& rhs_params = rhs.GetParameters(); + EXPECT_EQ(lhs_params.size(), rhs_params.size()); + if (lhs_params.size() == rhs_params.size()) { + for (size_t i = 0; i < lhs_params.size(); ++i) { + VerifyTypesEqual(lhs_params[i], rhs_params[i]); + } + } +} + struct TypeSignatureTestCase { Type type; std::string expected_signature; @@ -73,10 +172,18 @@ std::vector GetTypeSignatureTestCases() { .type = ListType(GetTestArena(), StringType{}), .expected_signature = "list", }, + { + .type = TypeType(GetTestArena(), IntType{}), + .expected_signature = "type", + }, { .type = ListType(GetTestArena(), TypeParamType("A")), .expected_signature = "list<~A>", }, + { + .type = ListType(GetTestArena(), TypeParamType("A GetTypeSignatureTestCases() { .expected_signature = "map<~B,~C>", }, { - .type = OpaqueType( - GetTestArena(), "bar", - {FunctionType(GetTestArena(), TypeParamType("D"), {})}), - .expected_signature = "bar>", + .type = OpaqueType(GetTestArena(), "bar", + {FunctionType(GetTestArena(), TypeParamType("D"), + {StringType{}, BoolType{}})}), + .expected_signature = "bar>", }, { .type = AnyType{}, @@ -104,10 +211,18 @@ std::vector GetTypeSignatureTestCases() { .type = TimestampType{}, .expected_signature = "timestamp", }, + { + .type = BoolWrapperType{}, + .expected_signature = "bool_wrapper", + }, { .type = IntWrapperType{}, .expected_signature = "int_wrapper", }, + { + .type = UintWrapperType{}, + .expected_signature = "uint_wrapper", + }, { .type = MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( "cel.expr.conformance.proto3.TestAllTypes")), @@ -117,22 +232,32 @@ std::vector GetTypeSignatureTestCases() { .type = ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)")), .expected_signature = R"(list<~a\,b\.\\.\(d\)\\e>)", }, - { - .type = UnknownType{}, - .expected_error = - "Type kind: *unknown* is not supported in CEL declarations", - }, - { - .type = ErrorType{}, - .expected_error = - "Type kind: *error* is not supported in CEL declarations", - }, }; } +TEST(TypeSignatureTest, UnsupportedTypes) { + EXPECT_THAT(common_internal::MakeTypeSignature(UnknownType{}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Type kind: *unknown* is not supported"))); + + EXPECT_THAT(common_internal::MakeTypeSignature(ErrorType{}), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Type kind: *error* is not supported"))); +} + INSTANTIATE_TEST_SUITE_P(TypeIdTest, TypeSignatureTest, ValuesIn(GetTypeSignatureTestCases())); +TEST_P(TypeSignatureTest, ParseTypeCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty() && param.expected_error.empty()) { + auto parsed = ParseType(param.expected_signature, GetTestArena(), + *GetTestingDescriptorPool()); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + VerifyTypesEqual(*parsed, param.type); + } +} + struct OverloadSignatureTestCase { std::string function_name = "hello"; std::vector args; @@ -202,10 +327,18 @@ std::vector GetOverloadSignatureTestCases() { .args = {TimestampType{}}, .expected_signature = "hello(timestamp)", }, + { + .args = {BoolWrapperType{}}, + .expected_signature = "hello(bool_wrapper)", + }, { .args = {IntWrapperType{}}, .expected_signature = "hello(int_wrapper)", }, + { + .args = {UintWrapperType{}}, + .expected_signature = "hello(uint_wrapper)", + }, { .args = {MessageType( GetTestingDescriptorPool()->FindMessageTypeByName( @@ -213,9 +346,6 @@ std::vector GetOverloadSignatureTestCases() { .expected_signature = "hello(cel.expr.conformance.proto3.TestAllTypes)", }, - {.args = {}, - .is_member = true, - .expected_error = "Member function with no receiver"}, { .args = {StringType{}}, .is_member = true, @@ -231,6 +361,18 @@ std::vector GetOverloadSignatureTestCases() { .is_member = true, .expected_signature = "string.hello(bool,dyn)", }, + { + .function_name = "hello", + .args = {OpaqueType(GetTestArena(), "bar", + {TypeParamType("dummy.type")})}, + .is_member = true, + .expected_signature = R"(bar<~dummy\.type>.hello())", + }, + { + .function_name = "inspect", + .args = {Type(TypeType(GetTestArena(), StringType{}))}, + .expected_signature = "inspect(type)", + }, { .function_name = R"(h.(e),l\o)", .args = {StringType{}, @@ -242,8 +384,321 @@ std::vector GetOverloadSignatureTestCases() { }; } +TEST(OverloadSignatureTest, MemberFunctionNoReceiverError) { + auto signature = common_internal::MakeOverloadSignature("hello", {}, true); + EXPECT_THAT(signature, + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Member function with no receiver"))); +} + INSTANTIATE_TEST_SUITE_P(OverloadIdTest, OverloadSignatureTest, ValuesIn(GetOverloadSignatureTestCases())); +TEST_P(OverloadSignatureTest, ExhaustiveFunctionParseCheck) { + const auto& param = GetParam(); + if (!param.expected_signature.empty()) { + auto parsed = ParseFunctionSignature(param.expected_signature); + ASSERT_THAT(parsed, ::absl_testing::IsOk()); + EXPECT_EQ(parsed->function_name, param.function_name); + EXPECT_EQ(parsed->is_member, param.is_member); + EXPECT_TRUE(parsed->signature_type.has_function()); + const auto& func = parsed->signature_type.function(); + for (size_t i = 0; i < param.args.size(); ++i) { + VerifyParsedMatchesType(func.arg_types()[i], param.args[i]); + } + } +} + +TEST(ParseSignatureTest, ProtoParsing) { + ASSERT_OK_AND_ASSIGN( + auto t1, ParseType("int", GetTestArena(), *GetTestingDescriptorPool())); + EXPECT_TRUE(t1.IsInt()); + + ASSERT_OK_AND_ASSIGN(auto t2, ParseType("list<~A>", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(t2.IsList()); + + ASSERT_OK_AND_ASSIGN(auto t3, ParseType(R"(~abc\)", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(t3.IsTypeParam()); + EXPECT_EQ(t3.GetTypeParam().name(), R"(abc\)"); + + ASSERT_OK_AND_ASSIGN(auto w1, + ParseType("google.protobuf.BoolValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w1.IsBoolWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w2, + ParseType("google.protobuf.Int64Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w2.IsIntWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w3, + ParseType("google.protobuf.Int32Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w3.IsIntWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w4, + ParseType("google.protobuf.UInt64Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w4.IsUintWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w5, + ParseType("google.protobuf.UInt32Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w5.IsUintWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w6, + ParseType("google.protobuf.DoubleValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w6.IsDoubleWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w7, + ParseType("google.protobuf.FloatValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w7.IsDoubleWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w8, + ParseType("google.protobuf.StringValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w8.IsStringWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w9, + ParseType("google.protobuf.BytesValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w9.IsBytesWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w10, ParseType("string_wrapper", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w10.IsStringWrapper()); + + ASSERT_OK_AND_ASSIGN(auto w11, ParseType("bytes_wrapper", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(w11.IsBytesWrapper()); + + ASSERT_OK_AND_ASSIGN(auto gp_any, + ParseType("google.protobuf.Any", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_any.IsAny()); + + ASSERT_OK_AND_ASSIGN(auto gp_timestamp, + ParseType("google.protobuf.Timestamp", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_timestamp.IsTimestamp()); + + ASSERT_OK_AND_ASSIGN(auto gp_duration, + ParseType("google.protobuf.Duration", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_duration.IsDuration()); + + ASSERT_OK_AND_ASSIGN(auto gp_value, + ParseType("google.protobuf.Value", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_value.IsDyn()); + + ASSERT_OK_AND_ASSIGN(auto gp_list_value, + ParseType("google.protobuf.ListValue", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_list_value.IsList()); + + ASSERT_OK_AND_ASSIGN(auto gp_struct, + ParseType("google.protobuf.Struct", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(gp_struct.IsMap()); + + // Legal whitespace handling tests + ASSERT_OK_AND_ASSIGN(auto ws_type1, + ParseType("map < int , string > ", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(ws_type1.IsMap()); + + ASSERT_OK_AND_ASSIGN(auto ws_type2, + ParseType("map\t<\nint\r,\tstring\n>\r", GetTestArena(), + *GetTestingDescriptorPool())); + EXPECT_TRUE(ws_type2.IsMap()); +} + +TEST(ParseSignatureTest, FunctionParsing) { + ASSERT_OK_AND_ASSIGN(auto f1, ParseFunctionSignature("hello(string)")); + EXPECT_TRUE(f1.signature_type.has_function()); + EXPECT_EQ(f1.signature_type.function().arg_types().size(), 1); + + // Legal whitespace handling tests + ASSERT_OK_AND_ASSIGN(auto ws_func1, + ParseFunctionSignature(" hello ( string ) ")); + EXPECT_TRUE(ws_func1.signature_type.has_function()); + EXPECT_EQ(ws_func1.signature_type.function().arg_types().size(), 1); + + ASSERT_OK_AND_ASSIGN(auto ws_func2, + ParseFunctionSignature("\thello\n(\rstring\t)\n\r")); + EXPECT_TRUE(ws_func2.signature_type.has_function()); + EXPECT_EQ(ws_func2.signature_type.function().arg_types().size(), 1); + + ASSERT_OK_AND_ASSIGN(auto f2, ParseFunctionSignature("a.b.c()")); + EXPECT_TRUE(f2.is_member); + EXPECT_EQ(f2.function_name, "c"); +} + +TEST(ParseSignatureTest, ParsingErrors) { + // Mismatched template brackets and parentheses. + EXPECT_THAT( + ParseType("list>", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT( + ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseType("list><", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("hello(list>)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("hello(list)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT(ParseFunctionSignature("foo", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("list expects at most 1 parameter"))); + EXPECT_THAT( + ParseType("map", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + EXPECT_THAT(ParseType("map", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("map expects 0 or 2 parameters"))); + + // Enforcing valid function and identifier names. + EXPECT_THAT(ParseFunctionSignature("()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + EXPECT_THAT(ParseFunctionSignature("string.()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("empty function name"))); + + // Missing closing operators and boundary checks. + EXPECT_THAT( + ParseType("listfoo", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("missing closing >"))); + + EXPECT_THAT(ParseFunctionSignature("hello>(string)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + EXPECT_THAT( + ParseType("list<", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("map", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("map int, string>", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); + + EXPECT_THAT(ParseType("list", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + + EXPECT_THAT(ParseFunctionSignature("a..b.c()"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + EXPECT_THAT( + ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); + + EXPECT_THAT( + ParseType("~list", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid type signature"))); + + // Checks that builtin types cannot have type parameters. + EXPECT_THAT( + ParseType("int", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(ParseSignatureTest, MessageTypeWithParamsError) { + EXPECT_THAT(ParseType("cel.expr.conformance.proto3.TestAllTypes", + GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("cannot have type parameters"))); +} + +TEST(ParseSignatureTest, MissingClosingParenthesisError) { + EXPECT_THAT(ParseFunctionSignature("hello(string"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); + EXPECT_THAT(ParseFunctionSignature("hello)"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); +} + +TEST(ParseSignatureTest, NestedDotsNonMember) { + auto f1 = ParseFunctionSignature( + "my_opaque()"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_FALSE(f1->is_member); + EXPECT_EQ(f1->function_name, + "my_opaque"); +} + +TEST(ParseSignatureTest, OverlyComplexSignatures) { + auto t1 = ParseType("map>,map>>", + GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t1, ::absl_testing::IsOk()); + EXPECT_TRUE(t1->IsMap()); + + auto t2 = ParseType(R"(~abc\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t2, ::absl_testing::IsOk()); + EXPECT_TRUE(t2->IsTypeParam()); + EXPECT_EQ(t2->GetTypeParam().name(), R"(abc\)"); + + auto t3 = + ParseType(R"(~abc\\\\)", GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(t3, ::absl_testing::IsOk()); + EXPECT_TRUE(t3->IsTypeParam()); + EXPECT_EQ(t3->GetTypeParam().name(), R"(abc\\)"); + + auto f1 = ParseFunctionSignature( + "bar>,map>.func(string)"); + ASSERT_THAT(f1, ::absl_testing::IsOk()); + EXPECT_TRUE(f1->is_member); + EXPECT_EQ(f1->function_name, "func"); + EXPECT_TRUE(f1->signature_type.has_function()); + EXPECT_EQ(f1->signature_type.function().arg_types().size(), 2); +} + +TEST(ParseSignatureTest, EmptyOrWhitespaceErrors) { + EXPECT_THAT(ParseType("", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); + EXPECT_THAT(ParseFunctionSignature(""), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty function signature"))); + EXPECT_THAT(ParseType("list>", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Empty type signature"))); +} + } // namespace } // namespace cel::common_internal From 6fd7030b954562c5e5c1c1185066e80cad29fd25 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 21 May 2026 18:17:46 -0700 Subject: [PATCH 31/87] not yet exported PiperOrigin-RevId: 919358445 --- common/expr_factory.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/expr_factory.h b/common/expr_factory.h index 773217ad9..5607d8deb 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -32,6 +32,7 @@ namespace cel { class MacroExprFactory; class ParserMacroExprFactory; +class OptimizerExprFactory; class ExprFactory { protected: @@ -378,6 +379,7 @@ class ExprFactory { private: friend class MacroExprFactory; friend class ParserMacroExprFactory; + friend class OptimizerExprFactory; ExprFactory() : accu_var_(kAccumulatorVariableName) {} From ec82288de1338c6d7763fd722d52c3636965ca1e Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 21 May 2026 20:41:46 -0700 Subject: [PATCH 32/87] No public description PiperOrigin-RevId: 919408845 --- .../descriptor_pool_type_introspector.cc | 12 +++--- .../descriptor_pool_type_introspector_test.cc | 4 +- checker/internal/type_check_env.cc | 10 ++--- checker/internal/type_checker_builder_impl.cc | 6 +-- .../type_checker_builder_impl_test.cc | 2 +- checker/internal/type_checker_impl.cc | 20 +++++----- checker/internal/type_inference_context.cc | 8 ++-- .../internal/type_inference_context_test.cc | 40 +++++++++---------- 8 files changed, 51 insertions(+), 51 deletions(-) diff --git a/checker/internal/descriptor_pool_type_introspector.cc b/checker/internal/descriptor_pool_type_introspector.cc index f6001e947..da4f4430b 100644 --- a/checker/internal/descriptor_pool_type_introspector.cc +++ b/checker/internal/descriptor_pool_type_introspector.cc @@ -35,7 +35,7 @@ namespace { // Standard implementation for field lookups. // Avoids building a FieldTable and just checks the DescriptorPool directly. -absl::StatusOr> +absl::StatusOr> FindStructTypeFieldByNameDirectly( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, absl::string_view type, absl::string_view name) { @@ -60,7 +60,7 @@ FindStructTypeFieldByNameDirectly( // Standard implementation for listing fields. // Avoids building a FieldTable and just checks the DescriptorPool directly. absl::StatusOr< - absl::optional>> + std::optional>> ListStructTypeFieldsDirectly( const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, absl::string_view type) { @@ -88,7 +88,7 @@ ListStructTypeFieldsDirectly( using Field = DescriptorPoolTypeIntrospector::Field; -absl::StatusOr> +absl::StatusOr> DescriptorPoolTypeIntrospector::FindTypeImpl(absl::string_view name) const { const google::protobuf::Descriptor* absl_nullable descriptor = descriptor_pool_->FindMessageTypeByName(name); @@ -103,7 +103,7 @@ DescriptorPoolTypeIntrospector::FindTypeImpl(absl::string_view name) const { return absl::nullopt; } -absl::StatusOr> +absl::StatusOr> DescriptorPoolTypeIntrospector::FindEnumConstantImpl( absl::string_view type, absl::string_view value) const { const google::protobuf::EnumDescriptor* absl_nullable enum_descriptor = @@ -124,7 +124,7 @@ DescriptorPoolTypeIntrospector::FindEnumConstantImpl( return absl::nullopt; } -absl::StatusOr> +absl::StatusOr> DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( absl::string_view type, absl::string_view name) const { if (!use_json_name_) { @@ -151,7 +151,7 @@ DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( } absl::StatusOr< - absl::optional>> + std::optional>> DescriptorPoolTypeIntrospector::ListFieldsForStructTypeImpl( absl::string_view type) const { if (!use_json_name_) { diff --git a/checker/internal/descriptor_pool_type_introspector_test.cc b/checker/internal/descriptor_pool_type_introspector_test.cc index e2fdc9d40..456798744 100644 --- a/checker/internal/descriptor_pool_type_introspector_test.cc +++ b/checker/internal/descriptor_pool_type_introspector_test.cc @@ -117,7 +117,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, internal::GetTestingDescriptorPool()); introspector.set_use_json_name(true); - absl::StatusOr> field = + absl::StatusOr> field = introspector.FindStructTypeFieldByName( "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); @@ -132,7 +132,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructType) { DescriptorPoolTypeIntrospector introspector( internal::GetTestingDescriptorPool()); absl::StatusOr< - absl::optional>> + std::optional>> fields = introspector.ListFieldsForStructType( "cel.expr.conformance.proto3.TestAllTypes"); ASSERT_THAT(fields, IsOkAndHolds(Optional(SizeIs(260)))); diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index c080326cb..763d9ba46 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -48,7 +48,7 @@ const FunctionDecl* absl_nullable TypeCheckEnv::LookupFunction( return nullptr; } -absl::StatusOr> TypeCheckEnv::LookupTypeName( +absl::StatusOr> TypeCheckEnv::LookupTypeName( absl::string_view name) const { for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { @@ -60,7 +60,7 @@ absl::StatusOr> TypeCheckEnv::LookupTypeName( return absl::nullopt; } -absl::StatusOr> TypeCheckEnv::LookupEnumConstant( +absl::StatusOr> TypeCheckEnv::LookupEnumConstant( absl::string_view type, absl::string_view value) const { for (auto iter = type_providers_.begin(); iter != type_providers_.end(); ++iter) { @@ -77,9 +77,9 @@ absl::StatusOr> TypeCheckEnv::LookupEnumConstant( return absl::nullopt; } -absl::StatusOr> TypeCheckEnv::LookupTypeConstant( +absl::StatusOr> TypeCheckEnv::LookupTypeConstant( google::protobuf::Arena* absl_nonnull arena, absl::string_view name) const { - CEL_ASSIGN_OR_RETURN(absl::optional type, LookupTypeName(name)); + CEL_ASSIGN_OR_RETURN(std::optional type, LookupTypeName(name)); if (type.has_value()) { return MakeVariableDecl(type->name(), TypeType(arena, *type)); } @@ -94,7 +94,7 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( return absl::nullopt; } -absl::StatusOr> TypeCheckEnv::LookupStructField( +absl::StatusOr> TypeCheckEnv::LookupStructField( absl::string_view type_name, absl::string_view field_name) const { // Check the type providers in registration order. // Note: this doesn't allow for shadowing a type with a subset type of the diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 94a05602e..85b581e83 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -158,8 +158,8 @@ absl::StatusOr MergeFunctionDecls( return merged_decl; } -absl::optional FilterDecl(FunctionDecl decl, - const TypeCheckerSubset& subset) { +std::optional FilterDecl(FunctionDecl decl, + const TypeCheckerSubset& subset) { FunctionDecl filtered; std::string name = decl.release_name(); std::vector overloads = decl.release_overloads(); @@ -283,7 +283,7 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( for (FunctionDeclRecord& fn : config.functions) { FunctionDecl decl = std::move(fn.decl); if (subset != nullptr) { - absl::optional filtered = + std::optional filtered = FilterDecl(std::move(decl), *subset); if (!filtered.has_value()) { continue; diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc index f7a3dff97..494e7e440 100644 --- a/checker/internal/type_checker_builder_impl_test.cc +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -144,7 +144,7 @@ TEST(ContextDeclsTest, CustomStructNotSupported) { {}); class MyTypeProvider : public cel::TypeIntrospector { public: - absl::StatusOr> FindTypeImpl( + absl::StatusOr> FindTypeImpl( absl::string_view name) const override { if (name == "com.example.MyStruct") { return common_internal::MakeBasicStructType("com.example.MyStruct"); diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 2472d7def..1ce871255 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -379,7 +379,7 @@ class ResolveVisitor : public AstVisitorBase { // Lookup message type by name to support WellKnownType creation. CEL_ASSIGN_OR_RETURN( - absl::optional field_info, + std::optional field_info, env_->LookupStructField(resolved_name, field.name())); if (!field_info.has_value()) { ReportUndefinedField(field.id(), field.name(), resolved_name); @@ -405,8 +405,8 @@ class ResolveVisitor : public AstVisitorBase { return absl::OkStatus(); } - absl::optional CheckFieldType(int64_t expr_id, const Type& operand_type, - absl::string_view field_name); + std::optional CheckFieldType(int64_t expr_id, const Type& operand_type, + absl::string_view field_name); void HandleOptSelect(const Expr& expr); void HandleBlockIndex(const Expr* expr); @@ -919,7 +919,7 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, arg_types.push_back(GetDeducedType(&expr.call_expr().args()[i])); } - absl::optional resolution = + std::optional resolution = inference_context_->ResolveOverload(decl, arg_types, is_receiver); if (!resolution.has_value()) { @@ -968,7 +968,7 @@ const VariableDecl* absl_nullable ResolveVisitor::LookupGlobalIdentifier( if (const VariableDecl* decl = env_->LookupVariable(name); decl != nullptr) { return decl; } - absl::StatusOr> constant = + absl::StatusOr> constant = env_->LookupTypeConstant(arena_, name); if (!constant.ok()) { @@ -1079,9 +1079,9 @@ void ResolveVisitor::ResolveQualifiedIdentifier( } } -absl::optional ResolveVisitor::CheckFieldType(int64_t id, - const Type& operand_type, - absl::string_view field) { +std::optional ResolveVisitor::CheckFieldType(int64_t id, + const Type& operand_type, + absl::string_view field) { if (operand_type.kind() == TypeKind::kDyn || operand_type.kind() == TypeKind::kAny) { return DynType(); @@ -1137,7 +1137,7 @@ void ResolveVisitor::ResolveSelectOperation(const Expr& expr, const Expr& operand) { const Type& operand_type = GetDeducedType(&operand); - absl::optional result_type; + std::optional result_type; int64_t id = expr.id(); // Support short-hand optional chaining. if (operand_type.IsOptional()) { @@ -1184,7 +1184,7 @@ void ResolveVisitor::HandleOptSelect(const Expr& expr) { operand_type = operand_type.GetOptional().GetParameter(); } - absl::optional field_type = CheckFieldType( + std::optional field_type = CheckFieldType( expr.id(), operand_type, field->const_expr().string_value()); if (!field_type.has_value()) { types_[&expr] = ErrorType(); diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 96d985071..5b909d982 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -133,7 +133,7 @@ FunctionOverloadInstance InstantiateFunctionOverload( // Converts a wrapper type to its corresponding primitive type. // Returns nullopt if the type is not a wrapper type. -absl::optional WrapperToPrimitive(const Type& t) { +std::optional WrapperToPrimitive(const Type& t) { switch (t.kind()) { case TypeKind::kBoolWrapper: return BoolType(); @@ -286,7 +286,7 @@ bool TypeInferenceContext::IsAssignableInternal( } // Type is as concrete as it can be under current substitutions. - if (absl::optional wrapped_type = WrapperToPrimitive(to_subs); + if (std::optional wrapped_type = WrapperToPrimitive(to_subs); wrapped_type.has_value()) { return from_subs.IsNull() || IsAssignableInternal(*wrapped_type, from_subs, @@ -531,11 +531,11 @@ bool TypeInferenceContext::IsAssignableWithConstraints( return false; } -absl::optional +std::optional TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, absl::Span argument_types, bool is_receiver) { - absl::optional result_type; + std::optional result_type; std::vector matching_overloads; for (const auto& ovl : decl.overloads()) { diff --git a/checker/internal/type_inference_context_test.cc b/checker/internal/type_inference_context_test.cc index d1bf7fa6d..458d08ff1 100644 --- a/checker/internal/type_inference_context_test.cc +++ b/checker/internal/type_inference_context_test.cc @@ -291,7 +291,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadBasic) { MakeOverloadDecl("add_double", DoubleType(), DoubleType(), DoubleType()))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), IntType()}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_THAT(resolution->result_type, IsTypeKind(TypeKind::kInt)); @@ -309,7 +309,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadFails) { MakeOverloadDecl("add_double", DoubleType(), DoubleType(), DoubleType()))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), DoubleType()}, false); ASSERT_FALSE(resolution.has_value()); } @@ -324,7 +324,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithParamsNoMatch) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), DoubleType()}, false); ASSERT_FALSE(resolution.has_value()); } @@ -341,7 +341,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithMixedParamsMatch) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_a, list_of_a}, false); ASSERT_TRUE(resolution.has_value()) << context.DebugString(); } @@ -359,7 +359,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithMixedParamsMatch2) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_a, list_of_int}, false); ASSERT_TRUE(resolution.has_value()) << context.DebugString(); EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("equals"))); @@ -375,7 +375,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithParamsMatches) { "_==_", MakeOverloadDecl("equals", BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {IntType(), IntType()}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_TRUE(resolution->result_type.IsBool()); @@ -394,7 +394,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsMatch) { Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); - absl::optional resolution = + std::optional resolution = context.ResolveOverload( decl, {list_of_a_instance, ListType(&arena, IntType())}, false); ASSERT_TRUE(resolution.has_value()); @@ -407,7 +407,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsMatch) { EXPECT_THAT(resolution->overloads, ElementsAre(IsOverloadDecl("add_list"))); - absl::optional resolution2 = + std::optional resolution2 = context.ResolveOverload( decl, {ListType(&arena, IntType()), list_of_a_instance}, false); ASSERT_TRUE(resolution2.has_value()); @@ -433,7 +433,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithNestedParamsNoMatch) { Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_a_instance, IntType()}, false); EXPECT_FALSE(resolution.has_value()); } @@ -450,13 +450,13 @@ TEST(TypeInferenceContextTest, InferencesAccumulate) { Type list_of_a_instance = context.InstantiateTypeParams(list_of_a); - absl::optional resolution1 = + std::optional resolution1 = context.ResolveOverload(decl, {list_of_a_instance, list_of_a_instance}, false); ASSERT_TRUE(resolution1.has_value()); EXPECT_TRUE(resolution1->result_type.IsList()); - absl::optional resolution2 = + std::optional resolution2 = context.ResolveOverload( decl, {resolution1->result_type, ListType(&arena, IntType())}, false); ASSERT_TRUE(resolution2.has_value()); @@ -480,7 +480,7 @@ TEST(TypeInferenceContextTest, DebugString) { MakeFunctionDecl("_+_", MakeOverloadDecl("add_list", list_of_a, list_of_a, list_of_a))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {list_of_int, list_of_int}, false); ASSERT_TRUE(resolution.has_value()); EXPECT_TRUE(resolution->result_type.IsList()); @@ -517,7 +517,7 @@ class TypeInferenceContextWrapperTypesTest TEST_P(TypeInferenceContextWrapperTypesTest, ResolvePrimitiveArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapper_type, test_case.wrapped_primitive_type}, @@ -534,7 +534,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, ResolvePrimitiveArg) { TEST_P(TypeInferenceContextWrapperTypesTest, ResolveWrapperArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload( ternary_decl_, {BoolType(), test_case.wrapper_type, test_case.wrapper_type}, false); @@ -550,7 +550,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, ResolveWrapperArg) { TEST_P(TypeInferenceContextWrapperTypesTest, ResolveNullArg) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapper_type, NullType()}, false); @@ -566,7 +566,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, ResolveNullArg) { TEST_P(TypeInferenceContextWrapperTypesTest, NullWidens) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), NullType(), test_case.wrapper_type}, false); @@ -582,7 +582,7 @@ TEST_P(TypeInferenceContextWrapperTypesTest, NullWidens) { TEST_P(TypeInferenceContextWrapperTypesTest, PrimitiveWidens) { const TypeInferenceContextWrapperTypesTestCase& test_case = GetParam(); - absl::optional resolution = + std::optional resolution = context_.ResolveOverload(ternary_decl_, {BoolType(), test_case.wrapped_primitive_type, test_case.wrapper_type}, @@ -622,7 +622,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithUnionTypePromotion) { /*result_type=*/TypeParamType("A"), BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {BoolType(), NullType(), IntWrapperType()}, false); ASSERT_TRUE(resolution.has_value()); @@ -648,7 +648,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithTypeType) { TypeType(&arena, TypeParamType("A")), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(decl, {StringType()}, false); ASSERT_TRUE(resolution.has_value()); @@ -680,7 +680,7 @@ TEST(TypeInferenceContextTest, ResolveOverloadWithInferredTypeType) { BoolType(), TypeParamType("A"), TypeParamType("A")))); - absl::optional resolution = + std::optional resolution = context.ResolveOverload(to_type_decl, {StringType()}, false); ASSERT_TRUE(resolution.has_value()); From 833bb0c8fe93dc9bf5e2971c38254f80d3a42c1f Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 21 May 2026 20:46:26 -0700 Subject: [PATCH 33/87] No public description PiperOrigin-RevId: 919410392 --- testutil/test_macros.cc | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/testutil/test_macros.cc b/testutil/test_macros.cc index 158135762..672439dc5 100644 --- a/testutil/test_macros.cc +++ b/testutil/test_macros.cc @@ -37,9 +37,8 @@ bool IsCelNamespace(const Expr& target) { return target.has_ident_expr() && target.ident_expr().name() == "cel"; } -absl::optional CelBlockMacroExpander(MacroExprFactory& factory, - Expr& target, - absl::Span args) { +std::optional CelBlockMacroExpander(MacroExprFactory& factory, + Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } @@ -51,9 +50,8 @@ absl::optional CelBlockMacroExpander(MacroExprFactory& factory, return factory.NewCall("cel.@block", args); } -absl::optional CelIndexMacroExpander(MacroExprFactory& factory, - Expr& target, - absl::Span args) { +std::optional CelIndexMacroExpander(MacroExprFactory& factory, + Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } @@ -70,9 +68,9 @@ absl::optional CelIndexMacroExpander(MacroExprFactory& factory, return factory.NewIdent(absl::StrCat("@index", index)); } -absl::optional CelIterVarMacroExpander(MacroExprFactory& factory, - Expr& target, - absl::Span args) { +std::optional CelIterVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } @@ -94,9 +92,9 @@ absl::optional CelIterVarMacroExpander(MacroExprFactory& factory, unique_arg.const_expr().int_value())); } -absl::optional CelAccuVarMacroExpander(MacroExprFactory& factory, - Expr& target, - absl::Span args) { +std::optional CelAccuVarMacroExpander(MacroExprFactory& factory, + Expr& target, + absl::Span args) { if (!IsCelNamespace(target)) { return absl::nullopt; } From e00189c323d42b0cacafc320aa890eb4d630d394 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Sat, 23 May 2026 02:26:10 -0700 Subject: [PATCH 34/87] No public description PiperOrigin-RevId: 920098804 --- runtime/activation_test.cc | 4 ++-- runtime/function_registry.cc | 5 ++--- runtime/function_registry_test.cc | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/runtime/activation_test.cc b/runtime/activation_test.cc index e6a74f027..4303116a3 100644 --- a/runtime/activation_test.cc +++ b/runtime/activation_test.cc @@ -326,7 +326,7 @@ TEST_F(ActivationTest, MoveAssignment) { "val_provided", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) - -> absl::StatusOr> { return IntValue(42); })); + -> absl::StatusOr> { return IntValue(42); })); moved_from.SetUnknownPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), @@ -377,7 +377,7 @@ TEST_F(ActivationTest, MoveCtor) { "val_provided", [](absl::string_view name, const google::protobuf::DescriptorPool* absl_nonnull, google::protobuf::MessageFactory* absl_nonnull, google::protobuf::Arena* absl_nonnull) - -> absl::StatusOr> { return IntValue(42); })); + -> absl::StatusOr> { return IntValue(42); })); moved_from.SetUnknownPatterns( {AttributePattern("var1", {AttributeQualifierPattern::OfString("field1")}), diff --git a/runtime/function_registry.cc b/runtime/function_registry.cc index ac1e53eb5..59f267255 100644 --- a/runtime/function_registry.cc +++ b/runtime/function_registry.cc @@ -44,14 +44,13 @@ class ActivationFunctionProviderImpl public: ActivationFunctionProviderImpl() = default; - absl::StatusOr> GetFunction( + absl::StatusOr> GetFunction( const cel::FunctionDescriptor& descriptor, const cel::ActivationInterface& activation) const override { std::vector overloads = activation.FindFunctionOverloads(descriptor.name()); - absl::optional matching_overload = - absl::nullopt; + std::optional matching_overload = absl::nullopt; for (const auto& overload : overloads) { if (overload.descriptor.ShapeMatches(descriptor)) { diff --git a/runtime/function_registry_test.cc b/runtime/function_registry_test.cc index af7f5bc06..53916777a 100644 --- a/runtime/function_registry_test.cc +++ b/runtime/function_registry_test.cc @@ -120,7 +120,7 @@ TEST(FunctionRegistryTest, DefaultLazyProviderNoOverloadFound) { ASSERT_THAT(providers, SizeIs(1)); const FunctionProvider& provider = providers[0].provider; ASSERT_OK_AND_ASSIGN( - absl::optional func, + std::optional func, provider.GetFunction({"LazyFunc", false, {cel::Kind::kInt64}}, activation)); @@ -146,7 +146,7 @@ TEST(FunctionRegistryTest, DefaultLazyProviderReturnsImpl) { ASSERT_THAT(providers, SizeIs(1)); const FunctionProvider& provider = providers[0].provider; ASSERT_OK_AND_ASSIGN( - absl::optional func, + std::optional func, provider.GetFunction( FunctionDescriptor("LazyFunction", false, {Kind::kInt}), activation)); From 267f4de9814c320e61b20cdd8fbe6580f2e57ac3 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Sat, 23 May 2026 02:51:01 -0700 Subject: [PATCH 35/87] No public description PiperOrigin-RevId: 920105999 --- extensions/select_optimization.cc | 16 ++++++++-------- extensions/select_optimization_test.cc | 5 +++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index 44da4c48a..42cad0f92 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -153,7 +153,7 @@ Expr MakeSelectPathExpr( // Returns a single select operation based on the inferred type of the operand // and the field name. If the operand type doesn't define the field, returns // nullopt. -absl::optional GetSelectInstruction( +std::optional GetSelectInstruction( const StructType& runtime_type, PlannerContext& planner_context, absl::string_view field_name) { auto field_or = planner_context.type_reflector() @@ -407,13 +407,13 @@ class RewriterImpl : public AstRewriterBase { // support message traversal. const TypeSpec checker_type = ast_.GetTypeOrDyn(operand.id()); - absl::optional rt_type = + std::optional rt_type = (checker_type.has_message_type()) ? GetRuntimeType(checker_type.message_type().type()) : absl::nullopt; if (rt_type.has_value() && (*rt_type).Is()) { const StructType& runtime_type = rt_type->GetStruct(); - absl::optional field_or = + std::optional field_or = GetSelectInstruction(runtime_type, planner_context_, field_name); if (field_or.has_value()) { candidates_[&expr] = std::move(field_or).value(); @@ -538,7 +538,7 @@ class RewriterImpl : public AstRewriterBase { return candidates_.find(operand) != candidates_.end(); } - absl::optional GetRuntimeType(absl::string_view type_name) { + std::optional GetRuntimeType(absl::string_view type_name) { return planner_context_.type_reflector().FindType(type_name).value_or( absl::nullopt); } @@ -582,14 +582,14 @@ class OptimizedSelectImpl { AttributeTrail GetAttributeTrail(const AttributeTrail& operand_trail) const; - absl::optional attribute() const { return attribute_; } + std::optional attribute() const { return attribute_; } const std::vector& qualifiers() const { return qualifiers_; } private: - absl::optional attribute_; + std::optional attribute_; std::vector select_path_; std::vector qualifiers_; bool presence_test_; @@ -597,7 +597,7 @@ class OptimizedSelectImpl { }; // Check for unknowns or missing attributes. -absl::StatusOr> CheckForMarkedAttributes( +absl::StatusOr> CheckForMarkedAttributes( ExecutionFrameBase& frame, const AttributeTrail& attribute_trail) { if (attribute_trail.empty()) { return absl::nullopt; @@ -715,7 +715,7 @@ absl::Status StackMachineImpl::Evaluate(ExecutionFrame* frame) const { // select arguments. // TODO(uncreated-issue/51): add support variable qualifiers attribute_trail = GetAttributeTrail(frame); - CEL_ASSIGN_OR_RETURN(absl::optional value, + CEL_ASSIGN_OR_RETURN(std::optional value, CheckForMarkedAttributes(*frame, attribute_trail)); if (value.has_value()) { frame->value_stack().Pop(kStackInputs); diff --git a/extensions/select_optimization_test.cc b/extensions/select_optimization_test.cc index c07f4c6ad..9d4024098 100644 --- a/extensions/select_optimization_test.cc +++ b/extensions/select_optimization_test.cc @@ -254,8 +254,9 @@ class MockAccessApis : public LegacyTypeInfoApis, public LegacyTypeAccessApis { return nullptr; } - absl::optional FindFieldByName( - absl::string_view field_name) const override { + std::optional< + google::api::expr::runtime::LegacyTypeInfoApis::FieldDescription> + FindFieldByName(absl::string_view field_name) const override { return absl::nullopt; } From b5db1b3dffb7b890f1a05110ad5833ebe0ebdbee Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Sat, 23 May 2026 23:10:55 -0700 Subject: [PATCH 36/87] No public description PiperOrigin-RevId: 920403526 --- eval/compiler/flat_expr_builder.cc | 10 +++++----- eval/compiler/flat_expr_builder_extensions.cc | 2 +- eval/compiler/qualified_reference_resolver.cc | 10 +++++----- eval/compiler/regex_precompilation_optimization.cc | 6 +++--- eval/compiler/resolver.cc | 6 +++--- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index e38c912c0..8558c7007 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -840,7 +840,7 @@ class FlatExprVisitor : public cel::AstVisitor { // Attempt to resolve a select expression as a namespaced identifier for an // enum or type constant value. - absl::optional const_value; + std::optional const_value; int64_t select_root_id = -1; std::string path_candidate; @@ -1080,7 +1080,7 @@ class FlatExprVisitor : public cel::AstVisitor { // Returns the maximum recursion depth of the current program if it is // eligible for recursion, or nullopt if it is not. - absl::optional RecursionEligible() { + std::optional RecursionEligible() { if (!PlanRecursiveProgram() || program_builder_.current() == nullptr) { return absl::nullopt; } @@ -1525,7 +1525,7 @@ class FlatExprVisitor : public cel::AstVisitor { } } } - if (absl::optional depth = RecursionEligible(); depth.has_value()) { + if (std::optional depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != list_expr.elements().size()) { SetProgressStatusError(absl::InternalError( @@ -1855,7 +1855,7 @@ class FlatExprVisitor : public cel::AstVisitor { int64_t expr_id) { absl::string_view ast_name = create_struct_expr.name(); - absl::optional> type; + std::optional> type; CEL_ASSIGN_OR_RETURN(type, resolver_.FindType(ast_name, expr_id)); if (!type.has_value()) { @@ -1932,7 +1932,7 @@ class FlatExprVisitor : public cel::AstVisitor { IndexManager index_manager_; bool enable_optional_types_; - absl::optional block_; + std::optional block_; int max_recursion_depth_ = 0; }; diff --git a/eval/compiler/flat_expr_builder_extensions.cc b/eval/compiler/flat_expr_builder_extensions.cc index 463b48425..e51b64023 100644 --- a/eval/compiler/flat_expr_builder_extensions.cc +++ b/eval/compiler/flat_expr_builder_extensions.cc @@ -98,7 +98,7 @@ size_t Subexpression::ComputeSize() const { return size; } -absl::optional Subexpression::RecursiveDependencyDepth() const { +std::optional Subexpression::RecursiveDependencyDepth() const { auto* tree = absl::get_if(&program_); int depth = 0; if (tree == nullptr) { diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 67f86ebb6..67c14d9b2 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -81,9 +81,9 @@ bool OverloadExists(const Resolver& resolver, absl::string_view name, // Return the qualified name of the most qualified matching overload, or // nullopt if no matches are found. -absl::optional BestOverloadMatch(const Resolver& resolver, - absl::string_view base_name, - int argument_count) { +std::optional BestOverloadMatch(const Resolver& resolver, + absl::string_view base_name, + int argument_count) { if (IsSpecialFunction(base_name)) { return std::string(base_name); } @@ -262,8 +262,8 @@ class ReferenceResolver : public cel::AstRewriterBase { // Convert a select expr sub tree into a namespace name if possible. // If any operand of the top element is a not a select or an ident node, // return nullopt. - absl::optional ToNamespace(const Expr& expr) { - absl::optional maybe_parent_namespace; + std::optional ToNamespace(const Expr& expr) { + std::optional maybe_parent_namespace; if (rewritten_reference_.find(expr.id()) != rewritten_reference_.end()) { // The target expr matches a reference (resolved to an ident decl). // This should not be treated as a function qualifier. diff --git a/eval/compiler/regex_precompilation_optimization.cc b/eval/compiler/regex_precompilation_optimization.cc index b94cae383..455796131 100644 --- a/eval/compiler/regex_precompilation_optimization.cc +++ b/eval/compiler/regex_precompilation_optimization.cc @@ -145,7 +145,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { // Try to check if the regex is valid, whether or not we can actually update // the plan. - absl::optional pattern = + std::optional pattern = GetConstantString(context, subexpression, node, pattern_expr); if (!pattern.has_value()) { return absl::OkStatus(); @@ -168,7 +168,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { } private: - absl::optional GetConstantString( + std::optional GetConstantString( PlannerContext& context, ProgramBuilder::Subexpression* absl_nullable subexpression, const Expr& call_expr, const Expr& re_expr) const { @@ -180,7 +180,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { // Already modified, can't recover the input pattern. return absl::nullopt; } - absl::optional constant; + std::optional constant; if (subexpression->IsRecursive()) { const auto& program = subexpression->recursive_program(); auto deps = program.step->GetDependencies(); diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index 4e3fa3841..17f60eaad 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -102,8 +102,8 @@ absl::Span Resolver::GetPrefixesFor( return namespace_prefixes_; } -absl::optional Resolver::FindConstant(absl::string_view name, - int64_t expr_id) const { +std::optional Resolver::FindConstant(absl::string_view name, + int64_t expr_id) const { auto prefixes = GetPrefixesFor(name); for (const auto& prefix : prefixes) { std::string qualified_name = absl::StrCat(prefix, name); @@ -205,7 +205,7 @@ std::vector Resolver::FindLazyOverloads( return funcs; } -absl::StatusOr>> +absl::StatusOr>> Resolver::FindType(absl::string_view name, int64_t expr_id) const { auto prefixes = GetPrefixesFor(name); for (auto& prefix : prefixes) { From 9911e079ac31ac087e954339170fdd4bc05d3206 Mon Sep 17 00:00:00 2001 From: Muhammad Askri Date: Tue, 26 May 2026 11:21:49 -0700 Subject: [PATCH 37/87] Map well-known protobuf types to their CEL equivalents in LegacyRuntimeType. PiperOrigin-RevId: 921579763 --- common/BUILD | 1 - common/type.cc | 55 ++++++++++++++++++++++++++++++++++++++------- common/type_test.cc | 34 ++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 9 deletions(-) diff --git a/common/BUILD b/common/BUILD index ffc4ae1e9..a016d2cb5 100644 --- a/common/BUILD +++ b/common/BUILD @@ -583,7 +583,6 @@ cc_library( "//internal:string_pool", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/common/type.cc b/common/type.cc index f94e8bc52..684c5ba09 100644 --- a/common/type.cc +++ b/common/type.cc @@ -640,8 +640,6 @@ constexpr absl::string_view kUInt64TypeName = "uint"; constexpr absl::string_view kDoubleTypeName = "double"; constexpr absl::string_view kStringTypeName = "string"; constexpr absl::string_view kBytesTypeName = "bytes"; -constexpr absl::string_view kDurationTypeName = "google.protobuf.Duration"; -constexpr absl::string_view kTimestampTypeName = "google.protobuf.Timestamp"; constexpr absl::string_view kListTypeName = "list"; constexpr absl::string_view kMapTypeName = "map"; constexpr absl::string_view kCelTypeTypeName = "type"; @@ -670,12 +668,6 @@ Type LegacyRuntimeType(absl::string_view name) { if (name == kBytesTypeName) { return BytesType{}; } - if (name == kDurationTypeName) { - return DurationType{}; - } - if (name == kTimestampTypeName) { - return TimestampType{}; - } if (name == kListTypeName) { return ListType{}; } @@ -685,6 +677,53 @@ Type LegacyRuntimeType(absl::string_view name) { if (name == kCelTypeTypeName) { return TypeType{}; } + if (cel::IsWellKnownMessageType(name)) { + if (name == "google.protobuf.Any") { + return AnyType(); + } + if (name == "google.protobuf.BoolValue") { + return BoolWrapperType(); + } + if (name == "google.protobuf.BytesValue") { + return BytesWrapperType(); + } + if (name == "google.protobuf.DoubleValue") { + return DoubleWrapperType(); + } + if (name == "google.protobuf.Duration") { + return DurationType(); + } + if (name == "google.protobuf.FloatValue") { + return DoubleWrapperType(); + } + if (name == "google.protobuf.Int32Value") { + return IntWrapperType(); + } + if (name == "google.protobuf.Int64Value") { + return IntWrapperType(); + } + if (name == "google.protobuf.ListValue") { + return ListType(); + } + if (name == "google.protobuf.StringValue") { + return StringWrapperType(); + } + if (name == "google.protobuf.Struct") { + return JsonMapType(); + } + if (name == "google.protobuf.Timestamp") { + return TimestampType(); + } + if (name == "google.protobuf.UInt32Value") { + return UintWrapperType(); + } + if (name == "google.protobuf.UInt64Value") { + return UintWrapperType(); + } + if (name == "google.protobuf.Value") { + return DynType(); + } + } return common_internal::MakeBasicStructType(name); } diff --git a/common/type_test.cc b/common/type_test.cc index 2cebf27ba..d6a613c3c 100644 --- a/common/type_test.cc +++ b/common/type_test.cc @@ -638,5 +638,39 @@ TEST(Type, Wrap) { EXPECT_EQ(Type(AnyType()).Wrap(), AnyType()); } +TEST(Type, LegacyRuntimeType) { + EXPECT_EQ(common_internal::LegacyRuntimeType("bool"), BoolType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Any"), + AnyType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.BoolValue"), + BoolWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.BytesValue"), + BytesWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.DoubleValue"), + DoubleWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Duration"), + DurationType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.FloatValue"), + DoubleWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Int32Value"), + IntWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Int64Value"), + IntWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.ListValue"), + ListType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.StringValue"), + StringWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Struct"), + JsonMapType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Timestamp"), + TimestampType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.UInt32Value"), + UintWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.UInt64Value"), + UintWrapperType()); + EXPECT_EQ(common_internal::LegacyRuntimeType("google.protobuf.Value"), + DynType()); +} + } // namespace } // namespace cel From e8f6c48fd01acc29dc67b8b0d2d6ce68ea36c94c Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Tue, 26 May 2026 15:05:42 -0700 Subject: [PATCH 38/87] Repo move announcement PiperOrigin-RevId: 921705653 --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 23afe2b00..41b44388d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,12 @@ # C++ Implementations of the Common Expression Language +> [!WARNING] +> **On June 16, 2026, this repository will move to +> github.com/cel-expr/cel-cpp!** +> +> Please update your links and dependencies. See the [pinned +> issue](https://github.com/google/cel-cpp/issues/2029) for details. + For background on the Common Expression Language see the [cel-spec][1] repo. This is a C++ implementation of a [Common Expression Language][1] runtime, From 2f0944e9f25cbe23b202c280f48f0e47f638668c Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 27 May 2026 10:42:49 -0700 Subject: [PATCH 39/87] Remove constraint that block initializer is non-empty. PiperOrigin-RevId: 922222081 --- eval/compiler/flat_expr_builder.cc | 10 ++++------ eval/compiler/flat_expr_builder_test.cc | 7 +++---- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 8558c7007..1e3f4ecd3 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -1047,11 +1047,7 @@ class FlatExprVisitor : public cel::AstVisitor { } const auto& list_expr = call_expr.args().front().list_expr(); block.size = list_expr.elements().size(); - if (block.size == 0) { - SetProgressStatusError(absl::InvalidArgumentError( - "malformed cel.@block: list of bound expressions is empty")); - return; - } + block.bindings_set.reserve(block.size); for (const auto& list_expr_element : list_expr.elements()) { if (list_expr_element.optional()) { @@ -2052,7 +2048,9 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock( } // Otherwise, iterative plan. - AddStep(CreateClearSlotsStep(block.index, block.slot_count, expr.id())); + if (block.slot_count > 0) { + AddStep(CreateClearSlotsStep(block.index, block.slot_count, expr.id())); + } return CallHandlerResult::kIntercepted; } diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 5fc20f01e..d84007485 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -2820,6 +2820,7 @@ TEST(FlatExprBuilderTest, BlockNotListOfBoundExpressions) { TEST(FlatExprBuilderTest, BlockEmptyListOfBoundExpressions) { ParsedExpr parsed_expr; + // Allowed, but degenerate case. ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( R"pb( expr: { @@ -2835,10 +2836,8 @@ TEST(FlatExprBuilderTest, BlockEmptyListOfBoundExpressions) { CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); EXPECT_THAT( builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), - StatusIs( - absl::StatusCode::kInvalidArgument, - HasSubstr( - "malformed cel.@block: list of bound expressions is empty"))); + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid @index greater than number of bindings:"))); } TEST(FlatExprBuilderTest, BlockOptionalListOfBoundExpressions) { From 9fb8e1068781eeb988cf5fcd33aaa05a0c8356b6 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 27 May 2026 12:56:46 -0700 Subject: [PATCH 40/87] Return the first match from `ResolveStatic()` PiperOrigin-RevId: 922297232 --- eval/eval/function_step.cc | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index 2a10e9674..fcf429378 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -286,20 +286,12 @@ absl::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { absl::StatusOr ResolveStatic( absl::Span input_args, absl::Span overloads) { - ResolveResult result = absl::nullopt; - for (const auto& overload : overloads) { if (ArgumentKindsMatch(overload.descriptor, input_args)) { - // More than one overload matches our arguments. - if (result.has_value()) { - return absl::Status(absl::StatusCode::kInternal, - "Cannot resolve overloads"); - } - - result.emplace(overload); + return overload; } } - return result; + return absl::nullopt; } absl::StatusOr ResolveLazy( @@ -315,7 +307,7 @@ absl::StatusOr ResolveLazy( input_args.begin(), input_args.end(), arg_types.begin(), [](const cel::Value& value) { return ValueKindToKind(value->kind()); }); - cel::FunctionDescriptor matcher{name, receiver_style, arg_types}; + cel::FunctionDescriptor matcher{name, receiver_style, std::move(arg_types)}; const cel::ActivationInterface& activation = frame.activation(); for (auto provider : providers) { From 83ce6bd890aef1c2ae41d8179490dd735d2549f2 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 28 May 2026 11:15:52 -0700 Subject: [PATCH 41/87] Add benchmark test for field access implementation. Adds a `cc_test` target for benchmarking `field_access_impl.cc`. PiperOrigin-RevId: 922875336 --- eval/public/structs/BUILD | 20 ++ .../field_access_impl_benchmark_test.cc | 239 ++++++++++++++++++ 2 files changed, 259 insertions(+) create mode 100644 eval/public/structs/field_access_impl_benchmark_test.cc diff --git a/eval/public/structs/BUILD b/eval/public/structs/BUILD index d301ff0ca..d722559e3 100644 --- a/eval/public/structs/BUILD +++ b/eval/public/structs/BUILD @@ -442,3 +442,23 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_test( + name = "field_access_impl_benchmark_test", + srcs = ["field_access_impl_benchmark_test.cc"], + tags = [ + "benchmark", + "manual", + ], + deps = [ + ":cel_proto_wrapper", + ":field_access_impl", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//extensions/protobuf/internal:map_reflection", + "//internal:benchmark", + "//internal:testing", + "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/eval/public/structs/field_access_impl_benchmark_test.cc b/eval/public/structs/field_access_impl_benchmark_test.cc new file mode 100644 index 000000000..888e424b1 --- /dev/null +++ b/eval/public/structs/field_access_impl_benchmark_test.cc @@ -0,0 +1,239 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/field_access_impl.h" +#include "extensions/protobuf/internal/map_reflection.h" +#include "internal/benchmark.h" +#include "cel/expr/conformance/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/map_field.h" +#include "google/protobuf/message.h" + +namespace google::api::expr::runtime::internal { +namespace { + +using ::cel::expr::conformance::proto3::TestAllTypes; + +void BM_CreateValueFromSingleField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.set_single_int64(42); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_int64"); + + for (auto _ : state) { + auto value = CreateValueFromSingleField( + &msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromSingleField_Int64); + +void BM_CreateValueFromSingleField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.set_single_string("hello world"); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_string"); + + for (auto _ : state) { + auto value = CreateValueFromSingleField( + &msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromSingleField_String); + +void BM_CreateValueFromSingleField_Message(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.mutable_standalone_message()->set_bb(123); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + + for (auto _ : state) { + auto value = CreateValueFromSingleField( + &msg, desc, ProtoWrapperTypeOptions::kUnsetProtoDefault, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromSingleField_Message); + +void BM_CreateValueFromRepeatedField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.add_repeated_int64(42); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_int64"); + + for (auto _ : state) { + auto value = CreateValueFromRepeatedField( + &msg, desc, 0, &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromRepeatedField_Int64); + +void BM_CreateValueFromRepeatedField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.add_repeated_string("hello world"); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string"); + + for (auto _ : state) { + auto value = CreateValueFromRepeatedField( + &msg, desc, 0, &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromRepeatedField_String); + +void BM_CreateValueFromMapValue_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + (*msg.mutable_map_int64_int64())[42] = 100; + const google::protobuf::FieldDescriptor* map_desc = + TestAllTypes::descriptor()->FindFieldByName("map_int64_int64"); + const google::protobuf::FieldDescriptor* value_desc = + map_desc->message_type()->FindFieldByName("value"); + + google::protobuf::ConstMapIterator iter = + cel::extensions::protobuf_internal::ConstMapBegin(*msg.GetReflection(), + msg, *map_desc); + google::protobuf::MapValueConstRef value_ref = iter.GetValueRef(); + + for (auto _ : state) { + auto value = + CreateValueFromMapValue(&msg, value_desc, &value_ref, + &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromMapValue_Int64); + +void BM_SetValueToSingleField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_int64"); + CelValue val = CelValue::CreateInt64(42); + + for (auto _ : state) { + auto status = SetValueToSingleField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_SetValueToSingleField_Int64); + +void BM_SetValueToSingleField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("single_string"); + CelValue val = CelValue::CreateStringView("hello world"); + + for (auto _ : state) { + auto status = SetValueToSingleField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_SetValueToSingleField_String); + +void BM_SetValueToSingleField_Message(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("standalone_message"); + + TestAllTypes::NestedMessage nested_msg; + nested_msg.set_bb(123); + CelValue val = CelProtoWrapper::CreateMessage(&nested_msg, &arena); + + for (auto _ : state) { + auto status = SetValueToSingleField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_SetValueToSingleField_Message); + +void BM_AddValueToRepeatedField_Int64(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_int64"); + CelValue val = CelValue::CreateInt64(42); + + for (auto _ : state) { + msg.clear_repeated_int64(); + auto status = AddValueToRepeatedField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_AddValueToRepeatedField_Int64); + +void BM_AddValueToRepeatedField_String(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string"); + CelValue val = CelValue::CreateStringView("hello world"); + + for (auto _ : state) { + msg.clear_repeated_string(); + auto status = AddValueToRepeatedField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_AddValueToRepeatedField_String); + +void BM_CreateValueFromRepeatedField_StringPiece(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + msg.add_repeated_string_piece("hello world"); + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string_piece"); + + for (auto _ : state) { + auto value = CreateValueFromRepeatedField( + &msg, desc, 0, &CelProtoWrapper::InternalWrapMessage, &arena); + benchmark::DoNotOptimize(value); + } +} +BENCHMARK(BM_CreateValueFromRepeatedField_StringPiece); + +void BM_AddValueToRepeatedField_StringPiece(benchmark::State& state) { + google::protobuf::Arena arena; + TestAllTypes msg; + const google::protobuf::FieldDescriptor* desc = + TestAllTypes::descriptor()->FindFieldByName("repeated_string_piece"); + CelValue val = CelValue::CreateStringView("hello world"); + + for (auto _ : state) { + msg.clear_repeated_string_piece(); + auto status = AddValueToRepeatedField(val, desc, &msg, &arena); + benchmark::DoNotOptimize(status); + } +} +BENCHMARK(BM_AddValueToRepeatedField_StringPiece); + +} // namespace +} // namespace google::api::expr::runtime::internal From 5436842151a69b1dd3e0c4350c7406b4fba36805 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 1 Jun 2026 12:51:26 -0700 Subject: [PATCH 42/87] Update presubmit docker image. Use bazelisk to simplify testing bazel upgrades. PiperOrigin-RevId: 924861509 --- Dockerfile | 45 +++++++++++++++++++++++++++++---------------- cloudbuild.yaml | 4 ++-- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/Dockerfile b/Dockerfile index c2c2915be..97611fc75 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,25 +12,25 @@ # # Run the following command from the root of the CEL repository: # -# gcloud builds submit --region=us -t gcr.io/cel-analysis/gcc9 . +# gcloud builds submit --region=us -t gcr.io/cel-analysis/cel-cpp/ubuntu_floor . # # Once complete get the sha256 digest from the output using the following # command: # -# gcloud artifacts versions list --package=gcc9 --repository=gcr.io \ +# gcloud artifacts versions list --package=cel-cpp/ubuntu_floor --repository=gcr.io \ # --location=us # # The cloudbuild.yaml file must be updated to use the new digest like so: # -# - name: 'gcr.io/cel-analysis/gcc9@' -FROM gcc:9 +# - name: 'gcr.io/cel-analysis/cel-cpp/ubuntu_floor@' +FROM gcr.io/cloud-marketplace/google/ubuntu2204:latest # Install Bazel prerequesites and required tools. # See https://docs.bazel.build/versions/master/install-ubuntu.html -RUN apt-get update && \ - apt-get upgrade -y && \ - apt-get install -y --no-install-recommends \ - ca-certificates \ +RUN apt-get update && apt-get upgrade -y && \ + DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + bash \ + ca-certificates \ git \ libssl-dev \ make \ @@ -41,16 +41,29 @@ RUN apt-get update && \ zip \ zlib1g-dev \ default-jdk-headless \ - clang-11 && \ - apt-get clean + clang-11 \ + gcc-9 g++-9 \ + tzdata \ + && apt-get clean -# Install Bazel. -# https://github.com/bazelbuild/bazel/releases -ARG BAZEL_VERSION="7.3.2" -ADD https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh /tmp/install_bazel.sh -RUN /bin/bash /tmp/install_bazel.sh && rm /tmp/install_bazel.sh +# Install Bazelisk. +# https://github.com/bazelbuild/bazelisk/releases +ARG BAZELISK_URL="https://github.com/bazelbuild/bazelisk/releases/download/v1.27.0/bazelisk-amd64.deb" +ARG BAZELISK_CHKSUM="d8b00ea975c823e15263c80200ac42979e17368547fbff4ab177af035badfa83" +ADD ${BAZELISK_URL} /tmp/bazelisk.deb + +ENV BAZELISK_CHKSUM=${BAZELISK_CHKSUM} +RUN echo "${BAZELISK_CHKSUM} */tmp/bazelisk.deb" | sha256sum --check + +RUN apt-get install /tmp/bazelisk.deb RUN mkdir -p /workspace RUN mkdir -p /bazel -ENTRYPOINT ["/usr/local/bin/bazel"] +RUN USE_BAZEL_VERSION=8.7.0 bazelisk help +RUN USE_BAZEL_VERSION=7.3.2 bazelisk help + +ENV CC=gcc-9 +ENV CXX=g++-9 + +ENTRYPOINT ["/usr/bin/bazelisk"] diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 8272378f6..dec359f25 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -1,5 +1,5 @@ steps: -- name: 'gcr.io/cel-analysis/gcc9@sha256:4d5ff2e55224398807235a44b57e9c5793e922ac46e9ff428536bb8f8e5790ce' +- name: 'gcr.io/cel-analysis/cel-cpp/ubuntu_floor@sha256:211a0c505b361d987b3d8b08a5144a84e62cb95edc3f897fe46d5cd3f556f79d' args: - '--output_base=/bazel' # This is mandatory to avoid steps accidently sharing data. - 'test' @@ -16,7 +16,7 @@ steps: - '--google_default_credentials' id: gcc-9 waitFor: ['-'] -- name: 'gcr.io/cel-analysis/gcc9@sha256:4d5ff2e55224398807235a44b57e9c5793e922ac46e9ff428536bb8f8e5790ce' +- name: 'gcr.io/cel-analysis/cel-cpp/ubuntu_floor@sha256:211a0c505b361d987b3d8b08a5144a84e62cb95edc3f897fe46d5cd3f556f79d' env: - 'CC=clang-11' - 'CXX=clang++-11' From f1e5042c97e641457b92dcd4cdee57d1facf96fa Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 1 Jun 2026 14:06:53 -0700 Subject: [PATCH 43/87] Bump bazel version to 8.x PiperOrigin-RevId: 924902706 --- .bazelversion | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.bazelversion b/.bazelversion index eab246c06..df5119ec6 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -7.3.2 +8.7.0 From dfea16a3fefaace88eafbdb294bc8ea376bb1c9b Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 1 Jun 2026 14:38:20 -0700 Subject: [PATCH 44/87] Update deps. Patches to antlr were upstreamed to bcr. PiperOrigin-RevId: 924919565 --- MODULE.bazel | 17 ++++------------- bazel/antlr.patch | 30 ------------------------------ 2 files changed, 4 insertions(+), 43 deletions(-) delete mode 100644 bazel/antlr.patch diff --git a/MODULE.bazel b/MODULE.bazel index 02404b645..43d0485d2 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -33,7 +33,7 @@ bazel_dep( ) bazel_dep( name = "protobuf", - version = "33.4", + version = "34.1", repo_name = "com_google_protobuf", ) bazel_dep( @@ -41,20 +41,16 @@ bazel_dep( version = "20260107.0", repo_name = "com_google_absl", ) - bazel_dep( name = "googletest", version = "1.17.0.bcr.2", - dev_dependency = True, repo_name = "com_google_googletest", ) bazel_dep( name = "google_benchmark", version = "1.9.2", - dev_dependency = True, repo_name = "com_github_google_benchmark", ) - bazel_dep( name = "re2", version = "2025-11-05.bcr.1", @@ -74,16 +70,9 @@ bazel_dep( name = "platforms", version = "1.0.0", ) - -ANTLR4_VERSION = "4.13.2" - bazel_dep( name = "antlr4-cpp-runtime", - version = ANTLR4_VERSION, -) -single_version_override( - module_name = "antlr4-cpp-runtime", - patches = ["//bazel:antlr.patch"], + version = "4.13.2.bcr.2", ) python = use_extension("@rules_python//python/extensions:python.bzl", "python") @@ -95,6 +84,8 @@ python.toolchain( http_jar = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_jar") +ANTLR4_VERSION = "4.13.2" + http_jar( name = "antlr4_jar", sha256 = "eae2dfa119a64327444672aff63e9ec35a20180dc5b8090b7a6ab85125df4d76", diff --git a/bazel/antlr.patch b/bazel/antlr.patch deleted file mode 100644 index c1aa9080c..000000000 --- a/bazel/antlr.patch +++ /dev/null @@ -1,30 +0,0 @@ ---- BUILD.bazel -+++ BUILD.bazel -@@ -17,21 +17,21 @@ - cc_library( - name = "antlr4-cpp-runtime", - srcs = glob(["runtime/src/**/*.cpp"]), - hdrs = ["runtime/src/antlr4-runtime.h"], - copts = ["-fexceptions"], -- defines = ["ANTLR4CPP_USING_ABSEIL"], -+ defines = ["ANTLR4CPP_USING_ABSEIL", "ANTLR4CPP_STATIC"], - features = ["-use_header_modules"], - includes = ["runtime/src"], - textual_hdrs = glob( - ["runtime/src/**/*.h"], - exclude = ["runtime/src/antlr4-runtime.h"], - ), - visibility = ["//visibility:public"], - deps = [ - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/synchronization", - ], - ) - ---- VERSION -+++ /dev/null -@@ -1,1 +1,0 @@ --4.13.2 \ No newline at end of file From db0725c120fab156cedf32776bba9d35d9406d85 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 2 Jun 2026 13:31:45 -0700 Subject: [PATCH 45/87] Add basic workflow for windows/bazel test (#2039) Add basic workflow for running conformance tests on windows. Only run on post merge or manually for now. Closes #2039 PiperOrigin-RevId: 925547950 --- .bazelrc | 2 + .github/workflows/windows_bazel_test.yml | 28 +++++++++++++ .../windows_bazel_test_post_merge.yml | 13 ++++++ conformance/BUILD | 1 + conformance/run.bzl | 2 +- conformance/run.cc | 7 ++-- internal/BUILD | 11 +++++ internal/runfiles.cc | 40 +++++++++++++++++++ internal/runfiles.h | 30 ++++++++++++++ 9 files changed, 130 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/windows_bazel_test.yml create mode 100644 .github/workflows/windows_bazel_test_post_merge.yml create mode 100644 internal/runfiles.cc create mode 100644 internal/runfiles.h diff --git a/.bazelrc b/.bazelrc index 475706072..1246d336b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -10,6 +10,8 @@ build:linux --copt=-Wno-deprecated-declarations # you will typically need to spell out the compiler for local dev # BAZEL_VC= # BAZEL_VC_FULL_VERSION=14.44.3520 +# Some dependencies rely on bash so you will likely need msys2 +# BAZEL_SH=C:\msys64\usr\bin\bash.exe build:msvc --cxxopt="-std:c++20" --cxxopt="-utf-8" --host_cxxopt="-std:c++20" build:msvc --define=protobuf_allow_msvc=true build:msvc --test_tag_filters=-benchmark,-notap,-no_test_msvc diff --git a/.github/workflows/windows_bazel_test.yml b/.github/workflows/windows_bazel_test.yml new file mode 100644 index 000000000..4ac7f2eec --- /dev/null +++ b/.github/workflows/windows_bazel_test.yml @@ -0,0 +1,28 @@ +name: Windows Bazel Test + +on: + workflow_call: + workflow_dispatch: + +jobs: + test: + name: Run Bazel Tests + runs-on: windows-latest + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Setup Bazel and Bazelisk + uses: bazel-contrib/setup-bazel@0.19.0 + with: + bazelisk-cache: true + disk-cache: ${{ github.workflow }} + repository-cache: true + + - name: Run Tests + # msys2 'bash' on Windows will try to 'fix' the label prefix to + # work as a directory. + # //... won't work. + shell: bash + run: | + bazelisk test --config=msvc conformance:all \ No newline at end of file diff --git a/.github/workflows/windows_bazel_test_post_merge.yml b/.github/workflows/windows_bazel_test_post_merge.yml new file mode 100644 index 000000000..11801011e --- /dev/null +++ b/.github/workflows/windows_bazel_test_post_merge.yml @@ -0,0 +1,13 @@ +name: Windows Bazel Test (Post-Merge) + +on: + push: + branches: + - master + +jobs: + trigger-test: + # This prevents the workflow from running automatically when someone + # pushes to their fork. + if: github.repository == 'google/cel-cpp' + uses: ./.github/workflows/windows_bazel_test.yml \ No newline at end of file diff --git a/conformance/BUILD b/conformance/BUILD index 0ca90a4bc..a6f25e001 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -99,6 +99,7 @@ cc_library( deps = [ ":service", ":utils", + "//internal:runfiles", "//internal:testing_no_main", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log:absl_check", diff --git a/conformance/run.bzl b/conformance/run.bzl index 15850b0aa..2c0b51c0e 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -77,7 +77,7 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard): cc_test( name = _conformance_test_name(name, optimize, recursive), - args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard) + ["$(location " + test + ")" for test in data], + args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard) + ["$(rlocationpath {})".format(test) for test in data], env = select( { "@platforms//os:windows": {"CEL_SKIP_TESTS": ",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS)}, diff --git a/conformance/run.cc b/conformance/run.cc index 80164d9a4..4a0493494 100644 --- a/conformance/run.cc +++ b/conformance/run.cc @@ -48,6 +48,7 @@ #include "absl/types/span.h" #include "conformance/service.h" #include "conformance/utils.h" +#include "internal/runfiles.h" #include "internal/testing.h" #include "cel/expr/conformance/test/simple.pb.h" #include "google/protobuf/io/zero_copy_stream_impl.h" @@ -68,8 +69,6 @@ ABSL_FLAG(bool, select_optimization, false, "Enable select optimization."); namespace { -using ::testing::IsEmpty; - using cel::expr::conformance::test::SimpleTest; using cel::expr::conformance::test::SimpleTestFile; using google::api::expr::conformance::v1alpha1::CheckRequest; @@ -78,6 +77,7 @@ using google::api::expr::conformance::v1alpha1::EvalRequest; using google::api::expr::conformance::v1alpha1::EvalResponse; using google::api::expr::conformance::v1alpha1::ParseRequest; using google::api::expr::conformance::v1alpha1::ParseResponse; +using ::testing::IsEmpty; google::rpc::Code ToGrpcCode(absl::StatusCode code) { return static_cast(code); @@ -282,8 +282,9 @@ int main(int argc, char** argv) { } } for (int argi = 1; argi < argc; argi++) { + std::string path = cel::internal::ResolveRunfilesPath(argv[argi]); ABSL_CHECK_OK(RegisterTestsFromFile(service, tests_to_skip, - absl::string_view(argv[argi]))); + absl::string_view(path))); } } int exit_code = RUN_ALL_TESTS(); diff --git a/internal/BUILD b/internal/BUILD index 3891c635d..0ac5c4e46 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -86,6 +86,17 @@ cc_library( ], ) +cc_library( + name = "runfiles", + srcs = ["runfiles.cc"], + hdrs = ["runfiles.h"], + deps = [ + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@rules_cc//cc/runfiles", + ], +) + cc_library( name = "status_builder", hdrs = ["status_builder.h"], diff --git a/internal/runfiles.cc b/internal/runfiles.cc new file mode 100644 index 000000000..259e2e7ca --- /dev/null +++ b/internal/runfiles.cc @@ -0,0 +1,40 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/runfiles.h" + +#include + +#include "rules_cc/cc/runfiles/runfiles.h" + +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace cel::internal { + +std::string ResolveRunfilesPath(absl::string_view path) { + using ::rules_cc::cc::runfiles::Runfiles; + static Runfiles* runfiles = []() { + std::string error; + auto runfiles = + Runfiles::CreateForTest(BAZEL_CURRENT_REPOSITORY, &error); + ABSL_QCHECK(runfiles != nullptr) + << absl::StrCat("failed to init runfiles", error); + return runfiles; + }(); + return runfiles->Rlocation(std::string(path)); +} + +} // namespace cel::internal diff --git a/internal/runfiles.h b/internal/runfiles.h new file mode 100644 index 000000000..643c677b4 --- /dev/null +++ b/internal/runfiles.h @@ -0,0 +1,30 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ + +#include + +#include "absl/strings/string_view.h" + +namespace cel::internal { + +// Resolves a path relative to the runfiles directory. +// Intended for resolving test cases from cel-spec and cel-policy. +std::string ResolveRunfilesPath(absl::string_view path); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ From e264932db519034136ba38cb7381405b32ee9da6 Mon Sep 17 00:00:00 2001 From: sahvx655-wq Date: Wed, 3 Jun 2026 23:04:15 +0530 Subject: [PATCH 46/87] fix undefined left shift of negative int in math.bitShiftLeft --- extensions/math_ext.cc | 6 +++++- extensions/math_ext_test.cc | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/extensions/math_ext.cc b/extensions/math_ext.cc index 4d133d90c..a7773da19 100644 --- a/extensions/math_ext.cc +++ b/extensions/math_ext.cc @@ -266,7 +266,11 @@ Value BitShiftLeftInt(int64_t lhs, int64_t rhs) { if (rhs > 63) { return IntValue(0); } - return IntValue(lhs << static_cast(rhs)); + // Shift in the unsigned domain to avoid undefined behaviour when lhs is + // negative or the shift moves bits into the sign bit, matching the bit + // pattern semantics already used by bitShiftRight. + return IntValue(absl::bit_cast(absl::bit_cast(lhs) + << static_cast(rhs))); } Value BitShiftLeftUint(uint64_t lhs, int64_t rhs) { diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc index 72605648f..ea9331970 100644 --- a/extensions/math_ext_test.cc +++ b/extensions/math_ext_test.cc @@ -563,6 +563,8 @@ INSTANTIATE_TEST_SUITE_P( {"math.bitNot(2) == -3"}, {"math.bitAnd(math.bitNot(0x3u), 0xFFu) == 0xFCu"}, {"math.bitShiftLeft(1, 1) == 2"}, + {"math.bitShiftLeft(-1, 1) == -2"}, + {"math.bitShiftLeft(-4, 2) == -16"}, {"math.bitShiftLeft(1u, 1) == 2u"}, {"math.bitShiftRight(4, 1) == 2"}, {"math.bitShiftRight(4u, 1) == 2u"}})); From ad137617e146bf2d050747ca3a7fedbcef0e4018 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Wed, 3 Jun 2026 14:00:23 -0700 Subject: [PATCH 47/87] Add memory test for managed constant folding arena PiperOrigin-RevId: 926250668 --- eval/tests/memory_safety_test.cc | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/eval/tests/memory_safety_test.cc b/eval/tests/memory_safety_test.cc index 9c0a683e4..a88844fed 100644 --- a/eval/tests/memory_safety_test.cc +++ b/eval/tests/memory_safety_test.cc @@ -51,7 +51,12 @@ struct TestCase { bool reference_resolver_enabled = false; }; -enum Options { kDefault, kExhaustive, kFoldConstants }; +enum Options { + kDefault, + kExhaustive, + kFoldConstants, + kFoldConstantsManagedArena +}; using ParamType = std::tuple; @@ -68,6 +73,9 @@ std::string TestCaseName(const testing::TestParamInfo& param_info) { case Options::kFoldConstants: opt = "opt"; break; + case Options::kFoldConstantsManagedArena: + opt = "opt_managed_arena"; + break; } return absl::StrCat(std::get<0>(param).name, "_", opt); @@ -110,6 +118,14 @@ class EvaluatorMemorySafetyTest : public testing::TestWithParam { options.enable_comprehension_vulnerability_check = false; options.short_circuiting = true; break; + case Options::kFoldConstantsManagedArena: + options.enable_regex_precompilation = true; + options.constant_folding = true; + options.enable_comprehension_list_append = true; + options.enable_comprehension_vulnerability_check = false; + options.short_circuiting = true; + options.constant_arena = nullptr; + break; } options.enable_qualified_identifier_rewrites = @@ -295,7 +311,8 @@ INSTANTIATE_TEST_SUITE_P( test::IsCelBool(true), }}), testing::Values(Options::kDefault, Options::kExhaustive, - Options::kFoldConstants)), + Options::kFoldConstants, + Options::kFoldConstantsManagedArena)), &TestCaseName); } // namespace From f26fd7d9ddf6cace625627e9b6d94d40507e823a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 4 Jun 2026 11:19:41 -0700 Subject: [PATCH 48/87] Fix leak for optional.of(string). PiperOrigin-RevId: 926790246 --- common/values/optional_value.cc | 4 ++-- runtime/BUILD | 1 + runtime/memory_safety_test.cc | 10 ++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/common/values/optional_value.cc b/common/values/optional_value.cc index 688cf8fb0..7c214b9cb 100644 --- a/common/values/optional_value.cc +++ b/common/values/optional_value.cc @@ -345,7 +345,7 @@ OpaqueValue GenericOptionalValueClone( cel::Value* absl_nonnull result = ::new (arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) cel::Value(content.To().value->Clone(arena)); - if (!ArenaTraits<>::trivially_destructible(result)) { + if (!ArenaTraits<>::trivially_destructible(*result)) { arena->OwnDestructor(result); } return common_internal::MakeOptionalValue( @@ -395,7 +395,7 @@ OptionalValue OptionalValue::Of(cel::Value value, cel::Value* absl_nonnull result = ::new ( arena->AllocateAligned(sizeof(cel::Value), alignof(cel::Value))) cel::Value(std::move(value)); - if (!ArenaTraits<>::trivially_destructible(result)) { + if (!ArenaTraits<>::trivially_destructible(*result)) { arena->OwnDestructor(result); } return OptionalValue(&optional_value_dispatcher, diff --git a/runtime/BUILD b/runtime/BUILD index 776a8223d..34ff411a1 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -615,6 +615,7 @@ cc_test( ":activation", ":constant_folding", ":function_adapter", + ":optional_types", ":reference_resolver", ":regex_precompilation", ":runtime", diff --git a/runtime/memory_safety_test.cc b/runtime/memory_safety_test.cc index 2a09be666..a60b4ce60 100644 --- a/runtime/memory_safety_test.cc +++ b/runtime/memory_safety_test.cc @@ -45,6 +45,7 @@ #include "runtime/activation.h" #include "runtime/constant_folding.h" #include "runtime/function_adapter.h" +#include "runtime/optional_types.h" #include "runtime/reference_resolver.h" #include "runtime/regex_precompilation.h" #include "runtime/runtime.h" @@ -174,6 +175,7 @@ absl::StatusOr> ConfigureRuntimeImpl( if (resolve_references) { CEL_RETURN_IF_ERROR(EnableReferenceResolver( runtime_builder, ReferenceResolverEnabled::kAlways)); + CEL_RETURN_IF_ERROR(extensions::EnableOptionalTypes(runtime_builder)); } if (evaluation_options == Options::kFoldConstants) { CEL_RETURN_IF_ERROR(extensions::EnableConstantFolding(runtime_builder)); @@ -315,6 +317,14 @@ INSTANTIATE_TEST_SUITE_P( {{"condition", BoolValue(false)}}, test::StringValueIs("long_right_hand_string_0123456789"), }, + {"optional_of_long_const_string", + "condition ? optional.of('lhs_short') : " + "optional.of('long_right_hand_string_0123456789')", + {{"condition", BoolValue(false)}}, + test::OptionalValueIs( + test::StringValueIs("long_right_hand_string_0123456789")), + // optional.of is a namespaced function. + /*enable_reference_resolver=*/true}, { "computed_string", "(condition ? 'a.b' : 'b.c') + '.d.e.f'", From e4ff771e68f58433ca20f4471c4a5c9dd86923eb Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 4 Jun 2026 15:39:42 -0700 Subject: [PATCH 49/87] Add ProtoTypeMaskRegistry, ProtoTypeMask, and FieldPath classes. The ProtoTypeMaskRegistry has functions that can be used to validate the input field masks and to check whether a field is visible. PiperOrigin-RevId: 926925505 --- checker/internal/BUILD | 92 ++++ checker/internal/field_path.cc | 30 ++ checker/internal/field_path.h | 77 ++++ checker/internal/field_path_test.cc | 85 ++++ checker/internal/proto_type_mask.cc | 87 ++++ checker/internal/proto_type_mask.h | 111 +++++ checker/internal/proto_type_mask_registry.cc | 180 ++++++++ checker/internal/proto_type_mask_registry.h | 83 ++++ .../internal/proto_type_mask_registry_test.cc | 402 ++++++++++++++++++ checker/internal/proto_type_mask_test.cc | 143 +++++++ 10 files changed, 1290 insertions(+) create mode 100644 checker/internal/field_path.cc create mode 100644 checker/internal/field_path.h create mode 100644 checker/internal/field_path_test.cc create mode 100644 checker/internal/proto_type_mask.cc create mode 100644 checker/internal/proto_type_mask.h create mode 100644 checker/internal/proto_type_mask_registry.cc create mode 100644 checker/internal/proto_type_mask_registry.h create mode 100644 checker/internal/proto_type_mask_registry_test.cc create mode 100644 checker/internal/proto_type_mask_test.cc diff --git a/checker/internal/BUILD b/checker/internal/BUILD index f4c60f937..25550616a 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -310,3 +310,95 @@ cc_test( "@com_google_absl//absl/types:optional", ], ) + +cc_library( + name = "field_path", + srcs = ["field_path.cc"], + hdrs = ["field_path.h"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "field_path_test", + srcs = ["field_path_test.cc"], + deps = [ + ":field_path", + "//internal:testing", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "proto_type_mask", + srcs = ["proto_type_mask.cc"], + hdrs = ["proto_type_mask.h"], + deps = [ + ":field_path", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_type_mask_test", + srcs = ["proto_type_mask_test.cc"], + deps = [ + ":field_path", + ":proto_type_mask", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "proto_type_mask_registry", + srcs = ["proto_type_mask_registry.cc"], + hdrs = ["proto_type_mask_registry.h"], + deps = [ + ":field_path", + ":proto_type_mask", + "//common:type", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_type_mask_registry_test", + srcs = ["proto_type_mask_registry_test.cc"], + deps = [ + ":proto_type_mask", + ":proto_type_mask_registry", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + ], +) diff --git a/checker/internal/field_path.cc b/checker/internal/field_path.cc new file mode 100644 index 000000000..5ecc4219b --- /dev/null +++ b/checker/internal/field_path.cc @@ -0,0 +1,30 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/field_path.h" + +#include + +#include "absl/strings/str_join.h" +#include "absl/strings/substitute.h" + +namespace cel::checker_internal { + +std::string FieldPath::DebugString() const { + return absl::Substitute( + "FieldPath { field path: '$0', field selection: {'$1'} }", path_, + absl::StrJoin(field_selection_, "', '")); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/field_path.h b/checker/internal/field_path.h new file mode 100644 index 000000000..d67d9b935 --- /dev/null +++ b/checker/internal/field_path.h @@ -0,0 +1,77 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ + +#include +#include +#include + +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace cel::checker_internal { + +// Represents a single path within a FieldMask. +class FieldPath { + public: + explicit FieldPath(std::string path) + : path_(std::move(path)), + field_selection_(absl::StrSplit(path_, kPathDelimiter)) {} + + // Returns the input path. + // For example: "f.b.d". + absl::string_view GetPath() const { return path_; } + + // Returns the list of nested field names in the path. + // For example: {"f", "b", "d"}. + absl::Span GetFieldSelection() const { + return field_selection_; + } + + // Returns the first field name in the path. + // For example: "f". + std::string GetFieldName() const { return field_selection_.front(); } + + template + friend void AbslStringify(Sink& sink, const FieldPath& field_path) { + sink.Append(field_path.DebugString()); + } + + private: + static constexpr char kPathDelimiter = '.'; + + std::string DebugString() const; + + // The input path. For example: "f.b.d". + std::string path_; + // The list of nested field names in the path. For example: {"f", "b", "d"}. + std::vector field_selection_; +}; + +inline bool operator==(const FieldPath& lhs, const FieldPath& rhs) { + return lhs.GetFieldSelection() == rhs.GetFieldSelection(); +} + +// Compares the field selections in the field paths. +// This is only intended as an arbitrary ordering for a set. +inline bool operator<(const FieldPath& lhs, const FieldPath& rhs) { + return lhs.GetFieldSelection() < rhs.GetFieldSelection(); +} + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_FIELD_PATH_H_ diff --git a/checker/internal/field_path_test.cc b/checker/internal/field_path_test.cc new file mode 100644 index 000000000..9a1434954 --- /dev/null +++ b/checker/internal/field_path_test.cc @@ -0,0 +1,85 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/field_path.h" + +#include "absl/strings/str_cat.h" +#include "internal/testing.h" + +namespace cel::checker_internal { +namespace { + +using ::testing::ElementsAre; + +TEST(FieldPathTest, EmptyPathReturnsEmptyString) { + FieldPath field_path(""); + EXPECT_EQ(field_path.GetPath(), ""); + EXPECT_THAT(field_path.GetFieldSelection(), ElementsAre("")); + EXPECT_EQ(field_path.GetFieldName(), ""); +} + +TEST(FieldPathTest, DelimiterPathReturnsEmptyStrings) { + FieldPath field_path("."); + EXPECT_EQ(field_path.GetPath(), "."); + EXPECT_THAT(field_path.GetFieldSelection(), ElementsAre("", "")); + EXPECT_EQ(field_path.GetFieldName(), ""); +} + +TEST(FieldPathTest, FieldPathReturnsFields) { + FieldPath field_path("resource.name.other_field"); + EXPECT_EQ(field_path.GetPath(), "resource.name.other_field"); + EXPECT_THAT(field_path.GetFieldSelection(), + ElementsAre("resource", "name", "other_field")); + EXPECT_EQ(field_path.GetFieldName(), "resource"); +} + +TEST(FieldPathTest, AbslStringifyPrintsFieldSelection) { + FieldPath field_path("resource.name"); + EXPECT_EQ(absl::StrCat(field_path), + "FieldPath { field path: 'resource.name', field selection: " + "{'resource', 'name'} }"); +} + +TEST(FieldPathTest, EqualsComparesFieldSelectionAndReturnsTrue) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.name"); + EXPECT_TRUE(field_path_1 == field_path_2); +} + +TEST(FieldPathTest, EqualsComparesFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.type"); + EXPECT_FALSE(field_path_1 == field_path_2); +} + +TEST(FieldPathTest, LessThanComparesFieldSelectionAndReturnsTrue) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.type"); + EXPECT_TRUE(field_path_1 < field_path_2); +} + +TEST(FieldPathTest, LessThanComparesIdenticalFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.name"); + FieldPath field_path_2("resource.name"); + EXPECT_FALSE(field_path_1 < field_path_2); +} + +TEST(FieldPathTest, LessThanComparesFieldSelectionAndReturnsFalse) { + FieldPath field_path_1("resource.type"); + FieldPath field_path_2("resource.name"); + EXPECT_FALSE(field_path_1 < field_path_2); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask.cc b/checker/internal/proto_type_mask.cc new file mode 100644 index 000000000..85e39cb69 --- /dev/null +++ b/checker/internal/proto_type_mask.cc @@ -0,0 +1,87 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "checker/internal/field_path.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +using ::google::protobuf::Descriptor; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::FieldDescriptor; + +absl::StatusOr FindMessage( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type_name) { + const Descriptor* descriptor = + descriptor_pool->FindMessageTypeByName(type_name); + if (descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::Substitute("type '$0' not found", type_name)); + } + return descriptor; +} + +absl::StatusOr FindField(const Descriptor* descriptor, + absl::string_view field_name) { + const FieldDescriptor* field_descriptor = + descriptor->FindFieldByName(field_name); + if (field_descriptor == nullptr) { + return absl::InvalidArgumentError( + absl::Substitute("could not select field '$0' from type '$1'", + field_name, descriptor->full_name())); + } + return field_descriptor; +} + +absl::StatusOr> ProtoTypeMask::GetFieldNames( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) const { + CEL_ASSIGN_OR_RETURN(const Descriptor* descriptor, + FindMessage(descriptor_pool, this->GetTypeName())); + absl::btree_set field_names; + for (const FieldPath& field_path : this->GetFieldPaths()) { + std::string field_name = field_path.GetFieldName(); + CEL_ASSIGN_OR_RETURN(const FieldDescriptor* field_descriptor, + FindField(descriptor, field_name)); + field_names.insert(field_descriptor->name()); + } + return field_names; +} + +std::string ProtoTypeMask::DebugString() const { + // Represent each FieldPath by its path because it is easiest to read. + std::vector paths; + paths.reserve(field_paths_.size()); + for (const FieldPath& field_path : field_paths_) { + paths.emplace_back(field_path.GetPath()); + } + return absl::Substitute( + "ProtoTypeMask { type name: '$0', field paths: { '$1' } }", type_name_, + absl::StrJoin(paths, "', '")); +} + +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask.h b/checker/internal/proto_type_mask.h new file mode 100644 index 000000000..f7d522cba --- /dev/null +++ b/checker/internal/proto_type_mask.h @@ -0,0 +1,111 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/internal/field_path.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Returns a descriptor for the input type name. +// Returns an error if the type name is not found. +absl::StatusOr FindMessage( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + absl::string_view type_name); + +// Returns a field descriptor for the input field name. +// Returns an error if the field name is not found. +absl::StatusOr FindField( + const google::protobuf::Descriptor* descriptor, absl::string_view field_name); + +// Represents the fraction of a protobuf type's object graph that should be +// visible within CEL expressions. +class ProtoTypeMask { + public: + explicit ProtoTypeMask(std::string type_name, + const std::vector& field_paths) + : type_name_(std::move(type_name)) { + for (const std::string& field_path : field_paths) { + field_paths_.insert(FieldPath(field_path)); + } + } + + // Returns a set of field names. The set includes the first field name from + // each field path. We are able to return a set of absl::string_view because + // the result is backed by the descriptor pool. + absl::StatusOr> GetFieldNames( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool) const; + + // Returns the type's full name. + // For example: "google.rpc.context.AttributeContext". + absl::string_view GetTypeName() const { return type_name_; } + + // Returns a representation of the FieldMask, which is a set of field paths. + // For example: + // { + // FieldPath { + // field path: 'resource.name', + // field selection: {'resource', 'name'} + // }, + // FieldPath { + // field path: 'request.auth.claims', + // field selection: {'request', 'auth', 'claims'} + // } + // } + const absl::btree_set& GetFieldPaths() const { + return field_paths_; + } + + template + friend void AbslStringify(Sink& sink, const ProtoTypeMask& proto_type_mask) { + sink.Append(proto_type_mask.DebugString()); + } + + private: + std::string DebugString() const; + + // A type's full name. For example: "google.rpc.context.AttributeContext". + std::string type_name_; + // A representation of a FieldMask, which is a set of field paths. + // For example: + // { + // FieldPath { + // field path: 'resource.name', + // field selection: {'resource', 'name'} + // }, + // FieldPath { + // field path: 'request.auth.claims', + // field selection: {'request', 'auth', 'claims'} + // } + // } + // A FieldMask contains one or more paths which contain identifier characters + // that have been dot delimited, e.g. resource.name, request.auth.claims. + // For each path, all descendent fields after the last element in the path are + // visible. An empty set means all fields are hidden. + absl::btree_set field_paths_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_H_ diff --git a/checker/internal/proto_type_mask_registry.cc b/checker/internal/proto_type_mask_registry.cc new file mode 100644 index 000000000..9c50c9784 --- /dev/null +++ b/checker/internal/proto_type_mask_registry.cc @@ -0,0 +1,180 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask_registry.h" + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "checker/internal/field_path.h" +#include "checker/internal/proto_type_mask.h" +#include "common/type.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { +namespace { + +using ::google::protobuf::Descriptor; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::FieldDescriptor; +using TypeMap = + absl::flat_hash_map>; + +// Returns a message type descriptor for the input field descriptor. +// Returns an error if the field is not a message type. +absl::StatusOr GetMessage( + const FieldDescriptor* field_descriptor) { + cel::MessageTypeField field(field_descriptor); + cel::Type type = field.GetType(); + absl::optional message_type = type.AsMessage(); + if (!message_type.has_value()) { + return absl::InvalidArgumentError(absl::Substitute( + "field '$0' is not a message type", field_descriptor->name())); + } + return &(*message_type.value()); +} + +// Inserts the type name with an empty set into types_and_visible_fields. +// Returns an error if the type name is already present with a non-empty set. +absl::Status AddAllHiddenFields(TypeMap& types_and_visible_fields, + absl::string_view type_name) { + auto result = types_and_visible_fields.find(type_name); + if (result != types_and_visible_fields.end()) { + if (!result->second.empty()) { + return absl::InvalidArgumentError( + absl::Substitute("cannot insert a proto type mask with all hidden " + "fields when type '$0' has already been inserted " + "with a proto type mask with a visible field", + type_name)); + } + return absl::OkStatus(); + } + types_and_visible_fields.insert({std::string(type_name), {}}); + return absl::OkStatus(); +} + +// Inserts the type name and field name into types_and_visible_fields. +// Returns an error if the type name is already present with an empty set. +absl::Status AddVisibleField(TypeMap& types_and_visible_fields, + absl::string_view type_name, + absl::string_view field_name) { + auto result = types_and_visible_fields.find(type_name); + if (result != types_and_visible_fields.end()) { + if (result->second.empty()) { + return absl::InvalidArgumentError(absl::Substitute( + "cannot insert a proto type mask with visible " + "field '$0' when type '$1' has already been inserted " + "with a proto type mask with all hidden fields", + field_name, type_name)); + } + result->second.insert(std::string(field_name)); + return absl::OkStatus(); + } + types_and_visible_fields.insert( + {std::string(type_name), {std::string(field_name)}}); + return absl::OkStatus(); +} + +// Processes the input proto type masks to create and return the +// types_and_visible_fields map. +// Returns an error if one of the proto type masks is not valid. For example, +// if a type is not found in the descriptor pool, if a field name is not +// found, or if a field is not a message type when we are expecting it to be. +// Returns an error if there is a conflict in field visibility when +// updating the map. +absl::StatusOr ComputeVisibleFieldsMap( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks) { + TypeMap types_and_visible_fields; + for (const ProtoTypeMask& proto_type_mask : proto_type_masks) { + absl::string_view type_name = proto_type_mask.GetTypeName(); + CEL_ASSIGN_OR_RETURN(const Descriptor* descriptor, + FindMessage(descriptor_pool, type_name)); + const absl::btree_set& field_paths = + proto_type_mask.GetFieldPaths(); + if (field_paths.empty()) { + CEL_RETURN_IF_ERROR( + AddAllHiddenFields(types_and_visible_fields, type_name)); + } + for (const FieldPath& field_path : field_paths) { + const Descriptor* target_descriptor = descriptor; + absl::Span field_selection = + field_path.GetFieldSelection(); + for (auto iterator = field_selection.begin(); + iterator != field_selection.end(); ++iterator) { + CEL_ASSIGN_OR_RETURN(const FieldDescriptor* field_descriptor, + FindField(target_descriptor, *iterator)); + CEL_RETURN_IF_ERROR(AddVisibleField(types_and_visible_fields, + target_descriptor->full_name(), + *iterator)); + if (std::next(iterator) != field_selection.end()) { + CEL_ASSIGN_OR_RETURN(target_descriptor, GetMessage(field_descriptor)); + } + } + } + } + return types_and_visible_fields; +} + +} // namespace + +absl::StatusOr> +ProtoTypeMaskRegistry::Create( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks) { + CEL_ASSIGN_OR_RETURN( + auto types_and_visible_fields, + ComputeVisibleFieldsMap(descriptor_pool, proto_type_masks)); + std::shared_ptr proto_type_mask_registry = + absl::WrapUnique(new ProtoTypeMaskRegistry(types_and_visible_fields)); + return proto_type_mask_registry; +} + +bool ProtoTypeMaskRegistry::FieldIsVisible(absl::string_view type_name, + absl::string_view field_name) const { + auto iterator = types_and_visible_fields_.find(type_name); + if (iterator != types_and_visible_fields_.end() && + !iterator->second.contains(field_name)) { + return false; + } + return true; +} + +std::string ProtoTypeMaskRegistry::DebugString() const { + std::string output = "ProtoTypeMaskRegistry { "; + for (auto& element : types_and_visible_fields_) { + absl::StrAppend(&output, "{type: '", element.first, "', visible_fields: '", + absl::StrJoin(element.second, "', '"), "'} "); + } + absl::StrAppend(&output, "}"); + return output; +} + +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask_registry.h b/checker/internal/proto_type_mask_registry.h new file mode 100644 index 000000000..338353e7d --- /dev/null +++ b/checker/internal/proto_type_mask_registry.h @@ -0,0 +1,83 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/internal/proto_type_mask.h" +#include "google/protobuf/descriptor.h" + +namespace cel::checker_internal { + +// Stores information related to ProtoTypeMasks. Visibility is defined per type, +// meaning that all messages of a type have the same visible fields. +class ProtoTypeMaskRegistry { + public: + // Processes the input proto type masks to create a ProtoTypeMaskRegistry. + // Returns an error if one of the proto type masks is not valid. For example, + // if a type is not found in the descriptor pool, if a field name is not + // found, or if a field is not a message type when we are expecting it to be. + // Returns an error if there is a conflict in field visibility when + // updating the map. + static absl::StatusOr> Create( + const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool, + const std::vector& proto_type_masks); + + const absl::flat_hash_map>& + GetTypesAndVisibleFields() const { + return types_and_visible_fields_; + } + + // Returns true when the field name is visible. A field is visible if: + // 1. The type name is not a key in the map. + // 2. The type name is a key in the map and the field name is in the set of + // field names that are visible for the type. + bool FieldIsVisible(absl::string_view type_name, + absl::string_view field_name) const; + + template + friend void AbslStringify( + Sink& sink, + const std::shared_ptr& proto_type_mask_registry) { + sink.Append(proto_type_mask_registry->DebugString()); + } + + private: + explicit ProtoTypeMaskRegistry( + absl::flat_hash_map> + types_and_visible_fields) + : types_and_visible_fields_(std::move(types_and_visible_fields)) {} + + std::string DebugString() const; + + // Map of types that have a field mask where the keys are + // fully qualified type names and the values are the set of field names that + // are visible for the type. + absl::flat_hash_map> + types_and_visible_fields_; +}; + +} // namespace cel::checker_internal + +#endif // THIRD_PARTY_CEL_CPP_CHECKER_PROTO_TYPE_MASK_REGISTRY_H_ diff --git a/checker/internal/proto_type_mask_registry_test.cc b/checker/internal/proto_type_mask_registry_test.cc new file mode 100644 index 000000000..3a73c8823 --- /dev/null +++ b/checker/internal/proto_type_mask_registry_test.cc @@ -0,0 +1,402 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask_registry.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/internal/proto_type_mask.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::AllOf; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +using TypeMap = + absl::flat_hash_map>; + +TEST(ProtoTypeMaskRegistryTest, + CreateWithEmptyInputSucceedsAndAllFieldsAreVisible) { + std::vector proto_type_masks = {}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), IsEmpty()); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithEmptyTypeReturnsError) { + std::vector proto_type_masks = {ProtoTypeMask("", {})}; + EXPECT_THAT(ProtoTypeMaskRegistry::Create( + GetSharedTestingDescriptorPool().get(), proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type '' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithUnknownTypeReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("com.example.UnknownType", {})}; + EXPECT_THAT(ProtoTypeMaskRegistry::Create( + GetSharedTestingDescriptorPool().get(), proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type 'com.example.UnknownType' not found"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithEmptySetFieldPathSucceedsAndFieldsAreHidden) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", IsEmpty()))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDuplicateEmptySetFieldPathSucceedsAndFieldsAreHidden) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", IsEmpty()))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithEmptyFieldPathReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {""})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithDelimiterFieldPathReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", {"."})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithUnknownFieldReturnsError) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"unknown_field"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneNonMessageFieldsSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "single_any", "single_timestamp"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("single_int32", "single_any", + "single_timestamp")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_int32")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_any")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_timestamp")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, CreateWithDepthTwoNonMessageFieldReturnsError) { + std::vector proto_type_masks; + proto_type_masks.push_back( + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32.any_field_name"})); + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("field 'single_int32' is not a message type"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre(Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoMessageUnknownFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.unknown_field"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes.NestedMessage'"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthThreeMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.NestedTestAllTypes", + UnorderedElementsAre("payload")), + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "payload")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthOneRepeatedMessageFieldSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"repeated_nested_message"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT(proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("repeated_nested_message")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "repeated_nested_message")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithDepthTwoRepeatedMessageFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"repeated_nested_message.bb"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("field 'repeated_nested_message' is not a message type"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateWithListOfFieldPathsSucceedsAndFieldsAreVisible) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + proto_type_mask_registry->GetTypesAndVisibleFields(), + UnorderedElementsAre( + Pair("cel.expr.conformance.proto3.NestedTestAllTypes", + UnorderedElementsAre("payload")), + Pair("cel.expr.conformance.proto3.TestAllTypes", + UnorderedElementsAre("standalone_message", "single_int32")), + Pair("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + UnorderedElementsAre("bb")))); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible("any_type_name", + "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "payload")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.NestedTestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "standalone_message")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "single_int32")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes", "any_field_name")); + EXPECT_TRUE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", "bb")); + EXPECT_FALSE(proto_type_mask_registry->FieldIsVisible( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + "any_field_name")); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateAddVisibleFieldThenAllHiddenFieldsReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.bb"}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + {})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "cannot insert a proto type mask with all hidden fields when " + "type 'cel.expr.conformance.proto3.TestAllTypes.NestedMessage' " + "has already been inserted with a proto type mask with a visible " + "field"))); +} + +TEST(ProtoTypeMaskRegistryTest, + CreateAddAllHiddenThenVisibleFieldReturnsError) { + std::vector proto_type_masks = { + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes.NestedMessage", + {}), + ProtoTypeMask("cel.expr.conformance.proto3.TestAllTypes", + {"standalone_message.bb"})}; + EXPECT_THAT( + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "cannot insert a proto type mask with visible field 'bb' when " + "type 'cel.expr.conformance.proto3.TestAllTypes.NestedMessage' " + "has already been inserted with a proto type mask with all " + "hidden fields"))); +} + +TEST(ProtoTypeMaskRegistryTest, AbslStringifyPrintsTypesAndVisibleFieldsMap) { + std::vector proto_type_masks = {ProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message.bb"})}; + ASSERT_OK_AND_ASSIGN( + std::shared_ptr proto_type_mask_registry, + ProtoTypeMaskRegistry::Create(GetSharedTestingDescriptorPool().get(), + proto_type_masks)); + EXPECT_THAT( + absl::StrCat(proto_type_mask_registry), + AllOf(HasSubstr("ProtoTypeMaskRegistry {"), + HasSubstr("{type: 'cel.expr.conformance.proto3.TestAllTypes', " + "visible_fields: 'standalone_message'}"), + HasSubstr("{type: " + "'cel.expr.conformance.proto3.TestAllTypes.NestedMessage'" + ", visible_fields: 'bb'}"))); +} + +} // namespace +} // namespace cel::checker_internal diff --git a/checker/internal/proto_type_mask_test.cc b/checker/internal/proto_type_mask_test.cc new file mode 100644 index 000000000..0c534f8cf --- /dev/null +++ b/checker/internal/proto_type_mask_test.cc @@ -0,0 +1,143 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "checker/internal/proto_type_mask.h" + +#include +#include + +#include "absl/container/btree_set.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "checker/internal/field_path.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" + +namespace cel::checker_internal { +namespace { + +using ::absl_testing::StatusIs; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +TEST(ProtoTypeMaskTest, EmptyTypeNameAndEmptyFieldPathsSucceeds) { + std::string type_name = ""; + std::vector field_paths; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_EQ(proto_type_mask.GetTypeName(), ""); + EXPECT_THAT(proto_type_mask.GetFieldPaths(), IsEmpty()); +} + +TEST(ProtoTypeMaskTest, NotEmptyTypeNameAndNotEmptyFieldPathsSucceeds) { + std::string type_name = "google.type.Expr"; + std::vector field_paths = {"resource.name", "resource.type"}; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_EQ(proto_type_mask.GetTypeName(), "google.type.Expr"); + EXPECT_THAT(proto_type_mask.GetFieldPaths(), + UnorderedElementsAre(FieldPath("resource.name"), + FieldPath("resource.type"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithEmptyTypeReturnsError) { + ProtoTypeMask proto_type_mask("", {}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type '' not found"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithUnknownTypeReturnsError) { + ProtoTypeMask proto_type_mask("com.example.UnknownType", {}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type 'com.example.UnknownType' not found"))); +} + +TEST(ProtoTypeMaskTest, + GetFieldNamesWithEmptySetFieldPathSucceedsAndReturnsEmptySet) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", {}); + ASSERT_OK_AND_ASSIGN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get())); + EXPECT_THAT(field_names, IsEmpty()); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithEmptyFieldPathReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {""}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithDelimiterFieldPathReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "."}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field '' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskTest, GetFieldNamesWithUnknownFieldReturnsError) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"unknown_field"}); + EXPECT_THAT( + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'"))); +} + +TEST(ProtoTypeMaskTest, + GetFieldNamesWithValidFieldsSucceedsAndReturnsFieldNames) { + ProtoTypeMask proto_type_mask("cel.expr.conformance.proto3.TestAllTypes", + {"single_int32", "single_string"}); + ASSERT_OK_AND_ASSIGN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get())); + EXPECT_THAT(field_names, + UnorderedElementsAre("single_int32", "single_string")); +} + +TEST(ProtoTypeMaskTest, + GetFieldNamesWithValidFieldPathsSucceedsAndReturnsFieldNames) { + ProtoTypeMask proto_type_mask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32", + "child.any_field_name"}); + ASSERT_OK_AND_ASSIGN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(GetSharedTestingDescriptorPool().get())); + EXPECT_THAT(field_names, UnorderedElementsAre("payload", "child")); +} + +TEST(ProtoTypeMaskTest, AbslStringifyPrintsTypeNameAndFieldPaths) { + std::string type_name = "google.type.Expr"; + std::vector field_paths = {"resource.name", "resource.type"}; + ProtoTypeMask proto_type_mask(type_name, field_paths); + EXPECT_THAT(absl::StrCat(proto_type_mask), + HasSubstr("ProtoTypeMask { type name: 'google.type.Expr', field " + "paths: { 'resource.name', 'resource.type' } }")); +} + +} // namespace +} // namespace cel::checker_internal From dd4f36816db185068e3ec8ba3663c95eb3f930db Mon Sep 17 00:00:00 2001 From: Mike Kruskal Date: Thu, 4 Jun 2026 15:58:40 -0700 Subject: [PATCH 50/87] Fix misuses of __cxa_demangle length parameters PiperOrigin-RevId: 926935330 --- common/typeinfo.cc | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/common/typeinfo.cc b/common/typeinfo.cc index 86bae1934..b07275712 100644 --- a/common/typeinfo.cc +++ b/common/typeinfo.cc @@ -57,18 +57,13 @@ std::string TypeInfo::DebugString() const { } return std::string(demangled.get()); #else - size_t length = 0; int status = 0; std::unique_ptr demangled( - abi::__cxa_demangle(rep_->name(), nullptr, &length, &status)); + abi::__cxa_demangle(rep_->name(), nullptr, nullptr, &status)); if (status != 0 || demangled == nullptr) { return std::string(rep_->name()); } - while (length != 0 && demangled.get()[length - 1] == '\0') { - // length includes the null terminator, remove it. - --length; - } - return std::string(demangled.get(), length); + return std::string(demangled.get()); #endif #else return absl::StrCat("0x", absl::Hex(absl::bit_cast(rep_))); From 35957777f1d8c22f3c4ac5e9cc333c81f559fc8c Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 8 Jun 2026 14:25:17 -0700 Subject: [PATCH 51/87] Variadic logical operators PiperOrigin-RevId: 928770216 --- checker/internal/BUILD | 2 + checker/internal/type_checker_impl.cc | 7 +- checker/internal/type_checker_impl_test.cc | 88 ++++++ checker/internal/type_inference_context.cc | 16 +- eval/compiler/flat_expr_builder.cc | 203 +++++++------ eval/compiler/flat_expr_builder_test.cc | 319 ++++++++++++++++++--- parser/options.h | 4 + parser/parser.cc | 21 +- parser/parser_test.cc | 53 ++++ tools/cel_unparser.cc | 25 +- tools/cel_unparser_test.cc | 19 ++ 11 files changed, 613 insertions(+), 144 deletions(-) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 25550616a..26c7b543f 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -145,6 +145,7 @@ cc_library( "//common:container", "//common:decl", "//common:expr", + "//common:standard_definitions", "//common:type", "//common:type_kind", "//internal:lexis", @@ -238,6 +239,7 @@ cc_library( deps = [ ":format_type_name", "//common:decl", + "//common:standard_definitions", "//common:type", "//common:type_kind", "@com_google_absl//absl/container:flat_hash_map", diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 1ce871255..6b6b051b1 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -52,6 +52,7 @@ #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" +#include "common/standard_definitions.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/status_macros.h" @@ -894,8 +895,12 @@ const FunctionDecl* ResolveVisitor::ResolveFunctionCallShape( if (decl == nullptr) { return true; } + bool is_logical_op = (candidate == cel::StandardFunctions::kAnd || + candidate == cel::StandardFunctions::kOr) && + arg_count >= 2; for (const auto& ovl : decl->overloads()) { - if (ovl.member() == is_receiver && ovl.args().size() == arg_count) { + if (ovl.member() == is_receiver && + (ovl.args().size() == arg_count || is_logical_op)) { return false; } } diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index 893f0689d..61ef7d55b 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -55,6 +55,7 @@ #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" namespace cel { namespace checker_internal { @@ -1471,6 +1472,93 @@ TEST(TypeCheckerImplTest, TypeInferredFromStructCreation) { std::make_unique(DynTypeSpec()))))))); } +struct VariadicLogicalCheckerTestCase { + std::string expr; +}; + +class VariadicLogicalCheckerTest + : public testing::TestWithParam {}; + +TEST_P(VariadicLogicalCheckerTest, Check) { + const auto& test_case = GetParam(); + + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource(test_case.expr)); + ASSERT_OK_AND_ASSIGN(auto parsed_ast, parser->Parse(*source)); + + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + TypeCheckerImpl impl(std::move(env)); + auto checker_builder = impl.ToBuilder(); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("a", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("b", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("c", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("d", BoolType())), + IsOk()); + ASSERT_THAT(checker_builder->AddVariable(MakeVariableDecl("e", BoolType())), + IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, checker_builder->Build()); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + checker->Check(std::move(parsed_ast))); + + ASSERT_TRUE(result.IsValid()) + << absl::StrJoin(result.GetIssues(), "\n", + [](std::string* out, const TypeCheckIssue& issue) { + absl::StrAppend(out, issue.message()); + }); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), + Eq(AstType(PrimitiveType::kBool))))); +} + +INSTANTIATE_TEST_SUITE_P( + VariadicLogicalChecker, VariadicLogicalCheckerTest, + testing::Values(VariadicLogicalCheckerTestCase{"true && false && true"}, + VariadicLogicalCheckerTestCase{"a && b && c && d"}, + VariadicLogicalCheckerTestCase{"a || b || c || d"}, + VariadicLogicalCheckerTestCase{"a && b && (c || d || e)"}, + VariadicLogicalCheckerTestCase{"a && b && c"}, + VariadicLogicalCheckerTestCase{"a || b || c"}, + VariadicLogicalCheckerTestCase{"[a, b, c].exists(x, x)"}, + VariadicLogicalCheckerTestCase{"[a, b, c].all(x, x)"})); + +TEST(TypeCheckerImplTest, VariadicLogicalOperatorsError) { + cel::expr::ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr { + call_expr { + function: "_&&_" + args { const_expr { bool_value: true } } + } + } + )pb", + &parsed_expr)); + ASSERT_OK_AND_ASSIGN(auto parsed_ast, + cel::CreateAstFromParsedExpr(parsed_expr)); + + google::protobuf::Arena arena; + TypeCheckEnv env(GetSharedTestingDescriptorPool()); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + impl.Check(std::move(parsed_ast))); + + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT( + result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, "undeclared reference"))); +} + TEST(TypeCheckerImplTest, ExpectedTypeMatches) { google::protobuf::Arena arena; TypeCheckEnv env(GetSharedTestingDescriptorPool()); diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 5b909d982..1a87d9e15 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -30,6 +30,7 @@ #include "absl/types/span.h" #include "checker/internal/format_type_name.h" #include "common/decl.h" +#include "common/standard_definitions.h" #include "common/type.h" #include "common/type_kind.h" @@ -537,21 +538,28 @@ TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, bool is_receiver) { std::optional result_type; + bool is_logical_op = (decl.name() == cel::StandardFunctions::kAnd || + decl.name() == cel::StandardFunctions::kOr) && + argument_types.size() >= 2; + std::vector matching_overloads; for (const auto& ovl : decl.overloads()) { if (ovl.member() != is_receiver || - argument_types.size() != ovl.args().size()) { + (!is_logical_op && argument_types.size() != ovl.args().size())) { continue; } auto call_type_instance = InstantiateFunctionOverload(*this, ovl); - ABSL_DCHECK_EQ(argument_types.size(), - call_type_instance.param_types.size()); + if (!is_logical_op) { + ABSL_DCHECK_EQ(argument_types.size(), + call_type_instance.param_types.size()); + } bool is_match = true; AssignabilityContext assignability_context = CreateAssignabilityContext(); for (int i = 0; i < argument_types.size(); ++i) { + int param_index = is_logical_op ? 0 : i; if (!assignability_context.IsAssignable( - argument_types[i], call_type_instance.param_types[i])) { + argument_types[i], call_type_instance.param_types[param_index])) { is_match = false; break; } diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 1e3f4ecd3..d6ccdf040 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -232,7 +232,7 @@ class BinaryCondVisitor : public CondVisitor { private: FlatExprVisitor* visitor_; const BinaryCond cond_; - Jump jump_step_; + std::vector jump_steps_; bool short_circuiting_; }; @@ -622,7 +622,7 @@ class FlatExprVisitor : public cel::AstVisitor { program_optimizers_) { absl::Status status = optimizer->OnPreVisit(extension_context_, expr); if (!status.ok()) { - SetProgressStatusError(status); + SetProgressStatusIfError(status); } } } @@ -639,7 +639,7 @@ class FlatExprVisitor : public cel::AstVisitor { program_optimizers_) { absl::Status status = optimizer->OnPostVisit(extension_context_, expr); if (!status.ok()) { - SetProgressStatusError(status); + SetProgressStatusIfError(status); return; } } @@ -657,7 +657,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (!comprehension_stack_.empty() && comprehension_stack_.back().is_optimizable_bind && (&comprehension_stack_.back().comprehension->accu_init() == &expr)) { - SetProgressStatusError( + SetProgressStatusIfError( MaybeExtractSubexpression(&expr, comprehension_stack_.back())); } @@ -666,7 +666,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (block.current_binding == &expr) { int index = program_builder_.ExtractSubexpression(&expr); if (index == -1) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("failed to extract subexpression")); return; } @@ -686,7 +686,7 @@ class FlatExprVisitor : public cel::AstVisitor { ConvertConstant(const_expr, cel::NewDeleteAllocator()); if (!converted_value.ok()) { - SetProgressStatusError(converted_value.status()); + SetProgressStatusIfError(converted_value.status()); return; } @@ -722,13 +722,13 @@ class FlatExprVisitor : public cel::AstVisitor { if (absl::ConsumePrefix(&index_suffix, "@index")) { size_t index; if (!absl::SimpleAtoi(index_suffix, &index)) { - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError("bad @index")))); return {-1, -1}; } if (index >= block.size) { - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError(absl::StrCat( "invalid @index greater than number of bindings: ", @@ -736,7 +736,7 @@ class FlatExprVisitor : public cel::AstVisitor { return {-1, -1}; } if (index >= block.current_index) { - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError(absl::StrCat( "@index references current or future binding: ", index, @@ -754,7 +754,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (record.iter_var_in_scope && record.comprehension->iter_var() == path) { if (record.is_optimizable_bind) { - SetProgressStatusError(issue_collector_.AddIssue( + SetProgressStatusIfError(issue_collector_.AddIssue( RuntimeIssue::CreateWarning(absl::InvalidArgumentError( "Unexpected iter_var access in trivial comprehension")))); return {-1, -1}; @@ -781,7 +781,7 @@ class FlatExprVisitor : public cel::AstVisitor { // If we see a CSE generated comprehension variable that was not // resolvable through the normal comprehension scope resolution, reject it // now rather than surfacing errors at activation time. - SetProgressStatusError( + SetProgressStatusIfError( issue_collector_.AddIssue(RuntimeIssue::CreateError( absl::InvalidArgumentError("out of scope reference to CSE " "generated comprehension variable")))); @@ -811,7 +811,7 @@ class FlatExprVisitor : public cel::AstVisitor { auto* subexpression = program_builder_.GetExtractedSubexpression(slot.subexpression); if (subexpression == nullptr) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InternalError("bad subexpression reference")); return; } @@ -965,7 +965,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 1) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "unexpected number of dependencies for select operation.")); return; } @@ -1022,7 +1022,7 @@ class FlatExprVisitor : public cel::AstVisitor { // cel.@block if (block_.has_value()) { // There can only be one for now. - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("multiple cel.@block are not allowed")); return; } @@ -1030,17 +1030,17 @@ class FlatExprVisitor : public cel::AstVisitor { BlockInfo& block = *block_; block.in = true; if (call_expr.args().empty()) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "malformed cel.@block: missing list of bound expressions")); return; } if (call_expr.args().size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "malformed cel.@block: missing bound expression")); return; } if (!call_expr.args()[0].has_list_expr()) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("malformed cel.@block: first argument " "is not a list of bound expressions")); return; @@ -1051,7 +1051,7 @@ class FlatExprVisitor : public cel::AstVisitor { block.bindings_set.reserve(block.size); for (const auto& list_expr_element : list_expr.elements()) { if (list_expr_element.optional()) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("malformed cel.@block: list of bound " "expressions contains an optional")); return; @@ -1093,7 +1093,7 @@ class FlatExprVisitor : public cel::AstVisitor { void MakeTernaryRecursive(const cel::Expr* expr) { if (expr->call_expr().args().size() != 3) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin ternary")); return; } @@ -1109,7 +1109,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (condition_plan == nullptr || !condition_plan->IsRecursive() || left_plan == nullptr || !left_plan->IsRecursive() || right_plan == nullptr || !right_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + SetProgressStatusIfError(FailedRecursivePlanning()); return; } @@ -1126,45 +1126,52 @@ class FlatExprVisitor : public cel::AstVisitor { } void MakeShortcircuitRecursive(const cel::Expr* expr, bool is_or) { - if (expr->call_expr().args().size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + int args_size = expr->call_expr().args().size(); + if (args_size < 2) { + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin boolean operator &&/||")); return; } - const cel::Expr* left_expr = &expr->call_expr().args()[0]; - const cel::Expr* right_expr = &expr->call_expr().args()[1]; - auto* left_plan = program_builder_.GetSubexpression(left_expr); - auto* right_plan = program_builder_.GetSubexpression(right_expr); - - if (left_plan == nullptr || !left_plan->IsRecursive() || - right_plan == nullptr || !right_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + auto* current_plan = + program_builder_.GetSubexpression(&expr->call_expr().args()[0]); + if (current_plan == nullptr || !current_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); return; } + int current_depth = current_plan->recursive_program().depth; + std::unique_ptr current_step = + current_plan->ExtractRecursiveProgram().step; - int max_depth = std::max({0, left_plan->recursive_program().depth, - right_plan->recursive_program().depth}); - - if (is_or) { - SetRecursiveStep( - CreateDirectOrStep(left_plan->ExtractRecursiveProgram().step, - right_plan->ExtractRecursiveProgram().step, - expr->id(), options_.short_circuiting), - max_depth + 1); - } else { - SetRecursiveStep( - CreateDirectAndStep(left_plan->ExtractRecursiveProgram().step, - right_plan->ExtractRecursiveProgram().step, - expr->id(), options_.short_circuiting), - max_depth + 1); + for (int i = 1; i < args_size; ++i) { + auto* next_plan = + program_builder_.GetSubexpression(&expr->call_expr().args()[i]); + if (next_plan == nullptr || !next_plan->IsRecursive()) { + SetProgressStatusIfError(FailedRecursivePlanning()); + return; + } + current_depth = + std::max(current_depth, next_plan->recursive_program().depth); + std::unique_ptr next_step = + next_plan->ExtractRecursiveProgram().step; + if (is_or) { + current_step = + CreateDirectOrStep(std::move(current_step), std::move(next_step), + expr->id(), options_.short_circuiting); + } else { + current_step = + CreateDirectAndStep(std::move(current_step), std::move(next_step), + expr->id(), options_.short_circuiting); + } + current_depth++; } + SetRecursiveStep(std::move(current_step), current_depth); } void MakeOptionalShortcircuit(const cel::Expr* expr, bool is_or_value) { if (!expr->call_expr().has_target() || expr->call_expr().args().size() != 1) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for optional.or{Value}")); return; } @@ -1176,7 +1183,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (left_plan == nullptr || !left_plan->IsRecursive() || right_plan == nullptr || !right_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + SetProgressStatusIfError(FailedRecursivePlanning()); return; } int max_depth = std::max({0, left_plan->recursive_program().depth, @@ -1200,7 +1207,7 @@ class FlatExprVisitor : public cel::AstVisitor { program_builder_.GetSubexpression(&comprehension->result()); if (result_plan == nullptr || !result_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + SetProgressStatusIfError(FailedRecursivePlanning()); return; } @@ -1234,7 +1241,7 @@ class FlatExprVisitor : public cel::AstVisitor { loop_plan == nullptr || !loop_plan->IsRecursive() || condition_plan == nullptr || !condition_plan->IsRecursive() || result_plan == nullptr || !result_plan->IsRecursive()) { - SetProgressStatusError(FailedRecursivePlanning()); + SetProgressStatusIfError(FailedRecursivePlanning()); return; } @@ -1462,7 +1469,7 @@ class FlatExprVisitor : public cel::AstVisitor { return; } - SetProgressStatusError(comprehension_stack_.back().visitor->PostVisitArg( + SetProgressStatusIfError(comprehension_stack_.back().visitor->PostVisitArg( comprehension_arg, comprehension_stack_.back().expr)); } @@ -1524,7 +1531,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (std::optional depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != list_expr.elements().size()) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "Unexpected number of plan elements for CreateList expr")); return; } @@ -1547,7 +1554,7 @@ class FlatExprVisitor : public cel::AstVisitor { auto status_or_resolved_fields = ResolveCreateStructFields(struct_expr, expr.id()); if (!status_or_resolved_fields.ok()) { - SetProgressStatusError(status_or_resolved_fields.status()); + SetProgressStatusIfError(status_or_resolved_fields.status()); return; } std::string resolved_name = @@ -1558,7 +1565,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != struct_expr.fields().size()) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "Unexpected number of plan elements for CreateStruct expr")); return; } @@ -1599,7 +1606,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (auto depth = RecursionEligible(); depth.has_value()) { auto deps = ExtractRecursiveDependencies(); if (deps.size() != 2 * map_expr.entries().size()) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "Unexpected number of plan elements for CreateStruct expr")); return; } @@ -1661,7 +1668,7 @@ class FlatExprVisitor : public cel::AstVisitor { "No overloads provided for FunctionStep creation"), RuntimeIssue::ErrorCode::kNoMatchingOverload)); if (!status.ok()) { - SetProgressStatusError(status); + SetProgressStatusIfError(status); return; } } @@ -1692,7 +1699,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (step.ok()) { return AddStep(*std::move(step)); } else { - SetProgressStatusError(step.status()); + SetProgressStatusIfError(step.status()); } return nullptr; } @@ -1711,19 +1718,19 @@ class FlatExprVisitor : public cel::AstVisitor { return; } if (program_builder_.current() == nullptr) { - SetProgressStatusError(absl::InternalError( + SetProgressStatusIfError(absl::InternalError( "CEL AST traversal out of order in flat_expr_builder.")); return; } program_builder_.current()->set_recursive_program(std::move(step), depth); if (depth > max_recursion_depth_) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( absl::StrCat("Maximum recursion depth of ", options_.max_recursion_depth, " exceeded"))); } } - void SetProgressStatusError(const absl::Status& status) { + void SetProgressStatusIfError(const absl::Status& status) { if (progress_status_.ok() && !status.ok()) { progress_status_ = status; } @@ -1765,7 +1772,7 @@ class FlatExprVisitor : public cel::AstVisitor { if (valid_expression) { return true; } - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( absl::StrCat(error_message, message_parts...))); return false; } @@ -1947,7 +1954,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleIndex( if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin index operator")); return CallHandlerResult::kIntercepted; } @@ -1974,7 +1981,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNot( if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin not operator")); return CallHandlerResult::kIntercepted; } @@ -1997,7 +2004,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleNotStrictlyFalse( if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 1) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("unexpected number of args for builtin " "@not_strictly_false operator")); return CallHandlerResult::kIntercepted; @@ -2016,7 +2023,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock( ABSL_DCHECK(call_expr.function() == kBlock); if (!block_.has_value() || block_->expr != &expr || call_expr.args().size() != 2 || call_expr.has_target()) { - SetProgressStatusError( + SetProgressStatusIfError( absl::InvalidArgumentError("unexpected call to internal cel.@block")); return CallHandlerResult::kIntercepted; } @@ -2101,7 +2108,7 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleHeterogeneousEquality( if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin equality operator")); return CallHandlerResult::kIntercepted; } @@ -2126,7 +2133,7 @@ FlatExprVisitor::HandleHeterogeneousEqualityIn(const cel::Expr& expr, if (auto depth = RecursionEligible(); depth.has_value()) { auto args = ExtractRecursiveDependencies(); if (args.size() != 2) { - SetProgressStatusError(absl::InvalidArgumentError( + SetProgressStatusIfError(absl::InvalidArgumentError( "unexpected number of args for builtin 'in' operator")); return CallHandlerResult::kIntercepted; } @@ -2164,13 +2171,14 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { if (visitor_->PlanRecursiveProgram()) { return; } - if (short_circuiting_ && arg_num == 0 && + const int last_arg_index = expr->call_expr().args().size() - 1; + if (short_circuiting_ && arg_num < last_arg_index && (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) { // If first branch evaluation result is enough to determine output, // jump over the second branch and provide result of the first argument as // final output. - // Retain a pointer to the jump step so we can update the target after - // planning the second argument. + // Retain pointers to the jump steps so we can update the target after + // planning the next arguments. std::unique_ptr jump_step; switch (cond_) { case BinaryCond::kAnd: @@ -2185,7 +2193,7 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { ProgramStepIndex index = visitor_->GetCurrentIndex(); if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); jump_step_ptr) { - jump_step_ = Jump(index, jump_step_ptr); + jump_steps_.push_back(Jump(index, jump_step_ptr)); } } } @@ -2215,7 +2223,7 @@ void BinaryCondVisitor::PostVisitTarget(const cel::Expr* expr) { ProgramStepIndex index = visitor_->GetCurrentIndex(); if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); jump_step_ptr) { - jump_step_ = Jump(index, jump_step_ptr); + jump_steps_.push_back(Jump(index, jump_step_ptr)); } } } @@ -2243,28 +2251,36 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { return; } - switch (cond_) { - case BinaryCond::kAnd: - visitor_->AddStep(CreateAndStep(expr->id())); - break; - case BinaryCond::kOr: - visitor_->AddStep(CreateOrStep(expr->id())); - break; - case BinaryCond::kOptionalOr: - visitor_->AddStep( - CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); - break; - case BinaryCond::kOptionalOrValue: - visitor_->AddStep(CreateOptionalOrStep(/*is_or_value=*/true, expr->id())); - break; - default: - ABSL_UNREACHABLE(); + int args_count = (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr) + ? expr->call_expr().args().size() + : 2; + for (int i = 0; i < args_count - 1; ++i) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->AddStep(CreateAndStep(expr->id())); + break; + case BinaryCond::kOr: + visitor_->AddStep(CreateOrStep(expr->id())); + break; + case BinaryCond::kOptionalOr: + visitor_->AddStep( + CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); + break; + case BinaryCond::kOptionalOrValue: + visitor_->AddStep( + CreateOptionalOrStep(/*is_or_value=*/true, expr->id())); + break; + default: + ABSL_UNREACHABLE(); + } } if (short_circuiting_) { // If short-circuiting is enabled, point the conditional jump past the // boolean operator step. - visitor_->SetProgressStatusError( - jump_step_.set_target(visitor_->GetCurrentIndex())); + for (auto& jump : jump_steps_) { + visitor_->SetProgressStatusIfError( + jump.set_target(visitor_->GetCurrentIndex())); + } } } @@ -2321,7 +2337,7 @@ void TernaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { if (visitor_->ValidateOrError( jump_to_second_.exists(), "Error configuring ternary operator: jump_to_second_ is null")) { - visitor_->SetProgressStatusError( + visitor_->SetProgressStatusIfError( jump_to_second_.set_target(visitor_->GetCurrentIndex())); } } @@ -2339,13 +2355,13 @@ void TernaryCondVisitor::PostVisit(const cel::Expr* expr) { if (visitor_->ValidateOrError( error_jump_.exists(), "Error configuring ternary operator: error_jump_ is null")) { - visitor_->SetProgressStatusError( + visitor_->SetProgressStatusIfError( error_jump_.set_target(visitor_->GetCurrentIndex())); } if (visitor_->ValidateOrError( jump_after_first_.exists(), "Error configuring ternary operator: jump_after_first_ is null")) { - visitor_->SetProgressStatusError( + visitor_->SetProgressStatusIfError( jump_after_first_.set_target(visitor_->GetCurrentIndex())); } } @@ -2403,7 +2419,8 @@ absl::Status ComprehensionVisitor::PostVisitArgDefault( break; } Jump jump_helper(index, jump_to_next); - visitor_->SetProgressStatusError(jump_helper.set_target(next_step_pos_)); + visitor_->SetProgressStatusIfError( + jump_helper.set_target(next_step_pos_)); // Set offsets jumping to the result step. if (cond_step_) { diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index d84007485..e2581e3fd 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -469,7 +469,7 @@ TEST(FlatExprBuilderTest, IdentExprUnsetName) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(ident_expr {})", &expr); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(ident_expr {})", &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -482,10 +482,10 @@ TEST(FlatExprBuilderTest, SelectExprUnsetField) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(select_expr{ + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(select_expr{ operand{ ident_expr {name: 'var'} } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -498,11 +498,11 @@ TEST(FlatExprBuilderTest, SelectExprUnsetOperand) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(select_expr{ + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"(select_expr{ field: 'field' operand { id: 1 } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -515,7 +515,8 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuVar) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"(comprehension_expr{})", &expr); + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"(comprehension_expr{})", &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -527,10 +528,10 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetIterVar) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{accu_var: "a"} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -542,12 +543,12 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetAccuInit) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: "a" iter_var: "b"} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -559,7 +560,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopCondition) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' @@ -567,7 +568,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopCondition) { const_expr {bool_value: true} }} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -579,7 +580,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopStep) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' @@ -590,7 +591,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetLoopStep) { const_expr {bool_value: true} }} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -602,7 +603,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetResult) { Expr expr; SourceInfo source_info; // An empty ident without the name set should error. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr{ accu_var: 'a' iter_var: 'b' @@ -616,7 +617,7 @@ TEST(FlatExprBuilderTest, ComprehensionExprUnsetResult) { const_expr {bool_value: false} }} )", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); EXPECT_THAT(builder.CreateExpression(&expr, &source_info).status(), @@ -628,7 +629,7 @@ TEST(FlatExprBuilderTest, MapComprehension) { Expr expr; SourceInfo source_info; // {1: "", 2: ""}.all(x, x > 0) - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr { iter_var: "k" accu_var: "accu" @@ -665,7 +666,7 @@ TEST(FlatExprBuilderTest, MapComprehension) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -683,7 +684,7 @@ TEST(FlatExprBuilderTest, InvalidContainer) { Expr expr; SourceInfo source_info; // foo && bar - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( call_expr { function: "_&&_" args { @@ -697,7 +698,7 @@ TEST(FlatExprBuilderTest, InvalidContainer) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -909,7 +910,7 @@ TEST(FlatExprBuilderTest, ParsedNamespacedFunctionSupportDisabled) { TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { CheckedExpr expr; // foo && bar - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( expr { id: 1 call_expr { @@ -928,7 +929,7 @@ TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -946,7 +947,7 @@ TEST(FlatExprBuilderTest, BasicCheckedExprSupport) { TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { CheckedExpr expr; // `foo.var1` && `bar.var2` - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 2 value { @@ -988,7 +989,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1008,7 +1009,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMap) { TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { CheckedExpr expr; // ext.and(var1, bar.var2) - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 1 value { @@ -1057,7 +1058,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1082,7 +1083,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapFunction) { TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { CheckedExpr expr; // && . - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 2 value { @@ -1125,7 +1126,7 @@ TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1160,7 +1161,7 @@ TEST(FlatExprBuilderTest, CheckedExprActivationMissesReferences) { TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { CheckedExpr expr; // {`var1`: 'hello'} - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( reference_map { key: 3 value { @@ -1190,7 +1191,7 @@ TEST(FlatExprBuilderTest, CheckedExprWithReferenceMapAndConstantFolding) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); builder.flat_expr_builder().AddAstTransform( @@ -1213,7 +1214,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForError) { Expr expr; SourceInfo source_info; // {}[0].all(x, x) should evaluate OK but return an error value - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 4 comprehension_expr { iter_var: "x" @@ -1278,7 +1279,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForError) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -1295,7 +1296,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { Expr expr; SourceInfo source_info; // 0.all(x, x) should evaluate OK but return an error value. - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 4 comprehension_expr { iter_var: "x" @@ -1349,7 +1350,7 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { } } })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_THAT(RegisterBuiltinFunctions(builder.GetRegistry()), IsOk()); @@ -1721,7 +1722,7 @@ TEST(FlatExprBuilderTest, NameCollisionWithComprehensionVarLeadingDot) { TEST(FlatExprBuilderTest, MapFieldPresence) { Expr expr; SourceInfo source_info; - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 1, select_expr{ operand { @@ -1731,7 +1732,7 @@ TEST(FlatExprBuilderTest, MapFieldPresence) { field: "string_int32_map" test_only: true })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, @@ -1765,7 +1766,7 @@ TEST(FlatExprBuilderTest, MapFieldPresence) { TEST(FlatExprBuilderTest, RepeatedFieldPresence) { Expr expr; SourceInfo source_info; - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( id: 1, select_expr{ operand { @@ -1775,7 +1776,7 @@ TEST(FlatExprBuilderTest, RepeatedFieldPresence) { field: "int32_list" test_only: true })", - &expr); + &expr)); CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv()); ASSERT_OK_AND_ASSIGN(auto cel_expr, @@ -2900,6 +2901,248 @@ TEST(FlatExprBuilderTest, BlockNested) { HasSubstr("multiple cel.@block are not allowed"))); } +struct VariadicLogicalEvalTestCase { + std::string label; + std::string expr; + std::string a_val; + std::string b_val; + std::string c_val; + std::string expected_type; // "bool", "error", "unknown" + bool expected_bool = false; +}; + +class FlatExprBuilderVariadicLogicalTest + : public testing::TestWithParam {}; + +TEST_P(FlatExprBuilderVariadicLogicalTest, Evaluate) { + const auto& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(test_case.expr)); + + cel::RuntimeOptions options; + options.unknown_processing = + cel::UnknownProcessingOptions::kAttributeAndFunction; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + ASSERT_OK_AND_ASSIGN(auto cel_expr, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + Activation activation; + google::protobuf::Arena arena; + std::vector unknown_patterns; + + // Set up variables: + auto insert_value = [&](absl::string_view name, const std::string& val) { + if (val == "true") { + activation.InsertValue(name, CelValue::CreateBool(true)); + } else if (val == "false") { + activation.InsertValue(name, CelValue::CreateBool(false)); + } else if (val == "error") { + activation.InsertValue(name, CreateErrorValue(&arena, "test error")); + } else if (val == "unknown1" || val == "unknown2") { + activation.InsertValue(name, CelValue::CreateBool(true)); + unknown_patterns.push_back(CreateCelAttributePattern(name, {})); + } + }; + + insert_value("a", test_case.a_val); + insert_value("b", test_case.b_val); + insert_value("c", test_case.c_val); + + if (!unknown_patterns.empty()) { + activation.set_unknown_attribute_patterns(std::move(unknown_patterns)); + } + + ASSERT_OK_AND_ASSIGN(CelValue result, cel_expr->Evaluate(activation, &arena)); + + if (test_case.expected_type == "bool") { + ASSERT_TRUE(result.IsBool()) << result.DebugString(); + EXPECT_EQ(result.BoolOrDie(), test_case.expected_bool); + } else if (test_case.expected_type == "error") { + EXPECT_TRUE(result.IsError()) << result.DebugString(); + } else if (test_case.expected_type == "unknown") { + EXPECT_TRUE(result.IsUnknownSet()) << result.DebugString(); + } +} + +INSTANTIATE_TEST_SUITE_P( + FlatExprBuilderVariadicLogicalTest, FlatExprBuilderVariadicLogicalTest, + testing::Values( + VariadicLogicalEvalTestCase{"AND_AllTrue", "a && b && c", "true", + "true", "true", "bool", true}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitFalse", "a && b && c", + "true", "false", "unset", "bool", false}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitFirstFalse", "a && b && c", + "false", "unset", "unset", "bool", false}, + VariadicLogicalEvalTestCase{"OR_AllFalse", "a || b || c", "false", + "false", "false", "bool", false}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitTrue", "a || b || c", + "false", "true", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitFirstTrue", "a || b || c", + "true", "unset", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"AND_Error", "a && b && c", "true", "error", + "true", "error"}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitBeforeError", + "a && b && c", "false", "error", "unset", + "bool", false}, + VariadicLogicalEvalTestCase{"OR_Error", "a || b || c", "false", "error", + "false", "error"}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitBeforeError", "a || b || c", + "true", "error", "unset", "bool", true}, + VariadicLogicalEvalTestCase{"AND_Unknown", "a && b && c", "true", + "unknown1", "true", "unknown"}, + VariadicLogicalEvalTestCase{"AND_ShortCircuitBeforeUnknown", + "a && b && c", "false", "unknown1", "unset", + "bool", false}, + VariadicLogicalEvalTestCase{"OR_Unknown", "a || b || c", "false", + "unknown1", "false", "unknown"}, + VariadicLogicalEvalTestCase{"OR_ShortCircuitBeforeUnknown", + "a || b || c", "true", "unknown1", "unset", + "bool", true}, + VariadicLogicalEvalTestCase{"AND_UnknownAggregation", "a && b && c", + "unknown1", "unknown2", "true", "unknown"}, + VariadicLogicalEvalTestCase{"OR_UnknownAggregation", "a || b || c", + "unknown1", "unknown2", "false", "unknown"}, + VariadicLogicalEvalTestCase{"Exists_True", "[a, b, c].exists(x, x)", + "false", "false", "true", "bool", true}, + VariadicLogicalEvalTestCase{"Exists_Unknown", "[a, b, c].exists(x, x)", + "false", "unknown1", "false", "unknown"}, + VariadicLogicalEvalTestCase{"All_False", "[a, b, c].all(x, x)", "true", + "true", "false", "bool", false}, + VariadicLogicalEvalTestCase{"All_Unknown", "[a, b, c].all(x, x)", + "true", "unknown1", "true", "unknown"})); + +struct RecursionDepthTestCase { + std::string label; + std::string expr; + int max_recursion_depth; + absl::StatusCode expected_status_code; + std::string expected_error_msg; +}; + +class FlatExprBuilderRecursionDepthTest + : public testing::TestWithParam {}; + +TEST_P(FlatExprBuilderRecursionDepthTest, CheckRecursionLimit) { + const auto& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(test_case.expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = test_case.max_recursion_depth; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + + auto result = + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()); + if (test_case.expected_status_code == absl::StatusCode::kOk) { + EXPECT_THAT(result, IsOk()); + } else { + EXPECT_THAT(result, StatusIs(test_case.expected_status_code, + HasSubstr(test_case.expected_error_msg))); + } +} + +INSTANTIATE_TEST_SUITE_P( + FlatExprBuilderRecursionDepthTest, FlatExprBuilderRecursionDepthTest, + testing::Values( + RecursionDepthTestCase{"AndChildLimitExceeded", "(1 + 1) && true", 1, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 1 exceeded"}, + RecursionDepthTestCase{"AndParentLimitExceeded", "(1 + 1) && true", 2, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 2 exceeded"}, + RecursionDepthTestCase{"AndLimitSuccess", "(1 + 1) && true", 3, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndLimitSuccessGenerous", "(1 + 1) && true", 10, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndLimitSuccessUnlimited", "(1 + 1) && true", + -1, absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"OrChildLimitExceeded", "(1 + 1) || true", 1, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 1 exceeded"}, + RecursionDepthTestCase{"OrParentLimitExceeded", "(1 + 1) || true", 2, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 2 exceeded"}, + RecursionDepthTestCase{"OrLimitSuccess", "(1 + 1) || true", 3, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"OrLimitSuccessGenerous", + "(1 + 1) || false || false || false || false || " + "(true && true && true && true && false)", + 10, absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"OrLimitSuccessUnlimited", "(1 + 1) || true", -1, + absl::StatusCode::kOk, ""}, + RecursionDepthTestCase{"AndDepthUpdateFromSubsequentArg", + "true && (1 + 1 + 1 + 1)", 4, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 4 exceeded"}, + RecursionDepthTestCase{"OrDepthUpdateFromSubsequentArg", + "true || (1 + 1 + 1 + 1)", 4, + absl::StatusCode::kInvalidArgument, + "Maximum recursion depth of 4 exceeded"})); + +TEST(FlatExprBuilderTest, NonRecursiveChildBlockAndError) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "_&&_" + args { const_expr: { bool_value: true } } + args { + call_expr: { + function: "cel.@block" + args { + list_expr { elements { const_expr: { int64_value: 1 } } } + } + args { ident_expr: { name: "@index0" } } + } + } + } + } + )pb", + &parsed_expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = 2; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("failed to build recursive program"))); +} + +TEST(FlatExprBuilderTest, NonRecursiveChildBlockOrError) { + ParsedExpr parsed_expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString( + R"pb( + expr: { + call_expr: { + function: "_||_" + args { const_expr: { bool_value: true } } + args { + call_expr: { + function: "cel.@block" + args { + list_expr { elements { const_expr: { int64_value: 1 } } } + } + args { ident_expr: { name: "@index0" } } + } + } + } + } + )pb", + &parsed_expr)); + + cel::RuntimeOptions options; + options.max_recursion_depth = 2; + options.fail_on_warnings = false; + CelExpressionBuilderFlatImpl builder(NewTestingRuntimeEnv(), options); + EXPECT_THAT( + builder.CreateExpression(&parsed_expr.expr(), &parsed_expr.source_info()), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("failed to build recursive program"))); +} + } // namespace } // namespace google::api::expr::runtime diff --git a/parser/options.h b/parser/options.h index 916a941f0..719bed454 100644 --- a/parser/options.h +++ b/parser/options.h @@ -62,6 +62,10 @@ struct ParserOptions final { // Limited to field specifiers in select and message creation, // enabled by default bool enable_quoted_identifiers = true; + + // Enables parsing logical AND & OR operators as a single flat variadic call + // instead of a balanced/nested binary AST structure. + bool enable_variadic_logical_operators = false; }; } // namespace cel diff --git a/parser/parser.cc b/parser/parser.cc index 709e2fd41..6c6434319 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -552,7 +552,7 @@ class ExpressionBalancer final { // balance creates a balanced tree from the sub-terms and returns the final // Expr value. - Expr Balance(); + Expr Balance(bool enable_variadic = false); private: // balancedTree recursively balances the terms provided to a commutative @@ -577,10 +577,13 @@ void ExpressionBalancer::AddTerm(int64_t op, Expr term) { ops_.push_back(op); } -Expr ExpressionBalancer::Balance() { +Expr ExpressionBalancer::Balance(bool enable_variadic) { if (terms_.size() == 1) { return std::move(terms_[0]); } + if (enable_variadic) { + return factory_.NewCall(ops_[0], function_, std::move(terms_)); + } return BalancedTree(0, ops_.size() - 1); } @@ -620,7 +623,8 @@ class ParserVisitor final : public CelBaseVisitor, const cel::MacroRegistry& macro_registry, bool add_macro_calls = false, bool enable_optional_syntax = false, - bool enable_quoted_identifiers = false) + bool enable_quoted_identifiers = false, + bool enable_variadic_logical_operators = false) : source_(source), factory_(source_), macro_registry_(macro_registry), @@ -628,7 +632,8 @@ class ParserVisitor final : public CelBaseVisitor, max_recursion_depth_(max_recursion_depth), add_macro_calls_(add_macro_calls), enable_optional_syntax_(enable_optional_syntax), - enable_quoted_identifiers_(enable_quoted_identifiers) {} + enable_quoted_identifiers_(enable_quoted_identifiers), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} ~ParserVisitor() override = default; @@ -719,6 +724,7 @@ class ParserVisitor final : public CelBaseVisitor, const bool add_macro_calls_; const bool enable_optional_syntax_; const bool enable_quoted_identifiers_; + const bool enable_variadic_logical_operators_; }; template ParseImpl( ExprRecursionListener listener(options.max_recursion_depth); ParserVisitor visitor( source, options.max_recursion_depth, registry, options.add_macro_calls, - options.enable_optional_syntax, options.enable_quoted_identifiers); + options.enable_optional_syntax, options.enable_quoted_identifiers, + options.enable_variadic_logical_operators); lexer.removeErrorListeners(); parser.removeErrorListeners(); diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 587b63a30..33c52b1d2 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1782,6 +1782,59 @@ TEST(NewParserBuilderTest, ToBuilderPreservesStdlibAndOptionalFromOptions) { EXPECT_FALSE(ast->IsChecked()); } +struct VariadicLogicalOperatorsTestCase { + std::string input; + std::string expected_adorned_string; +}; + +class VariadicLogicalOperatorsTest + : public testing::TestWithParam {}; + +TEST_P(VariadicLogicalOperatorsTest, Parse) { + const auto& test_case = GetParam(); + auto builder = cel::NewParserBuilder(); + builder->GetOptions().enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); + ASSERT_OK_AND_ASSIGN(auto source, cel::NewSource(test_case.input)); + ASSERT_OK_AND_ASSIGN(auto ast, parser->Parse(*source)); + + KindAndIdAdorner kind_and_id_adorner; + ExprPrinter w(kind_and_id_adorner); + std::string adorned_string = w.Print(ast->root_expr()); + EXPECT_EQ(adorned_string, test_case.expected_adorned_string); +} + +INSTANTIATE_TEST_SUITE_P( + VariadicLogicalOperators, VariadicLogicalOperatorsTest, + testing::Values( + VariadicLogicalOperatorsTestCase{ + .input = "a && b && c && d", + .expected_adorned_string = "_&&_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " c^#4:Expr.Ident#,\n" + " d^#6:Expr.Ident#\n" + ")^#3:Expr.Call#"}, + VariadicLogicalOperatorsTestCase{ + .input = "a || b || c || d", + .expected_adorned_string = "_||_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " c^#4:Expr.Ident#,\n" + " d^#6:Expr.Ident#\n" + ")^#3:Expr.Call#"}, + VariadicLogicalOperatorsTestCase{ + .input = "a && b && (c || d || e)", + .expected_adorned_string = "_&&_(\n" + " a^#1:Expr.Ident#,\n" + " b^#2:Expr.Ident#,\n" + " _||_(\n" + " c^#4:Expr.Ident#,\n" + " d^#5:Expr.Ident#,\n" + " e^#7:Expr.Ident#\n" + " )^#6:Expr.Call#\n" + ")^#3:Expr.Call#"})); + TEST(ParserTest, ParseFailurePopulatesIssues) { auto builder = cel::NewParserBuilder(); ASSERT_OK_AND_ASSIGN(auto parser, std::move(*builder).Build()); diff --git a/tools/cel_unparser.cc b/tools/cel_unparser.cc index 28a1187bb..741d91208 100644 --- a/tools/cel_unparser.cc +++ b/tools/cel_unparser.cc @@ -150,6 +150,8 @@ class Unparser { // - a ternary conditional operator bool IsBinaryOrTernaryOperator(const Expr& expr); + bool IsLogicalOperator(absl::string_view op); + template void Print(Ts&&... args) { absl::StrAppend(&output_, std::forward(args)...); @@ -436,6 +438,24 @@ absl::Status Unparser::VisitUnary(const Expr::Call& expr, absl::Status Unparser::VisitBinary(const Expr::Call& expr, const std::string& op) { + if (expr.args_size() < 2) { + return absl::InvalidArgumentError( + absl::StrCat("Unexpected binary: ", expr.ShortDebugString())); + } + + const auto& fun = expr.function(); + if (IsLogicalOperator(fun)) { + for (int i = 0; i < expr.args_size(); ++i) { + if (i > 0) { + Print(kSpace, op, kSpace); + } + const auto& arg = expr.args(i); + bool arg_paren = IsComplexOperatorWithRespectTo(arg, fun); + CEL_RETURN_IF_ERROR(VisitMaybeNested(arg, arg_paren)); + } + return absl::OkStatus(); + } + if (expr.args_size() != 2) { return absl::InvalidArgumentError( absl::StrCat("Unexpected binary: ", expr.ShortDebugString())); @@ -443,7 +463,6 @@ absl::Status Unparser::VisitBinary(const Expr::Call& expr, const auto& lhs = expr.args(0); const auto& rhs = expr.args(1); - const auto& fun = expr.function(); // add parens if the current operator is lower precedence than the lhs expr // operator. @@ -549,6 +568,10 @@ bool Unparser::IsBinaryOrTernaryOperator(const Expr& expr) { IsOperatorSamePrecedence(CelOperator::CONDITIONAL, expr); } +bool Unparser::IsLogicalOperator(absl::string_view op) { + return op == CelOperator::LOGICAL_AND || op == CelOperator::LOGICAL_OR; +} + } // namespace absl::StatusOr Unparse(const Expr& expr, diff --git a/tools/cel_unparser_test.cc b/tools/cel_unparser_test.cc index 4cba4ce4d..aca6e91fd 100644 --- a/tools/cel_unparser_test.cc +++ b/tools/cel_unparser_test.cc @@ -67,6 +67,22 @@ INSTANTIATE_TEST_SUITE_P( {// Empty Expr error {"", absl::InvalidArgumentError("Unsupported Expr")}, + // Logical operators with too few arguments (single argument) + { + R"pb( + call_expr { + function: "_&&_" + args { const_expr { bool_value: true } } + })pb", + absl::InvalidArgumentError("Unexpected binary")}, + { + R"pb( + call_expr { + function: "_||_" + args { const_expr { bool_value: true } } + })pb", + absl::InvalidArgumentError("Unexpected binary")}, + // Constants {"const_expr{}", absl::InvalidArgumentError("Unsupported Constant")}, {"const_expr{bool_value: true}", "true"}, @@ -619,6 +635,7 @@ TEST_P(UnparserTestTextExpr, Test) { options.add_macro_calls = true; options.enable_optional_syntax = true; options.enable_quoted_identifiers = true; + options.enable_variadic_logical_operators = true; ASSERT_OK_AND_ASSIGN(ParsedExpr result, Parse(GetParam().expr, "unparser", options)); @@ -779,6 +796,8 @@ INSTANTIATE_TEST_SUITE_P( {"has(a.`b.c`)", ""}, {"a.`b/c`", ""}, {"a.?`b/c`", ""}, + {"a && b && c && d", ""}, + {"a || b || c || d", ""}, })); } // namespace From 09dba1a9187b58e9bd7a8c74f8b6e006d16cc99f Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 8 Jun 2026 14:25:30 -0700 Subject: [PATCH 52/87] Refactor type signature generation to use TypeSpec. Refactors the internal signature generation logic to operate on `cel::TypeSpec` instead of `cel::Type`. Adds a utility to convert `cel::Type` to `cel::TypeSpec` PiperOrigin-RevId: 928770349 --- common/BUILD | 1 + common/ast/metadata.h | 4 + common/internal/BUILD | 3 +- common/internal/signature.cc | 275 ++++++++++++---------- common/internal/signature.h | 23 +- common/internal/signature_test.cc | 367 ++++++++++++++++++------------ common/type_spec_resolver.cc | 143 +++++++++++- common/type_spec_resolver.h | 3 + common/type_spec_resolver_test.cc | 27 +++ 9 files changed, 571 insertions(+), 275 deletions(-) diff --git a/common/BUILD b/common/BUILD index a016d2cb5..01710329b 100644 --- a/common/BUILD +++ b/common/BUILD @@ -53,6 +53,7 @@ cc_library( deps = [ ":ast", ":type", + ":type_kind", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/common/ast/metadata.h b/common/ast/metadata.h index 197790ff3..1a69b5b50 100644 --- a/common/ast/metadata.h +++ b/common/ast/metadata.h @@ -573,6 +573,10 @@ class TypeSpec { TypeSpecKind& mutable_type_kind() { return type_kind_; } + bool is_specified() const { + return !absl::holds_alternative(type_kind_); + } + bool has_dyn() const { return absl::holds_alternative(type_kind_); } diff --git a/common/internal/BUILD b/common/internal/BUILD index 48a8dfe8b..73cbf37e9 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -145,13 +145,11 @@ cc_library( deps = [ "//common:ast", "//common:type", - "//common:type_kind", "//common:type_spec_resolver", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", ], @@ -165,6 +163,7 @@ cc_test( "//common:ast", "//common:type", "//common:type_kind", + "//common:type_spec_resolver", "//internal:testing", "//internal:testing_descriptor_pool", "@com_google_absl//absl/base:no_destructor", diff --git a/common/internal/signature.cc b/common/internal/signature.cc index 5c75225f9..fe315bb04 100644 --- a/common/internal/signature.cc +++ b/common/internal/signature.cc @@ -26,11 +26,9 @@ #include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/types/optional.h" #include "common/ast.h" #include "common/type.h" -#include "common/type_kind.h" #include "common/type_spec_resolver.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" @@ -64,125 +62,145 @@ void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) { } } -absl::Status AppendTypeParameters(std::string* result, const Type& type); +absl::Status AppendTypeDesc(std::string* result, const TypeSpec& type_spec); -// Recursively appends a string representation of the given `type` to `result`. -// Type parameters are enclosed in angle brackets and separated by commas. -// -// Grammar: -// TypeDesc = NamespaceIdentifier [ "<" TypeList ">" ] ; -// NamespaceIdentifier = [ "." ] Identifier { "." Identifier } ; -// TypeList = TypeElem { "," TypeElem } ; -// TypeElem = TypeDesc | TypeParam -// TypeParam = "~" Alpha ; -// Identifier = ( Alpha | "_" ) { AlphaNumeric | "_" } ; -// (* Terminals *) -// Alpha = "a"..."z" | "A"..."Z" ; -// Digit = "0"..."9" ; -// AlphaNumeric = Alpha | Digit ; -// -// For compatibility, the implementation allows unexpected characters in -// type names and parameters and escapes them with a backslash. -absl::Status AppendTypeDesc(std::string* result, const Type& type) { - switch (type.kind()) { - case TypeKind::kNull: - absl::StrAppend(result, "null"); - break; - case TypeKind::kBool: - absl::StrAppend(result, "bool"); - break; - case TypeKind::kInt: - absl::StrAppend(result, "int"); - break; - case TypeKind::kUint: - absl::StrAppend(result, "uint"); - break; - case TypeKind::kDouble: - absl::StrAppend(result, "double"); - break; - case TypeKind::kString: - absl::StrAppend(result, "string"); - break; - case TypeKind::kBytes: - absl::StrAppend(result, "bytes"); - break; - case TypeKind::kDuration: - absl::StrAppend(result, "duration"); - break; - case TypeKind::kTimestamp: - absl::StrAppend(result, "timestamp"); - break; - case TypeKind::kAny: - absl::StrAppend(result, "any"); - break; - case TypeKind::kDyn: - absl::StrAppend(result, "dyn"); - break; - case TypeKind::kBoolWrapper: - absl::StrAppend(result, "bool_wrapper"); - break; - case TypeKind::kIntWrapper: - absl::StrAppend(result, "int_wrapper"); - break; - case TypeKind::kUintWrapper: - absl::StrAppend(result, "uint_wrapper"); - break; - case TypeKind::kDoubleWrapper: - absl::StrAppend(result, "double_wrapper"); - break; - case TypeKind::kStringWrapper: - absl::StrAppend(result, "string_wrapper"); - break; - case TypeKind::kBytesWrapper: - absl::StrAppend(result, "bytes_wrapper"); - break; - case TypeKind::kList: - absl::StrAppend(result, "list"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kMap: - absl::StrAppend(result, "map"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kFunction: - absl::StrAppend(result, "function"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kType: - absl::StrAppend(result, "type"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kTypeParam: - absl::StrAppend(result, "~"); - AppendEscaped(result, type.GetTypeParam().name(), /*escape_dot=*/true); - break; - case TypeKind::kOpaque: - AppendEscaped(result, type.name(), /*escape_dot=*/false); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kStruct: - AppendEscaped(result, type.name(), /*escape_dot=*/false); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - default: - return absl::InvalidArgumentError( - absl::StrFormat("Type kind: %s is not supported in CEL declarations", - type.DebugString())); +absl::Status AppendTypeSpecList(std::string* result, + const std::vector& params) { + if (!params.empty()) { + result->push_back('<'); + for (size_t i = 0; i < params.size(); ++i) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, params[i])); + if (i < params.size() - 1) { + result->push_back(','); + } + } + result->push_back('>'); } return absl::OkStatus(); } -absl::Status AppendTypeParameters(std::string* result, const Type& type) { - const auto& parameters = type.GetParameters(); - if (!parameters.empty()) { - result->push_back('<'); - for (size_t i = 0; i < parameters.size(); ++i) { - CEL_RETURN_IF_ERROR(AppendTypeDesc(result, parameters[i])); - if (i < parameters.size() - 1) { - result->push_back(','); - } +absl::Status AppendTypeDesc(std::string* result, const TypeSpec& type_spec) { + if (type_spec.has_null()) { + absl::StrAppend(result, "null"); + } else if (type_spec.has_dyn()) { + absl::StrAppend(result, "dyn"); + } else if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + absl::StrAppend(result, "bool"); + break; + case PrimitiveType::kInt64: + absl::StrAppend(result, "int"); + break; + case PrimitiveType::kUint64: + absl::StrAppend(result, "uint"); + break; + case PrimitiveType::kDouble: + absl::StrAppend(result, "double"); + break; + case PrimitiveType::kString: + absl::StrAppend(result, "string"); + break; + case PrimitiveType::kBytes: + absl::StrAppend(result, "bytes"); + break; + default: + return absl::InvalidArgumentError("Unsupported primitive type"); + } + } else if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + absl::StrAppend(result, "any"); + break; + case WellKnownTypeSpec::kTimestamp: + absl::StrAppend(result, "timestamp"); + break; + case WellKnownTypeSpec::kDuration: + absl::StrAppend(result, "duration"); + break; + default: + return absl::InvalidArgumentError("Unsupported well-known type"); } + } else if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + absl::StrAppend(result, "bool_wrapper"); + break; + case PrimitiveType::kInt64: + absl::StrAppend(result, "int_wrapper"); + break; + case PrimitiveType::kUint64: + absl::StrAppend(result, "uint_wrapper"); + break; + case PrimitiveType::kDouble: + absl::StrAppend(result, "double_wrapper"); + break; + case PrimitiveType::kString: + absl::StrAppend(result, "string_wrapper"); + break; + case PrimitiveType::kBytes: + absl::StrAppend(result, "bytes_wrapper"); + break; + default: + return absl::InvalidArgumentError("Unsupported wrapper type"); + } + } else if (type_spec.has_list_type()) { + absl::StrAppend(result, "list<"); + if (type_spec.list_type().elem_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.list_type().elem_type())); + } else { + absl::StrAppend(result, "dyn"); + } + result->push_back('>'); + } else if (type_spec.has_map_type()) { + absl::StrAppend(result, "map<"); + if (type_spec.map_type().key_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.map_type().key_type())); + } else { + absl::StrAppend(result, "dyn"); + } + result->push_back(','); + if (type_spec.map_type().value_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.map_type().value_type())); + } else { + absl::StrAppend(result, "dyn"); + } + result->push_back('>'); + } else if (type_spec.has_function()) { + absl::StrAppend(result, "function<"); + if (type_spec.function().result_type().is_specified()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.function().result_type())); + } else { + absl::StrAppend(result, "dyn"); + } + for (const auto& arg : type_spec.function().arg_types()) { + result->push_back(','); + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, arg)); + } + result->push_back('>'); + } else if (type_spec.has_type()) { + absl::StrAppend(result, "type"); + result->push_back('<'); + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, type_spec.type())); result->push_back('>'); + } else if (type_spec.has_type_param()) { + absl::StrAppend(result, "~"); + AppendEscaped(result, type_spec.type_param().type(), /*escape_dot=*/true); + } else if (type_spec.has_abstract_type()) { + AppendEscaped(result, type_spec.abstract_type().name(), + /*escape_dot=*/false); + CEL_RETURN_IF_ERROR(AppendTypeSpecList( + result, type_spec.abstract_type().parameter_types())); + } else if (type_spec.has_message_type()) { + AppendEscaped(result, type_spec.message_type().type(), + /*escape_dot=*/false); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported type in signature: ", FormatTypeSpec(type_spec))); } return absl::OkStatus(); } @@ -190,13 +208,32 @@ absl::Status AppendTypeParameters(std::string* result, const Type& type) { absl::StatusOr MakeTypeSignature(const Type& type) { std::string result; - CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type)); + CEL_ASSIGN_OR_RETURN(TypeSpec type_spec, ConvertTypeToTypeSpec(type)); + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type_spec)); + return result; +} + +absl::StatusOr MakeTypeSpecSignature(const TypeSpec& type_spec) { + std::string result; + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type_spec)); return result; } absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member) { + std::vector arg_type_specs; + arg_type_specs.reserve(args.size()); + for (const auto& arg : args) { + CEL_ASSIGN_OR_RETURN(TypeSpec type_spec, ConvertTypeToTypeSpec(arg)); + arg_type_specs.push_back(type_spec); + } + return MakeOverloadSignature(function_name, arg_type_specs, is_member); +} + +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member) { std::string result; if (is_member) { if (!args.empty()) { @@ -589,10 +626,14 @@ absl::StatusOr ParseFunctionSignature( return out; } +absl::StatusOr ParseTypeSpec(std::string_view signature) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + return ParseTypeSignature(stripped_sig); +} + absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool& pool) { - std::string stripped_sig = StripUnescapedWhitespace(signature); - CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSignature(stripped_sig)); + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSpec(signature)); return cel::ConvertTypeSpecToType(type_spec, arena, pool); } diff --git a/common/internal/signature.h b/common/internal/signature.h index 3fdba4b2e..8a44fbd5c 100644 --- a/common/internal/signature.h +++ b/common/internal/signature.h @@ -27,7 +27,7 @@ namespace cel::common_internal { -// Generates an signature for a `cel::Type`, which is a string representation of +// Generates a signature for a `cel::Type`, which is a string representation of // the type. // // Examples: @@ -37,7 +37,17 @@ namespace cel::common_internal { // - `list>` absl::StatusOr MakeTypeSignature(const Type& type); -// Generates an identifier for a function overload based on the function name +// Generates a signature for a `cel::TypeSpec`, which is a string +// representation of the type. +// +// Examples: +// +// - `int` +// - `list` +// - `list>` +absl::StatusOr MakeTypeSpecSignature(const TypeSpec& type_spec); + +// Generates a signature for a function overload based on the function name // and the types of the arguments. If `is_member` is true, the first argument // type is used as the receiver and is prepended to the function name, followed // by a dollar sign. @@ -59,6 +69,15 @@ absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member); +// Generates a signature for a function overload based on the function name +// and the type specs of the arguments. See above for more details. +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member); + +// Parses a string type signature directly into a `cel::TypeSpec`. +absl::StatusOr ParseTypeSpec(std::string_view signature); + // Parses a string type signature directly into a `cel::Type`. absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool& pool); diff --git a/common/internal/signature_test.cc b/common/internal/signature_test.cc index 765055f75..17b628d88 100644 --- a/common/internal/signature_test.cc +++ b/common/internal/signature_test.cc @@ -14,6 +14,7 @@ // limitations under the License. #include +#include #include #include @@ -24,6 +25,7 @@ #include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" +#include "common/type_spec_resolver.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" @@ -42,82 +44,9 @@ google::protobuf::Arena* GetTestArena() { return &*arena; } -void VerifyParsedMatchesType(const TypeSpec& parsed, const Type& original) { - switch (original.kind()) { - case TypeKind::kDyn: - EXPECT_TRUE(parsed.has_dyn()); - break; - case TypeKind::kNull: - EXPECT_TRUE(parsed.has_null()); - break; - case TypeKind::kBool: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kBool); - break; - case TypeKind::kInt: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kInt64); - break; - case TypeKind::kUint: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kUint64); - break; - case TypeKind::kDouble: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kDouble); - break; - case TypeKind::kString: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kString); - break; - case TypeKind::kBytes: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kBytes); - break; - case TypeKind::kAny: - EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kAny); - break; - case TypeKind::kTimestamp: - EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kTimestamp); - break; - case TypeKind::kDuration: - EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kDuration); - break; - case TypeKind::kList: - EXPECT_TRUE(parsed.has_list_type()); - if (!original.GetParameters().empty()) { - VerifyParsedMatchesType(parsed.list_type().elem_type(), - original.GetParameters()[0]); - } - break; - case TypeKind::kMap: - EXPECT_TRUE(parsed.has_map_type()); - if (!original.GetParameters().empty()) { - VerifyParsedMatchesType(parsed.map_type().key_type(), - original.GetParameters()[0]); - } - if (original.GetParameters().size() > 1) { - VerifyParsedMatchesType(parsed.map_type().value_type(), - original.GetParameters()[1]); - } - break; - case TypeKind::kBoolWrapper: - case TypeKind::kIntWrapper: - case TypeKind::kUintWrapper: - case TypeKind::kDoubleWrapper: - case TypeKind::kStringWrapper: - case TypeKind::kBytesWrapper: - EXPECT_TRUE(parsed.has_wrapper()); - break; - case TypeKind::kType: - EXPECT_TRUE(parsed.has_type()); - if (!original.GetParameters().empty()) { - VerifyParsedMatchesType(parsed.type(), original.GetParameters()[0]); - } - break; - case TypeKind::kTypeParam: - EXPECT_TRUE(parsed.has_type_param()); - break; - default: - EXPECT_TRUE(parsed.has_abstract_type()); - break; - } +void VerifyParsedMatchesType(const TypeSpec& parsed, const TypeSpec& expected) { + EXPECT_EQ(parsed, expected); } - void VerifyTypesEqual(const Type& lhs, const Type& rhs) { EXPECT_EQ(lhs.kind(), rhs.kind()); if (lhs.kind() != rhs.kind()) return; @@ -138,7 +67,7 @@ void VerifyTypesEqual(const Type& lhs, const Type& rhs) { } struct TypeSignatureTestCase { - Type type; + TypeSpec type; std::string expected_signature; std::string expected_error; }; @@ -149,104 +78,208 @@ TEST_P(TypeSignatureTest, TypeSignature) { const auto& param = GetParam(); absl::StatusOr signature = - common_internal::MakeTypeSignature(param.type); + common_internal::MakeTypeSpecSignature(param.type); if (!param.expected_error.empty()) { EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(param.expected_error))); } else { EXPECT_THAT(signature, IsOkAndHolds(param.expected_signature)); + + absl::StatusOr type = ConvertTypeSpecToType( + param.type, GetTestArena(), *GetTestingDescriptorPool()); + ASSERT_THAT(type, ::absl_testing::IsOk()); + EXPECT_THAT(MakeTypeSignature(*type), + IsOkAndHolds(param.expected_signature)); } } std::vector GetTypeSignatureTestCases() { return { { - .type = StringType{}, - .expected_signature = "string", + .type = TypeSpec(NullTypeSpec{}), + .expected_signature = "null", + }, + { + .type = TypeSpec(PrimitiveType::kBool), + .expected_signature = "bool", }, { - .type = IntType{}, + .type = TypeSpec(PrimitiveType::kInt64), .expected_signature = "int", }, { - .type = ListType(GetTestArena(), StringType{}), - .expected_signature = "list", + .type = TypeSpec(PrimitiveType::kUint64), + .expected_signature = "uint", }, { - .type = TypeType(GetTestArena(), IntType{}), - .expected_signature = "type", + .type = TypeSpec(PrimitiveType::kDouble), + .expected_signature = "double", + }, + { + .type = TypeSpec(PrimitiveType::kString), + .expected_signature = "string", }, { - .type = ListType(GetTestArena(), TypeParamType("A")), + .type = TypeSpec(PrimitiveType::kBytes), + .expected_signature = "bytes", + }, + { + .type = TypeSpec( + MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes")), + .expected_signature = "cel.expr.conformance.proto3.TestAllTypes", + }, + { + .type = TypeSpec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", {})), + .expected_signature = "cel.expr.conformance.proto3.TestAllTypes", + }, + { + .type = TypeSpec(WellKnownTypeSpec::kDuration), + .expected_signature = "duration", + }, + { + .type = TypeSpec(WellKnownTypeSpec::kTimestamp), + .expected_signature = "timestamp", + }, + { + .type = TypeSpec( + ListTypeSpec(std::make_unique(PrimitiveType::kString))), + .expected_signature = "list", + }, + { + .type = TypeSpec( + ListTypeSpec(std::make_unique(ParamTypeSpec("A")))), .expected_signature = "list<~A>", }, { - .type = ListType(GetTestArena(), TypeParamType("A(ParamTypeSpec("A(ParamTypeSpec(R"(a,b..(d)\e)")))), + .expected_signature = R"(list<~a\,b\.\\.\(d\)\\e>)", + }, + { + .type = TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kInt64), + std::make_unique(DynTypeSpec()))), .expected_signature = "map", }, { - .type = - MapType(GetTestArena(), TypeParamType("B"), TypeParamType("C")), + .type = TypeSpec( + MapTypeSpec(std::make_unique(ParamTypeSpec("B")), + std::make_unique(ParamTypeSpec("C")))), .expected_signature = "map<~B,~C>", }, { - .type = OpaqueType(GetTestArena(), "bar", - {FunctionType(GetTestArena(), TypeParamType("D"), - {StringType{}, BoolType{}})}), - .expected_signature = "bar>", + .type = TypeSpec(MapTypeSpec( + std::make_unique(PrimitiveType::kInt64), nullptr)), + .expected_signature = "map", + }, + { + .type = TypeSpec(MapTypeSpec(nullptr, nullptr)), + .expected_signature = "map", + }, + { + .type = TypeSpec(std::make_unique(PrimitiveType::kInt64)), + .expected_signature = "type", }, { - .type = AnyType{}, + .type = TypeSpec(WellKnownTypeSpec::kAny), .expected_signature = "any", }, { - .type = DurationType{}, - .expected_signature = "duration", + .type = TypeSpec(DynTypeSpec{}), + .expected_signature = "dyn", }, { - .type = TimestampType{}, - .expected_signature = "timestamp", + .type = TypeSpec(AbstractType( + "bar", {TypeSpec(FunctionTypeSpec( + std::make_unique(ParamTypeSpec("D")), + {TypeSpec(PrimitiveType::kString), + TypeSpec(PrimitiveType::kBool)}))})), + .expected_signature = "bar>", + }, + { + .type = + TypeSpec(AbstractType("bar", {TypeSpec(PrimitiveType::kInt64), + TypeSpec(PrimitiveType::kString)})), + .expected_signature = "bar", }, { - .type = BoolWrapperType{}, + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)), .expected_signature = "bool_wrapper", }, { - .type = IntWrapperType{}, + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)), .expected_signature = "int_wrapper", }, { - .type = UintWrapperType{}, + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)), .expected_signature = "uint_wrapper", }, { - .type = MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( - "cel.expr.conformance.proto3.TestAllTypes")), - .expected_signature = "cel.expr.conformance.proto3.TestAllTypes", + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + .expected_signature = "double_wrapper", }, { - .type = ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)")), - .expected_signature = R"(list<~a\,b\.\\.\(d\)\\e>)", + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)), + .expected_signature = "string_wrapper", + }, + { + .type = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)), + .expected_signature = "bytes_wrapper", + }, + { + .type = TypeSpec( + FunctionTypeSpec(nullptr, {TypeSpec(PrimitiveType::kInt64)})), + .expected_signature = "function", + }, + { + .type = TypeSpec(FunctionTypeSpec( + std::make_unique(PrimitiveType::kInt64), {})), + .expected_signature = "function", + }, + { + .type = TypeSpec(FunctionTypeSpec(nullptr, {})), + .expected_signature = "function", }, }; } +INSTANTIATE_TEST_SUITE_P(TypeSignatureTest, TypeSignatureTest, + ValuesIn(GetTypeSignatureTestCases())); + TEST(TypeSignatureTest, UnsupportedTypes) { EXPECT_THAT(common_internal::MakeTypeSignature(UnknownType{}), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Type kind: *unknown* is not supported"))); + HasSubstr("Unsupported Type kind: *unknown*"))); EXPECT_THAT(common_internal::MakeTypeSignature(ErrorType{}), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Type kind: *error* is not supported"))); -} + HasSubstr("Unsupported type in signature: *error*"))); -INSTANTIATE_TEST_SUITE_P(TypeIdTest, TypeSignatureTest, - ValuesIn(GetTypeSignatureTestCases())); + EXPECT_THAT(common_internal::MakeTypeSpecSignature( + TypeSpec(static_cast(999))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported primitive type"))); + + EXPECT_THAT(common_internal::MakeTypeSpecSignature( + TypeSpec(static_cast(999))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported well-known type"))); + + EXPECT_THAT(common_internal::MakeTypeSpecSignature(TypeSpec( + PrimitiveTypeWrapper(static_cast(999)))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported wrapper type"))); +} TEST_P(TypeSignatureTest, ParseTypeCheck) { const auto& param = GetParam(); @@ -254,13 +287,16 @@ TEST_P(TypeSignatureTest, ParseTypeCheck) { auto parsed = ParseType(param.expected_signature, GetTestArena(), *GetTestingDescriptorPool()); ASSERT_THAT(parsed, ::absl_testing::IsOk()); - VerifyTypesEqual(*parsed, param.type); + ASSERT_OK_AND_ASSIGN(auto expected_type, + ConvertTypeSpecToType(param.type, GetTestArena(), + *GetTestingDescriptorPool())); + VerifyTypesEqual(*parsed, expected_type); } } struct OverloadSignatureTestCase { std::string function_name = "hello"; - std::vector args; + std::vector args; bool is_member = false; std::string expected_signature; std::string expected_error; @@ -285,98 +321,110 @@ TEST_P(OverloadSignatureTest, OverloadSignature) { std::vector GetOverloadSignatureTestCases() { return { { - .args = {StringType{}}, + .args = {TypeSpec(PrimitiveType::kString)}, .expected_signature = "hello(string)", }, { - .args = {IntType{}, UintType{}}, + .args = {TypeSpec(PrimitiveType::kInt64), + TypeSpec(PrimitiveType::kUint64)}, .expected_signature = "hello(int,uint)", }, { - .args = {ListType(GetTestArena(), StringType{})}, + .args = {TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kString)))}, .expected_signature = "hello(list)", }, { - .args = {ListType(GetTestArena(), TypeParamType("A"))}, + .args = {TypeSpec( + ListTypeSpec(std::make_unique(ParamTypeSpec("A"))))}, .expected_signature = "hello(list<~A>)", }, { - .args = {MapType(GetTestArena(), IntType{}, DynType{})}, + .args = {TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kInt64), + std::make_unique(DynTypeSpec())))}, .expected_signature = "hello(map)", }, { - .args = {MapType(GetTestArena(), TypeParamType("B"), - TypeParamType("C"))}, + .args = {TypeSpec( + MapTypeSpec(std::make_unique(ParamTypeSpec("B")), + std::make_unique(ParamTypeSpec("C"))))}, .expected_signature = "hello(map<~B,~C>)", }, + { - .args = {OpaqueType( - GetTestArena(), "bar", - {FunctionType(GetTestArena(), TypeParamType("D"), {})})}, + .args = {TypeSpec(AbstractType( + "bar", + {TypeSpec(FunctionTypeSpec( + std::make_unique(ParamTypeSpec("D")), {}))}))}, .expected_signature = "hello(bar>)", }, { - .args = {AnyType{}}, + .args = {TypeSpec(WellKnownTypeSpec::kAny)}, .expected_signature = "hello(any)", }, { - .args = {DurationType{}}, + .args = {TypeSpec(WellKnownTypeSpec::kDuration)}, .expected_signature = "hello(duration)", }, { - .args = {TimestampType{}}, + .args = {TypeSpec(WellKnownTypeSpec::kTimestamp)}, .expected_signature = "hello(timestamp)", }, { - .args = {BoolWrapperType{}}, + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool))}, .expected_signature = "hello(bool_wrapper)", }, { - .args = {IntWrapperType{}}, + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64))}, .expected_signature = "hello(int_wrapper)", }, { - .args = {UintWrapperType{}}, + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64))}, .expected_signature = "hello(uint_wrapper)", }, { - .args = {MessageType( - GetTestingDescriptorPool()->FindMessageTypeByName( - "cel.expr.conformance.proto3.TestAllTypes"))}, + .args = {TypeSpec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", {}))}, .expected_signature = "hello(cel.expr.conformance.proto3.TestAllTypes)", }, { - .args = {StringType{}}, + .args = {TypeSpec(PrimitiveType::kString)}, .is_member = true, .expected_signature = "string.hello()", }, { - .args = {StringType{}, ListType(GetTestArena(), BoolType{})}, + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kBool)))}, .is_member = true, .expected_signature = "string.hello(list)", }, { - .args = {StringType{}, BoolType{}, DynType{}}, + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(PrimitiveType::kBool), TypeSpec(DynTypeSpec())}, .is_member = true, .expected_signature = "string.hello(bool,dyn)", }, { .function_name = "hello", - .args = {OpaqueType(GetTestArena(), "bar", - {TypeParamType("dummy.type")})}, + .args = {TypeSpec( + AbstractType("bar", {TypeSpec(ParamTypeSpec("dummy.type"))}))}, .is_member = true, .expected_signature = R"(bar<~dummy\.type>.hello())", }, { .function_name = "inspect", - .args = {Type(TypeType(GetTestArena(), StringType{}))}, + .args = {TypeSpec( + std::make_unique(PrimitiveType::kString))}, .expected_signature = "inspect(type)", }, { .function_name = R"(h.(e),l\o)", - .args = {StringType{}, - ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)"))}, + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(ListTypeSpec(std::make_unique( + ParamTypeSpec(R"(a,b..(d)\e)"))))}, .is_member = true, .expected_signature = R"(string.h\.\(e\)\,l\\\o(list<~a\,b\.\\.\(d\)\\e>))", @@ -385,7 +433,8 @@ std::vector GetOverloadSignatureTestCases() { } TEST(OverloadSignatureTest, MemberFunctionNoReceiverError) { - auto signature = common_internal::MakeOverloadSignature("hello", {}, true); + auto signature = common_internal::MakeOverloadSignature( + "hello", std::vector{}, true); EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Member function with no receiver"))); @@ -564,8 +613,15 @@ TEST(ParseSignatureTest, ParsingErrors) { EXPECT_THAT(ParseFunctionSignature("foo"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid function signature"))); + EXPECT_THAT( + ParseType("list b < c>", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("mismatched brackets"))); - // Parameter count validations for list and map types. + // Parameter count validations for list, map and type types. EXPECT_THAT(ParseType("list", GetTestArena(), *GetTestingDescriptorPool()), StatusIs(absl::StatusCode::kInvalidArgument, @@ -578,6 +634,18 @@ TEST(ParseSignatureTest, ParsingErrors) { *GetTestingDescriptorPool()), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("map expects 0 or 2 parameters"))); + EXPECT_THAT(ParseType("type", GetTestArena(), + *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("type expects at most 1 parameter"))); + + // Invalid parameter name validations. + EXPECT_THAT(ParseType("~", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid type parameter name"))); + EXPECT_THAT(ParseType("~A", GetTestArena(), *GetTestingDescriptorPool()), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("invalid type parameter name"))); // Enforcing valid function and identifier names. EXPECT_THAT(ParseFunctionSignature("()"), @@ -700,5 +768,20 @@ TEST(ParseSignatureTest, EmptyOrWhitespaceErrors) { HasSubstr("Empty type signature"))); } +TEST(OverloadSignatureTest, ArgumentTypeVector) { + std::vector args; + args.push_back(Type(IntType())); + args.push_back(Type(StringType())); + args.push_back(Type(ListType(GetTestArena(), IntType()))); + args.push_back( + Type(MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))); + args.push_back(Type(OpaqueType(GetTestArena(), "Foo", {TypeParamType("T")}))); + ASSERT_OK_AND_ASSIGN(auto sig, MakeOverloadSignature("foo", args, false)); + EXPECT_EQ(sig, + "foo(int,string,list,cel.expr.conformance.proto3.TestAllTypes," + "Foo<~T>)"); +} + } // namespace } // namespace cel::common_internal diff --git a/common/type_spec_resolver.cc b/common/type_spec_resolver.cc index 97451f390..90c9930a8 100644 --- a/common/type_spec_resolver.cc +++ b/common/type_spec_resolver.cc @@ -14,6 +14,7 @@ #include "common/type_spec_resolver.h" +#include #include #include #include @@ -22,8 +23,12 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "common/ast.h" #include "common/type.h" +#include "common/type_kind.h" #include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel { @@ -85,28 +90,42 @@ absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, } if (type_spec.has_list_type()) { - CEL_ASSIGN_OR_RETURN( - auto elem_type, - ConvertTypeSpecToType(type_spec.list_type().elem_type(), arena, pool)); + Type elem_type; + if (type_spec.list_type().elem_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + elem_type, ConvertTypeSpecToType(type_spec.list_type().elem_type(), + arena, pool)); + } return Type(ListType(arena, elem_type)); } if (type_spec.has_map_type()) { - CEL_ASSIGN_OR_RETURN( - auto key_type, - ConvertTypeSpecToType(type_spec.map_type().key_type(), arena, pool)); - CEL_ASSIGN_OR_RETURN( - auto value_type, - ConvertTypeSpecToType(type_spec.map_type().value_type(), arena, pool)); + Type key_type; + if (type_spec.map_type().key_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + key_type, + ConvertTypeSpecToType(type_spec.map_type().key_type(), arena, pool)); + } + + Type value_type; + if (type_spec.map_type().value_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + value_type, ConvertTypeSpecToType(type_spec.map_type().value_type(), + arena, pool)); + } return Type(MapType(arena, key_type, value_type)); } if (type_spec.has_function()) { const auto& func_spec = type_spec.function(); - CEL_ASSIGN_OR_RETURN( - auto result_type, - ConvertTypeSpecToType(func_spec.result_type(), arena, pool)); + Type result_type; + if (func_spec.result_type().is_specified()) { + CEL_ASSIGN_OR_RETURN( + result_type, + ConvertTypeSpecToType(func_spec.result_type(), arena, pool)); + } std::vector arg_types; + arg_types.reserve(func_spec.arg_types().size()); for (const auto& arg_spec : func_spec.arg_types()) { CEL_ASSIGN_OR_RETURN(auto arg_type, ConvertTypeSpecToType(arg_spec, arena, pool)); @@ -179,4 +198,104 @@ absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, return absl::InvalidArgumentError("Unknown TypeSpec kind"); } +absl::StatusOr ConvertTypeToTypeSpec(const Type& type) { + switch (type.kind()) { + case TypeKind::kNull: + return TypeSpec(NullTypeSpec{}); + case TypeKind::kDyn: + return TypeSpec(DynTypeSpec{}); + case TypeKind::kBool: + return TypeSpec(PrimitiveType::kBool); + case TypeKind::kInt: + return TypeSpec(PrimitiveType::kInt64); + case TypeKind::kUint: + return TypeSpec(PrimitiveType::kUint64); + case TypeKind::kDouble: + return TypeSpec(PrimitiveType::kDouble); + case TypeKind::kString: + return TypeSpec(PrimitiveType::kString); + case TypeKind::kBytes: + return TypeSpec(PrimitiveType::kBytes); + case TypeKind::kAny: + return TypeSpec(WellKnownTypeSpec::kAny); + case TypeKind::kTimestamp: + return TypeSpec(WellKnownTypeSpec::kTimestamp); + case TypeKind::kDuration: + return TypeSpec(WellKnownTypeSpec::kDuration); + case TypeKind::kBoolWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + case TypeKind::kIntWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + case TypeKind::kUintWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + case TypeKind::kDoubleWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + case TypeKind::kStringWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + case TypeKind::kBytesWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + case TypeKind::kList: { + CEL_ASSIGN_OR_RETURN(auto elem_type, + ConvertTypeToTypeSpec(type.GetList().element())); + return TypeSpec( + ListTypeSpec(std::make_unique(std::move(elem_type)))); + } + case TypeKind::kMap: { + CEL_ASSIGN_OR_RETURN(auto key_type, + ConvertTypeToTypeSpec(type.GetMap().key())); + CEL_ASSIGN_OR_RETURN(auto value_type, + ConvertTypeToTypeSpec(type.GetMap().value())); + return TypeSpec( + MapTypeSpec(std::make_unique(std::move(key_type)), + std::make_unique(std::move(value_type)))); + } + case TypeKind::kFunction: { + auto func_type = type.GetFunction(); + CEL_ASSIGN_OR_RETURN(auto result_type, + ConvertTypeToTypeSpec(func_type.result())); + std::vector arg_types; + arg_types.reserve(func_type.args().size()); + for (const auto& arg : func_type.args()) { + CEL_ASSIGN_OR_RETURN(auto arg_type, ConvertTypeToTypeSpec(arg)); + arg_types.push_back(std::move(arg_type)); + } + return TypeSpec( + FunctionTypeSpec(std::make_unique(std::move(result_type)), + std::move(arg_types))); + } + case TypeKind::kTypeParam: + return TypeSpec(ParamTypeSpec(std::string(type.GetTypeParam().name()))); + case TypeKind::kStruct: { + if (type.IsMessage()) { + return TypeSpec(MessageTypeSpec(std::string(type.GetMessage().name()))); + } + return absl::InvalidArgumentError("Unsupported struct type"); + } + case TypeKind::kOpaque: { + auto opaque_type = type.GetOpaque(); + std::vector params; + params.reserve(opaque_type.GetParameters().size()); + for (const auto& param : opaque_type.GetParameters()) { + CEL_ASSIGN_OR_RETURN(auto param_type, ConvertTypeToTypeSpec(param)); + params.push_back(std::move(param_type)); + } + return TypeSpec( + AbstractType(std::string(opaque_type.name()), std::move(params))); + } + case TypeKind::kType: { + CEL_ASSIGN_OR_RETURN(auto nested_type, + ConvertTypeToTypeSpec(type.GetType().GetType())); + return TypeSpec(std::make_unique(std::move(nested_type))); + } + case TypeKind::kError: + return TypeSpec(ErrorTypeSpec::kValue); + case TypeKind::kEnum: + return TypeSpec( + AbstractType(std::string(type.GetEnum().name()), /*params=*/{})); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported Type kind: ", TypeKindToString(type.kind()))); + } +} + } // namespace cel diff --git a/common/type_spec_resolver.h b/common/type_spec_resolver.h index 44e1e088f..edbfa3bde 100644 --- a/common/type_spec_resolver.h +++ b/common/type_spec_resolver.h @@ -32,6 +32,9 @@ absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool& pool); +// Resolves a `cel::Type` to a `cel::TypeSpec`. +absl::StatusOr ConvertTypeToTypeSpec(const Type& type); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ diff --git a/common/type_spec_resolver_test.cc b/common/type_spec_resolver_test.cc index c7fbb2cf8..1cda7280f 100644 --- a/common/type_spec_resolver_test.cc +++ b/common/type_spec_resolver_test.cc @@ -23,6 +23,7 @@ #include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" +#include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/testing.h" @@ -33,6 +34,7 @@ namespace cel { namespace { using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::internal::GetTestingDescriptorPool; using ::testing::HasSubstr; @@ -67,6 +69,7 @@ TEST_P(ConversionTest, TestTypeSpecConversion) { auto t, ConvertTypeSpecToType(std::get<0>(GetParam()), GetTestArena(), *GetTestingDescriptorPool())); EXPECT_EQ(t.kind(), std::get<1>(GetParam())); + EXPECT_THAT(ConvertTypeToTypeSpec(t), IsOkAndHolds(std::get<0>(GetParam()))); } INSTANTIATE_TEST_SUITE_P( @@ -104,6 +107,8 @@ TEST(TypeSpecResolverTest, ListTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsList()); EXPECT_TRUE(t->GetList().element().IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MapTypeConversion) { @@ -116,6 +121,8 @@ TEST(TypeSpecResolverTest, MapTypeConversion) { EXPECT_TRUE(t->IsMap()); EXPECT_TRUE(t->GetMap().key().IsString()); EXPECT_TRUE(t->GetMap().value().IsBytes()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, FunctionTypeConversion) { @@ -129,6 +136,8 @@ TEST(TypeSpecResolverTest, FunctionTypeConversion) { EXPECT_TRUE(t->IsFunction()); EXPECT_EQ(t->GetFunction().args().size(), 1); EXPECT_TRUE(t->GetFunction().result().IsBool()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, TypeParamConversion) { @@ -138,6 +147,8 @@ TEST(TypeSpecResolverTest, TypeParamConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsTypeParam()); EXPECT_EQ(t->GetTypeParam().name(), "T"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MessageTypeConversion) { @@ -148,6 +159,10 @@ TEST(TypeSpecResolverTest, MessageTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsMessage()); EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ( + spec2, + TypeSpec(MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes"))); } TEST(TypeSpecResolverTest, MessageTypeWithParamsError) { @@ -172,6 +187,8 @@ TEST(TypeSpecResolverTest, UnresolvedAbstractTypeFallbackToOpaque) { EXPECT_EQ(t->name(), "my.custom.OpaqueType"); EXPECT_EQ(t->GetParameters().size(), 1); EXPECT_TRUE(t->GetParameters()[0].IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, OptionalType) { @@ -186,6 +203,8 @@ TEST(TypeSpecResolverTest, OptionalType) { EXPECT_EQ(t->GetParameters().size(), 1); EXPECT_TRUE(t->GetParameters()[0].IsInt()); EXPECT_TRUE(t->IsOptional()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, TypeTypeConversion) { @@ -196,6 +215,8 @@ TEST(TypeSpecResolverTest, TypeTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsType()); EXPECT_TRUE(t->GetType().GetType().IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, ErrorTypeConversion) { @@ -204,6 +225,8 @@ TEST(TypeSpecResolverTest, ErrorTypeConversion) { ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsError()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { @@ -213,6 +236,8 @@ TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsMessage()); EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MessageTypeSpecNotFoundError) { @@ -231,6 +256,8 @@ TEST(TypeSpecResolverTest, EnumTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsEnum()); EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, EnumTypeWithParamsError) { From f5d0d5ff13e770bd35be79e5ce7e81fe35f63e91 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 8 Jun 2026 15:38:29 -0700 Subject: [PATCH 53/87] Clean up unused dependencies in cel/cpp/common BUILD files. Remove unused dependencies from cel/cpp/common/BUILD and cel/cpp/common/internal/BUILD. PiperOrigin-RevId: 928808099 --- common/BUILD | 9 --------- common/internal/BUILD | 2 -- 2 files changed, 11 deletions(-) diff --git a/common/BUILD b/common/BUILD index 01710329b..f7c897e57 100644 --- a/common/BUILD +++ b/common/BUILD @@ -403,7 +403,6 @@ cc_library( ":allocator", ":arena", ":data", - ":native_type", ":reference_count", "//common/internal:metadata", "//common/internal:reference_count", @@ -425,13 +424,9 @@ cc_test( ":allocator", ":data", ":memory", - ":native_type", "//common/internal:reference_count", "//internal:testing", "@com_google_absl//absl/base:nullability", - "@com_google_absl//absl/debugging:leak_check", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", "@com_google_protobuf//:struct_cc_proto", ], @@ -1024,9 +1019,6 @@ cc_library( deps = [ ":decl", ":decl_proto", - ":type", - ":type_proto", - "//internal:status_macros", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1194,6 +1186,5 @@ cc_test( ":container", "//internal:testing", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:status_matchers", ], ) diff --git a/common/internal/BUILD b/common/internal/BUILD index 73cbf37e9..b07faf229 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -21,10 +21,8 @@ cc_library( name = "casting", hdrs = ["casting.h"], deps = [ - "//common:native_type", "//internal:casts", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/types:optional", ], From 87fea87272429b9c85655caeda7ab5e4804bf5f4 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 8 Jun 2026 22:27:41 -0700 Subject: [PATCH 54/87] Add support for type signatures in CEL environment YAML configuration. The `env_yaml` parser now accepts type signatures for variable types and function overload signatures. The `type` field can be used instead of `type_name` for variables, allowing a more compact representation of types, including type parameters and parameterized types. The `signature` field can be used for function overloads, providing a single string to define the overload's target, arguments, and member status. The `return` type in function overloads can now also be specified as a type signature string. PiperOrigin-RevId: 928959415 --- env/BUILD | 7 +- env/env_yaml.cc | 311 +++++++++++++++++++++++++--------- env/env_yaml.h | 37 +++- env/env_yaml_test.cc | 381 +++++++++++++++++++++++++++++++++--------- env/type_info.cc | 226 +++++++++++++++++++++++++ env/type_info.h | 7 + env/type_info_test.cc | 169 +++++++++++++++++++ 7 files changed, 978 insertions(+), 160 deletions(-) diff --git a/env/BUILD b/env/BUILD index 41ffc1723..3035e11ac 100644 --- a/env/BUILD +++ b/env/BUILD @@ -28,6 +28,7 @@ cc_library( "type_info.h", ], deps = [ + "//common:ast", "//common:constant", "//common:type", "//common:type_kind", @@ -120,7 +121,9 @@ cc_library( features = ["-use_header_modules"], deps = [ ":config", + "//common:ast", "//common:constant", + "//common/internal:signature", "//internal:status_macros", "//internal:strings", "@com_google_absl//absl/algorithm:container", @@ -178,9 +181,11 @@ cc_test( ":config", "//common:type", "//common:type_proto", + "//common/ast:metadata", "//internal:proto_matchers", "//internal:testing", "//internal:testing_descriptor_pool", + "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) @@ -201,7 +206,6 @@ cc_test( "//common:type", "//common:value", "//compiler", - "//internal:proto_matchers", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", @@ -241,7 +245,6 @@ cc_test( "//common:value", "//compiler", "//extensions:math_ext", - "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//runtime", diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 159786598..8c635e65f 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -35,8 +35,11 @@ #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "common/ast.h" #include "common/constant.h" +#include "common/internal/signature.h" #include "env/config.h" +#include "env/type_info.h" #include "internal/status_macros.h" #include "internal/strings.h" #include "yaml-cpp/emitter.h" @@ -117,8 +120,8 @@ absl::StatusOr GetBinary(absl::string_view yaml, return binary; } else { return YamlError(yaml, node, - "Node '" + GetString(yaml, node) + - "' is not a valid Base64 encoded binary"); + absl::StrCat("Node '", GetString(yaml, node), + "' is not a valid Base64 encoded binary")); } } @@ -131,10 +134,22 @@ absl::StatusOr GetBool(absl::string_view yaml, absl::string_view key, return node.as(); } catch (YAML::Exception& e) { return YamlError(yaml, node, - "Node '" + std::string(key) + "' is not a boolean"); + absl::StrCat("Node '", key, "' is not a boolean")); } } +// Returns the key in the map `node` that has the given `value_node` as its +// value. If no such key exists, returns `value_node` itself. +YAML::Node GetContextNodeForKeyValue(const YAML::Node& node, + const YAML::Node& value_node) { + for (const auto& kv : node) { + if (kv.second.IsDefined() && kv.second.is(value_node)) { + return kv.first; + } + } + return value_node; +} + absl::Status ParseName(Config& config, absl::string_view yaml, const YAML::Node& root) { const YAML::Node name = root["name"]; @@ -407,7 +422,23 @@ absl::Status ParseStandardLibraryConfig(Config& config, absl::string_view yaml, absl::StatusOr ParseTypeInfo(const YAML::Node& node, absl::string_view yaml) { Config::TypeInfo type_config; + const YAML::Node type = node["type"]; const YAML::Node type_name = node["type_name"]; + if (type.IsDefined() && type_name.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(node, type_name), + "Node 'type' and 'type_name' are mutually exclusive"); + } + + if (type.IsDefined()) { + if (!type.IsScalar()) { + return YamlError(yaml, type, "Node 'type' is not a string"); + } + CEL_ASSIGN_OR_RETURN(auto type_spec, + common_internal::ParseTypeSpec(GetString(yaml, type))); + CEL_ASSIGN_OR_RETURN(auto type_config, TypeSpecToTypeInfo(type_spec)); + return type_config; + } + if (!type_name.IsDefined()) { return type_config; } @@ -627,7 +658,8 @@ absl::Status ParseVariableConfigs(Config& config, absl::string_view yaml, } absl::StatusOr ParseFunctionOverloadConfig( - absl::string_view yaml, const YAML::Node& overload) { + absl::string_view yaml, const YAML::Node& overload, + absl::string_view function_name) { Config::FunctionOverloadConfig overload_config; if (!overload || !overload.IsMap()) { return YamlError(yaml, overload, "Function overload is not a map"); @@ -654,40 +686,89 @@ absl::StatusOr ParseFunctionOverloadConfig( } } + const YAML::Node signature_node = overload["signature"]; const YAML::Node target = overload["target"]; - if (target.IsDefined()) { - if (!target.IsMap()) { - return YamlError(yaml, target, "Function overload target is not a map"); + const YAML::Node args = overload["args"]; + if (signature_node.IsDefined()) { + if (!signature_node.IsScalar()) { + return YamlError(yaml, signature_node, + "Function overload signature is not a string"); } - CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, - ParseTypeInfo(target, yaml)); - overload_config.is_member_function = true; - overload_config.parameters.push_back(type_info); - } - const YAML::Node args = overload["args"]; - if (args.IsDefined()) { - if (!args.IsSequence()) { - return YamlError(yaml, args, "Function overload args is not a sequence"); + if (target.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(overload, target), + "Function overload signature and target are mutually " + "exclusive"); + } + if (args.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(overload, args), + "Function overload signature and args are mutually " + "exclusive"); + } + + std::string signature = GetString(yaml, signature_node); + CEL_ASSIGN_OR_RETURN( + common_internal::ParsedFunctionOverload parsed_signature, + common_internal::ParseFunctionSignature(signature)); + if (parsed_signature.function_name != function_name) { + return YamlError(yaml, signature_node, + absl::StrCat("Function overload name \"", + parsed_signature.function_name, + "\" does not match function name \"", + function_name, "\"")); + } + overload_config.is_member_function = parsed_signature.is_member; + if (!parsed_signature.signature_type.has_function()) { + return absl::InternalError(absl::StrCat( + "Function overload signature has no function type: ", signature)); } - for (const YAML::Node& arg : args) { - if (!arg.IsMap()) { - return YamlError(yaml, arg, "Function overload arg is not a map"); + const FunctionTypeSpec& function_type_spec = + parsed_signature.signature_type.function(); + for (const auto& arg : function_type_spec.arg_types()) { + CEL_ASSIGN_OR_RETURN(auto type_info, TypeSpecToTypeInfo(arg)); + overload_config.parameters.push_back(std::move(type_info)); + } + } else { + if (target.IsDefined()) { + if (!target.IsMap()) { + return YamlError(yaml, target, "Function overload target is not a map"); } CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, - ParseTypeInfo(arg, yaml)); + ParseTypeInfo(target, yaml)); + overload_config.is_member_function = true; overload_config.parameters.push_back(type_info); } - } + if (args.IsDefined()) { + if (!args.IsSequence()) { + return YamlError(yaml, args, + "Function overload args is not a sequence"); + } + for (const YAML::Node& arg : args) { + if (!arg.IsMap()) { + return YamlError(yaml, arg, "Function overload arg is not a map"); + } + CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, + ParseTypeInfo(arg, yaml)); + overload_config.parameters.push_back(type_info); + } + } + } const YAML::Node return_type = overload["return"]; if (return_type.IsDefined()) { - if (!return_type.IsMap()) { - return YamlError(yaml, return_type, - "Function overload return type is not a map"); + if (return_type.IsScalar()) { + CEL_ASSIGN_OR_RETURN(auto type_spec, common_internal::ParseTypeSpec( + GetString(yaml, return_type))); + CEL_ASSIGN_OR_RETURN(overload_config.return_type, + TypeSpecToTypeInfo(type_spec)); + } else if (return_type.IsMap()) { + CEL_ASSIGN_OR_RETURN(overload_config.return_type, + ParseTypeInfo(return_type, yaml)); + } else { + return YamlError( + yaml, return_type, + "Function overload return type is neither a string nor a map"); } - CEL_ASSIGN_OR_RETURN(overload_config.return_type, - ParseTypeInfo(return_type, yaml)); } return overload_config; } @@ -728,8 +809,9 @@ absl::Status ParseFunctionConfigs(Config& config, absl::string_view yaml, } for (const YAML::Node& overload : overloads) { - CEL_ASSIGN_OR_RETURN(Config::FunctionOverloadConfig overload_config, - ParseFunctionOverloadConfig(yaml, overload)); + CEL_ASSIGN_OR_RETURN( + Config::FunctionOverloadConfig overload_config, + ParseFunctionOverloadConfig(yaml, overload, function_config.name)); function_config.overload_configs.push_back(std::move(overload_config)); } } @@ -893,26 +975,43 @@ void EmitStandardLibraryConfig(const Config& env_config, YAML::Emitter& out) { out << YAML::EndMap; } -void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out) { +void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { // Note: the map is already started when this is called, so we don't emit // BeginMap here or EndMap at the end. - out << YAML::Key << "type_name"; - out << YAML::Value << YAML::DoubleQuoted << type_info.name; - if (type_info.is_type_param) { - out << YAML::Key << "is_type_param" << YAML::Value << true; - } - if (!type_info.params.empty()) { - out << YAML::Key << "params" << YAML::Value << YAML::BeginSeq; - for (const Config::TypeInfo& param : type_info.params) { - out << YAML::BeginMap; - EmitTypeInfo(param, out); - out << YAML::EndMap; + bool signature_generated = false; + if (options.use_type_signatures) { + absl::StatusOr type_spec = TypeInfoToTypeSpec(type_info); + if (type_spec.ok()) { + absl::StatusOr signature = + common_internal::MakeTypeSpecSignature(*type_spec); + if (signature.ok()) { + out << YAML::Key << "type"; + out << YAML::Value << YAML::DoubleQuoted << *signature; + signature_generated = true; + } + } + } + if (!signature_generated) { + out << YAML::Key << "type_name"; + out << YAML::Value << YAML::DoubleQuoted << type_info.name; + if (type_info.is_type_param) { + out << YAML::Key << "is_type_param" << YAML::Value << true; + } + if (!type_info.params.empty()) { + out << YAML::Key << "params" << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& param : type_info.params) { + out << YAML::BeginMap; + EmitTypeInfo(param, out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; } - out << YAML::EndSeq; } } -void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { +void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { const auto& variable_configs = env_config.GetVariableConfigs(); if (variable_configs.empty()) { return; @@ -936,7 +1035,7 @@ void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { out << YAML::Key << "description"; out << YAML::Value << YAML::DoubleQuoted << variable_config.description; } - EmitTypeInfo(variable_config.type_info, out); + EmitTypeInfo(variable_config.type_info, out, options); if (variable_config.value.has_value()) { const Constant& constant = variable_config.value; switch (constant.kind_case()) { @@ -991,51 +1090,97 @@ void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { } void EmitFunctionOverloadConfig( - const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out) { + absl::string_view function_name, + const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { out << YAML::BeginMap; - out << YAML::Key << "id"; - out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; - if (overload_config.is_member_function) { - out << YAML::Key << "target" << YAML::Value; - out << YAML::BeginMap; - if (overload_config.parameters.empty()) { - // This should never happen, but if it does, emit a dynamic type. - EmitTypeInfo({.name = "dyn"}, out); - } else { - EmitTypeInfo(overload_config.parameters[0], out); + if (!overload_config.overload_id.empty()) { + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; + } + bool signature_generated = false; + if (options.use_type_signatures) { + bool param_type_spec_generated = true; + std::vector params; + params.reserve(overload_config.parameters.size()); + for (const auto& parameter : overload_config.parameters) { + absl::StatusOr type_spec = TypeInfoToTypeSpec(parameter); + if (!type_spec.ok()) { + param_type_spec_generated = false; + break; + } + params.push_back(std::move(*type_spec)); } - out << YAML::EndMap; - if (overload_config.parameters.size() > 1) { - out << YAML::Key << "args"; - out << YAML::Value << YAML::BeginSeq; - for (size_t i = 1; i < overload_config.parameters.size(); ++i) { - out << YAML::BeginMap; - EmitTypeInfo(overload_config.parameters[i], out); - out << YAML::EndMap; + if (param_type_spec_generated) { + absl::StatusOr signature = + common_internal::MakeOverloadSignature( + function_name, params, overload_config.is_member_function); + if (signature.ok()) { + out << YAML::Key << "signature"; + out << YAML::Value << YAML::DoubleQuoted << *signature; + signature_generated = true; } - out << YAML::EndSeq; } - } else { - if (!overload_config.parameters.empty()) { - out << YAML::Key << "args"; - out << YAML::Value << YAML::BeginSeq; - for (const Config::TypeInfo& parameter : overload_config.parameters) { - out << YAML::BeginMap; - EmitTypeInfo(parameter, out); - out << YAML::EndMap; + } + if (!signature_generated) { + if (overload_config.is_member_function) { + out << YAML::Key << "target" << YAML::Value; + out << YAML::BeginMap; + if (overload_config.parameters.empty()) { + // This should never happen, but if it does, emit a dynamic type. + EmitTypeInfo({.name = "dyn"}, out, options); + } else { + EmitTypeInfo(overload_config.parameters[0], out, options); + } + out << YAML::EndMap; + if (overload_config.parameters.size() > 1) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (size_t i = 1; i < overload_config.parameters.size(); ++i) { + out << YAML::BeginMap; + EmitTypeInfo(overload_config.parameters[i], out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } else { + if (!overload_config.parameters.empty()) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& parameter : overload_config.parameters) { + out << YAML::BeginMap; + EmitTypeInfo(parameter, out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; } - out << YAML::EndSeq; } } - out << YAML::Key << "return"; - out << YAML::Value << YAML::BeginMap; - EmitTypeInfo(overload_config.return_type, out); - out << YAML::EndMap; - + bool return_type_signature_generated = false; + if (options.use_type_signatures) { + absl::StatusOr type_spec = + TypeInfoToTypeSpec(overload_config.return_type); + if (type_spec.ok()) { + absl::StatusOr signature = + common_internal::MakeTypeSpecSignature(*type_spec); + if (signature.ok()) { + out << YAML::Key << "return"; + out << YAML::Value << YAML::DoubleQuoted << *signature; + return_type_signature_generated = true; + } + } + } + if (!return_type_signature_generated) { + out << YAML::Key << "return"; + out << YAML::Value << YAML::BeginMap; + EmitTypeInfo(overload_config.return_type, out, options); + out << YAML::EndMap; + } out << YAML::EndMap; } -void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out) { +void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { const std::vector& function_configs = env_config.GetFunctionConfigs(); if (function_configs.empty()) { @@ -1085,7 +1230,8 @@ void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out) { out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; for (const Config::FunctionOverloadConfig& overload_config : sorted_overloads) { - EmitFunctionOverloadConfig(overload_config, out); + EmitFunctionOverloadConfig(function_config.name, overload_config, out, + options); } out << YAML::EndSeq; } @@ -1116,7 +1262,8 @@ absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { return config; } -void EnvConfigToYaml(const Config& env_config, std::ostream& os) { +void EnvConfigToYaml(const Config& env_config, std::ostream& os, + const EnvConfigToYamlOptions& options) { YAML::Emitter out(os); out.SetIndent(2); out << YAML::BeginMap; @@ -1127,8 +1274,8 @@ void EnvConfigToYaml(const Config& env_config, std::ostream& os) { EmitContainerConfig(env_config, out); EmitExtensionConfigs(env_config, out); EmitStandardLibraryConfig(env_config, out); - EmitVariableConfigs(env_config, out); - EmitFunctionConfigs(env_config, out); + EmitVariableConfigs(env_config, out, options); + EmitFunctionConfigs(env_config, out, options); out << YAML::EndMap; } diff --git a/env/env_yaml.h b/env/env_yaml.h index c96b45933..7bf7bf6b4 100644 --- a/env/env_yaml.h +++ b/env/env_yaml.h @@ -31,8 +31,43 @@ namespace cel { // expensive expressions. absl::StatusOr EnvConfigFromYaml(const std::string& yaml); +struct EnvConfigToYamlOptions { + // Whether to use type and overload signatures instead of arg/return types in + // the output YAML. + // Example of type signature: "map>" vs + // type_name: "map" + // params: + // - type_name: "int" + // - type_name: "A" + // params: + // - type_name: "B" + // is_type_param: true + // + // Example of overload signature config: + // name: "foo" + // overloads: + // - signature: "timestamp.foo(A<~B>)" + // return: "int" + // vs + // name: "foo" + // overloads: + // - id: "foo_id" + // target: + // type_name: "timestamp" + // args: + // - type_name: "A" + // params: + // - type_name: "B" + // is_type_param: true + // return: + // type_name: "int" + // TODO(uncreated-issue/91): default to true after all dependencies are updated + bool use_type_signatures = false; +}; + // EnvConfigToYaml serializes an environment configuration as a YAML string. -void EnvConfigToYaml(const Config& env_config, std::ostream& os); +void EnvConfigToYaml(const Config& env_config, std::ostream& os, + const EnvConfigToYamlOptions& options = {}); } // namespace cel diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index d19c0dbfb..f6bde59c9 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -195,6 +195,28 @@ TEST(EnvYamlTest, ParseVariableConfigs) { } TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "dict" + type: "map" + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "dict"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "map"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, SizeIs(2)); + EXPECT_EQ(type_info.params[0].name, "string"); + EXPECT_FALSE(type_info.params[0].is_type_param); + EXPECT_THAT(type_info.params[0].params, IsEmpty()); + EXPECT_EQ(type_info.params[1].name, "A"); + EXPECT_TRUE(type_info.params[1].is_type_param); + EXPECT_THAT(type_info.params[1].params, IsEmpty()); +} + +TEST(EnvYamlTest, ParseVariableConfigWithTypeParamsLegacySyntax) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( variables: - name: "dict" @@ -221,7 +243,7 @@ TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { } struct ParseConstantTestCase { - std::string type_name; + std::string type; std::string value; std::string expected_error; // Empty if no error. Constant expected_constant; @@ -236,10 +258,10 @@ TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { R"yaml( variables: - name: "const" - type_name: "%s" + type: "%s" value: %s )yaml", - param.type_name, param.value); + param.type, param.value); absl::StatusOr status_or_config = EnvConfigFromYaml(yaml); if (!param.expected_error.empty()) { EXPECT_THAT(status_or_config, StatusIs(absl::StatusCode::kInvalidArgument, @@ -251,8 +273,7 @@ TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { const Config::VariableConfig& variable_config = config.GetVariableConfigs()[0]; EXPECT_EQ(variable_config.name, "const"); - EXPECT_EQ(variable_config.type_info.name, param.type_name) - << " yaml: " << yaml; + EXPECT_EQ(variable_config.type_info.name, param.type) << " yaml: " << yaml; EXPECT_EQ(variable_config.value, param.expected_constant) << " yaml: " << yaml; } @@ -260,119 +281,119 @@ TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { std::vector GetParseConstantTestCases() { return { ParseConstantTestCase{ - .type_name = "null", + .type = "null", .value = "\"\"", .expected_constant = Constant(nullptr), }, ParseConstantTestCase{ - .type_name = "null", + .type = "null", .value = "anything", .expected_error = "Failed to parse null constant", }, ParseConstantTestCase{ - .type_name = "bool", + .type = "bool", .value = "TRUE", .expected_constant = Constant(true), }, ParseConstantTestCase{ - .type_name = "bool", + .type = "bool", .value = "false", .expected_constant = Constant(false), }, ParseConstantTestCase{ - .type_name = "bool", + .type = "bool", .value = "yes", .expected_error = "Failed to parse bool constant", }, ParseConstantTestCase{ - .type_name = "int", + .type = "int", .value = "42", .expected_constant = Constant(int64_t{42}), }, ParseConstantTestCase{ - .type_name = "int", + .type = "int", .value = "41.999", .expected_error = "Failed to parse int constant", }, ParseConstantTestCase{ - .type_name = "uint", + .type = "uint", .value = "42", .expected_constant = Constant(uint64_t{42}), }, ParseConstantTestCase{ - .type_name = "uint", + .type = "uint", .value = "42u", .expected_constant = Constant(uint64_t{42}), }, ParseConstantTestCase{ - .type_name = "uint", + .type = "uint", .value = "-1", .expected_error = "Failed to parse uint constant", }, ParseConstantTestCase{ - .type_name = "double", + .type = "double", .value = "42.42", .expected_constant = Constant(42.42), }, ParseConstantTestCase{ - .type_name = "double", + .type = "double", .value = "abc", .expected_error = "Failed to parse double constant", }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "abc", .expected_constant = Constant(BytesConstant("abc")), }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "b\"\\xFF\\x00\\x01\"", .expected_constant = Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "!!binary /wAB", .expected_constant = Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "!!binary YWJj=", .expected_error = "Node 'YWJj=' is not a valid Base64 encoded binary", }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "abc", .expected_constant = Constant(BytesConstant("abc")), }, ParseConstantTestCase{ - .type_name = "string", + .type = "string", .value = "abc", .expected_constant = Constant(StringConstant("abc")), }, ParseConstantTestCase{ - .type_name = "string", + .type = "string", .value = "\"\\\"abc\\\"\"", .expected_constant = Constant(StringConstant("\"abc\"")), }, ParseConstantTestCase{ - .type_name = "duration", + .type = "duration", .value = "1s", .expected_constant = Constant(absl::Seconds(1)), }, ParseConstantTestCase{ - .type_name = "duration", + .type = "duration", .value = "abc", .expected_error = "Failed to parse duration constant", }, ParseConstantTestCase{ - .type_name = "timestamp", + .type = "timestamp", .value = "2023-01-01T00:00:00Z", .expected_constant = Constant(absl::FromUnixSeconds(1672531200)), }, ParseConstantTestCase{ - .type_name = "timestamp", + .type = "timestamp", .value = "abc", .expected_error = "Failed to parse timestamp constant", }, @@ -439,6 +460,50 @@ TEST_P(EnvYamlParseFunctionTest, EnvYamlParseFunction) { std::vector GetParseFunctionTestCases() { return { + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "isEmpty" + description: |- + determines whether a list is empty, + or a string has no characters + overloads: + - signature: "google.protobuf.StringValue.isEmpty()" + examples: + - "''.isEmpty() // true" + return: "bool" + - signature: "list<~T>.isEmpty()" + examples: + - "[].isEmpty() // true" + - "[1].isEmpty() // false" + return: "bool" + )yaml", + .expected_function_config = + { + .name = "isEmpty", + .description = "determines whether a list is empty,\nor a " + "string has no characters", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .examples = {"''.isEmpty() // true"}, + .is_member_function = true, + .parameters = {{.name = "string_wrapper"}}, + .return_type = {.name = "bool"}, + }, + Config::FunctionOverloadConfig{ + .examples = {"[].isEmpty() // true", + "[1].isEmpty() // false"}, + .is_member_function = true, + .parameters = {{.name = "list", + .params = {{.name = "T", + .is_type_param = + true}}}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, ParseFunctionTestCase{ .yaml = R"yaml( functions: @@ -495,6 +560,34 @@ std::vector GetParseFunctionTestCases() { }, }, }, + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "contains" + overloads: + - signature: "contains(list<~T>, ~T)" + examples: + - "contains([1, 2, 3], 2) // true" + return: "bool" + )yaml", + .expected_function_config = + { + .name = "contains", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .examples = {"contains([1, 2, 3], 2) // true"}, + .is_member_function = false, + .parameters = + {{.name = "list", + .params = {{.name = "T", + .is_type_param = true}}}, + {.name = "T", .is_type_param = true}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, ParseFunctionTestCase{ .yaml = R"yaml( functions: @@ -865,6 +958,18 @@ INSTANTIATE_TEST_SUITE_P( "| is_type_param: maybe\n" "| ^", }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "opaque" + type: "opaque" + )yaml", + .expected_error = "4:19: Node 'type' and 'type_name'" + " are mutually exclusive\n" + "| type_name: \"opaque\"\n" + "| ^", + }, ParseTestCase{ .yaml = R"yaml( variables: @@ -965,12 +1070,65 @@ INSTANTIATE_TEST_SUITE_P( - name: "foo" overloads: - id: "foo_int64" - return: "to sender" + return: [1] )yaml", .expected_error = "6:31: Function overload return type" - " is not a map\n" - "| return: \"to sender\"\n" + " is neither a string nor a map\n" + "| return: [1]\n" "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + signature: "bar()" + )yaml", + .expected_error = "6:34: Function overload name \"bar\" " + "does not match function name \"foo\"\n" + "| signature: \"bar()\"\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: [ "foo()" ] + )yaml", + .expected_error = + "5:34: Function overload signature is not a string\n" + "| - signature: [ \"foo()\" ]\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: "foo()" + target: + type_name: "int" + )yaml", + .expected_error = "6:23: Function overload signature and target " + "are mutually exclusive\n" + "| target:\n" + "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: "foo()" + args: + - type_name: "int" + )yaml", + .expected_error = "6:23: Function overload signature and args are " + "mutually exclusive\n" + "| args:\n" + "| ^", })); std::string Unindent(std::string_view yaml) { @@ -999,6 +1157,7 @@ std::string Unindent(std::string_view yaml) { struct ExportTestCase { absl::StatusOr config; std::string expected_yaml; + std::string expected_alt_yaml; }; class EnvYamlExportTest : public testing::TestWithParam {}; @@ -1007,10 +1166,18 @@ TEST_P(EnvYamlExportTest, EnvYamlExport) { const ExportTestCase& param = GetParam(); ASSERT_OK_AND_ASSIGN(Config config, param.config); std::stringstream ss; - EnvConfigToYaml(config, ss); + EnvConfigToYaml(config, ss, {.use_type_signatures = true}); std::string yaml_output = Unindent(ss.str()); std::string expected_yaml = Unindent(param.expected_yaml); EXPECT_EQ(yaml_output, expected_yaml); + + if (!param.expected_alt_yaml.empty()) { + std::stringstream alt_ss; + EnvConfigToYaml(config, alt_ss, {.use_type_signatures = false}); + std::string alt_yaml_output = Unindent(alt_ss.str()); + std::string expected_alt_yaml = Unindent(param.expected_alt_yaml); + EXPECT_EQ(alt_yaml_output, expected_alt_yaml); + } } std::vector GetExportTestCases() { @@ -1211,7 +1378,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "null" + type: "null" )yaml", }, ExportTestCase{ @@ -1224,6 +1391,12 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "bool" + value: true + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "bool" @@ -1240,6 +1413,12 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "int" + value: 42 + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "int" @@ -1258,7 +1437,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "uint" + type: "uint" value: 777 )yaml", }, @@ -1274,7 +1453,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "double" + type: "double" value: 0.75 )yaml", }, @@ -1291,7 +1470,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "bytes" + type: "bytes" value: b"\xff\x00\x01" )yaml", }, @@ -1309,7 +1488,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "string" + type: "string" value: "'single' \"double\"" )yaml", }, @@ -1324,6 +1503,12 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "duration" + value: 1h2m3s + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "duration" @@ -1340,6 +1525,12 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "timestamp" @@ -1358,7 +1549,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "google.expr.proto3.test.TestAllTypes" + type: "google.expr.proto3.test.TestAllTypes" )yaml", }, ExportTestCase{ @@ -1373,6 +1564,11 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "A" + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "A" @@ -1402,12 +1598,22 @@ std::vector GetExportTestCases() { {.overload_id = "foo_overload_id", .is_member_function = true, .parameters = {{.name = "timestamp"}, - {.name = "A", .params = {{.name = "B"}}}}, + {.name = "A", + .params = {{.name = "B", + .is_type_param = true}}}}, .return_type = {.name = "int"}}, }})); return config; }(), .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", + .expected_alt_yaml = R"yaml( functions: - name: "foo" overloads: @@ -1418,6 +1624,7 @@ std::vector GetExportTestCases() { - type_name: "A" params: - type_name: "B" + is_type_param: true return: type_name: "int" )yaml", @@ -1427,6 +1634,7 @@ std::vector GetExportTestCases() { Config config; CEL_RETURN_IF_ERROR(config.AddFunctionConfig( {.name = "foo", + .description = "my desc", .overload_configs = { {.overload_id = "foo_overload_a", .parameters = {{.name = "timestamp"}}, @@ -1442,6 +1650,19 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( functions: - name: "foo" + description: "my desc" + overloads: + - id: "foo_overload_b" + signature: "foo(double,A)" + return: "string" + - id: "foo_overload_a" + signature: "foo(timestamp)" + return: "list" + )yaml", + .expected_alt_yaml = R"yaml( + functions: + - name: "foo" + description: "my desc" overloads: - id: "foo_overload_b" args: @@ -1466,9 +1687,10 @@ std::vector GetExportTestCases() { INSTANTIATE_TEST_SUITE_P(EnvYamlExportTest, EnvYamlExportTest, ::testing::ValuesIn(GetExportTestCases())); -class EnvYamlRoundTripTest : public testing::TestWithParam {}; +class EnvYamlStructuredRoundTripTest + : public testing::TestWithParam {}; -TEST_P(EnvYamlRoundTripTest, EnvYamlRoundTrip) { +TEST_P(EnvYamlStructuredRoundTripTest, EnvYamlRoundTrip) { const std::string& yaml = Unindent(GetParam()); ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(yaml)); @@ -1477,7 +1699,7 @@ TEST_P(EnvYamlRoundTripTest, EnvYamlRoundTrip) { EXPECT_EQ(ss.str(), yaml); } -std::vector GetRoundTripTestCases() { +std::vector GetStructuredRoundTripTestCases() { return { R"yaml( stdlib: @@ -1536,74 +1758,83 @@ std::vector GetRoundTripTestCases() { overloads: - id: "string_to_timestamp" )yaml", + R"yaml( + functions: + - name: "bar" + - name: "foo" + )yaml", + }; +} + +INSTANTIATE_TEST_SUITE_P( + EnvYamlStructuredRoundTripTest, EnvYamlStructuredRoundTripTest, + ::testing::ValuesIn(GetStructuredRoundTripTestCases())); + +class EnvYamlSignatureRoundTripTest + : public testing::TestWithParam {}; + +TEST_P(EnvYamlSignatureRoundTripTest, EnvYamlRoundTrip) { + const std::string& yaml = Unindent(GetParam()); + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(yaml)); + + std::stringstream ss; + EnvConfigToYaml(config, ss, {.use_type_signatures = true}); + EXPECT_EQ(ss.str(), yaml); +} + +std::vector GetSignatureRoundTripTestCases() { + return { R"yaml( variables: - name: "a" - type_name: "null" + type: "null" - name: "b" - type_name: "bool" + type: "bool" value: true - name: "c" - type_name: "int" + type: "int" value: 42 - name: "d" - type_name: "uint" + type: "uint" value: 777 - name: "e" - type_name: "double" + type: "double" value: 0.75 - name: "f" - type_name: "bytes" + type: "bytes" value: b"\xff\x00\x01" - name: "g" - type_name: "string" + type: "string" value: "plain 'single' \"double\"" - name: "h" - type_name: "duration" + type: "duration" value: 1h2m3s - name: "i" - type_name: "timestamp" + type: "timestamp" value: 2026-01-02T03:04:05Z )yaml", - R"yaml( - functions: - - name: "bar" - - name: "foo" - )yaml", R"yaml( functions: - name: "foo" overloads: - id: "foo_overload_id" - target: - type_name: "timestamp" - args: - - type_name: "A" - params: - - type_name: "B" - return: - type_name: "int" + signature: "timestamp.foo(A<~B>)" + return: "int" )yaml", R"yaml( functions: - name: "foo" overloads: - id: "foo_overload_id" - args: - - type_name: "timestamp" - - type_name: "A" - params: - - type_name: "B" - return: - type_name: "list" - params: - - type_name: "int" + signature: "foo(timestamp,A<~B>)" + return: "list" )yaml", }; } -INSTANTIATE_TEST_SUITE_P(EnvYamlRoundTripTest, EnvYamlRoundTripTest, - ::testing::ValuesIn(GetRoundTripTestCases())); +INSTANTIATE_TEST_SUITE_P(EnvYamlSignatureRoundTripTest, + EnvYamlSignatureRoundTripTest, + ::testing::ValuesIn(GetSignatureRoundTripTestCases())); } // namespace } // namespace cel diff --git a/env/type_info.cc b/env/type_info.cc index a5b47b6f1..f49fab9f4 100644 --- a/env/type_info.cc +++ b/env/type_info.cc @@ -14,13 +14,17 @@ #include "env/type_info.h" +#include #include +#include #include #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" #include "env/config.h" @@ -180,5 +184,227 @@ absl::StatusOr TypeInfoToType( return DynType(); } } +absl::StatusOr TypeInfoToTypeSpec(const Config::TypeInfo& type_info) { + if (type_info.is_type_param) { + return TypeSpec(ParamTypeSpec(type_info.name)); + } + + std::optional type_kind = TypeNameToTypeKind(type_info.name); + if (!type_kind.has_value()) { + if (type_info.params.empty()) { + return TypeSpec(MessageTypeSpec(type_info.name)); + } else { + std::vector param_specs; + param_specs.reserve(type_info.params.size()); + for (const Config::TypeInfo& param : type_info.params) { + CEL_ASSIGN_OR_RETURN(TypeSpec param_spec, TypeInfoToTypeSpec(param)); + param_specs.push_back(std::move(param_spec)); + } + return TypeSpec(AbstractType(type_info.name, std::move(param_specs))); + } + } + + switch (*type_kind) { + case TypeKind::kNull: + return TypeSpec(NullTypeSpec()); + case TypeKind::kBool: + return TypeSpec(PrimitiveType::kBool); + case TypeKind::kInt: + return TypeSpec(PrimitiveType::kInt64); + case TypeKind::kUint: + return TypeSpec(PrimitiveType::kUint64); + case TypeKind::kDouble: + return TypeSpec(PrimitiveType::kDouble); + case TypeKind::kString: + return TypeSpec(PrimitiveType::kString); + case TypeKind::kBytes: + return TypeSpec(PrimitiveType::kBytes); + case TypeKind::kTimestamp: + return TypeSpec(WellKnownTypeSpec::kTimestamp); + case TypeKind::kDuration: + return TypeSpec(WellKnownTypeSpec::kDuration); + case TypeKind::kList: { + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN(TypeSpec elem_type, + TypeInfoToTypeSpec(type_info.params[0])); + return TypeSpec( + ListTypeSpec(std::make_unique(std::move(elem_type)))); + } else { + return TypeSpec(ListTypeSpec()); + } + } + case TypeKind::kMap: { + if (type_info.params.empty()) { + return TypeSpec(MapTypeSpec()); + } + CEL_ASSIGN_OR_RETURN(TypeSpec key_type, + TypeInfoToTypeSpec(type_info.params[0])); + if (type_info.params.size() > 1) { + CEL_ASSIGN_OR_RETURN(TypeSpec value_type, + TypeInfoToTypeSpec(type_info.params[1])); + return TypeSpec( + MapTypeSpec(std::make_unique(std::move(key_type)), + std::make_unique(std::move(value_type)))); + } + return TypeSpec(MapTypeSpec( + std::make_unique(std::move(key_type)), nullptr)); + } + case TypeKind::kDyn: + return TypeSpec(DynTypeSpec()); + case TypeKind::kAny: + return TypeSpec(WellKnownTypeSpec::kAny); + case TypeKind::kBoolWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + case TypeKind::kIntWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + case TypeKind::kUintWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + case TypeKind::kDoubleWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + case TypeKind::kStringWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + case TypeKind::kBytesWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + case TypeKind::kType: { + if (type_info.params.empty()) { + return TypeSpec(std::make_unique(DynTypeSpec())); + } + CEL_ASSIGN_OR_RETURN(TypeSpec type_param, + TypeInfoToTypeSpec(type_info.params[0])); + return TypeSpec(std::make_unique(std::move(type_param))); + } + default: + return TypeSpec(DynTypeSpec()); + } +} + +absl::StatusOr TypeSpecToTypeInfo(const TypeSpec& type_spec) { + Config::TypeInfo type_info; + + if (type_spec.has_dyn()) { + type_info.name = "dyn"; + } else if (type_spec.has_null()) { + type_info.name = "null"; + } else if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + type_info.name = "bool"; + break; + case PrimitiveType::kInt64: + type_info.name = "int"; + break; + case PrimitiveType::kUint64: + type_info.name = "uint"; + break; + case PrimitiveType::kDouble: + type_info.name = "double"; + break; + case PrimitiveType::kString: + type_info.name = "string"; + break; + case PrimitiveType::kBytes: + type_info.name = "bytes"; + break; + default: + return absl::InvalidArgumentError("Unspecified primitive type"); + } + } else if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + type_info.name = "bool_wrapper"; + break; + case PrimitiveType::kInt64: + type_info.name = "int_wrapper"; + break; + case PrimitiveType::kUint64: + type_info.name = "uint_wrapper"; + break; + case PrimitiveType::kDouble: + type_info.name = "double_wrapper"; + break; + case PrimitiveType::kString: + type_info.name = "string_wrapper"; + break; + case PrimitiveType::kBytes: + type_info.name = "bytes_wrapper"; + break; + default: + return absl::InvalidArgumentError("Unspecified wrapper type"); + } + } else if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + type_info.name = "any"; + break; + case WellKnownTypeSpec::kTimestamp: + type_info.name = "timestamp"; + break; + case WellKnownTypeSpec::kDuration: + type_info.name = "duration"; + break; + default: + return absl::InvalidArgumentError("Unspecified well known type"); + } + } else if (type_spec.has_list_type()) { + type_info.name = "list"; + const ListTypeSpec& list_type = type_spec.list_type(); + if (list_type.has_elem_type() && list_type.elem_type().is_specified()) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(list_type.elem_type())); + type_info.params.push_back(std::move(param)); + } + } else if (type_spec.has_map_type()) { + type_info.name = "map"; + const MapTypeSpec& map_type = type_spec.map_type(); + bool has_key = + map_type.has_key_type() && map_type.key_type().is_specified(); + bool has_value = + map_type.has_value_type() && map_type.value_type().is_specified(); + if (has_key || has_value) { + if (has_key) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(map_type.key_type())); + type_info.params.push_back(std::move(param)); + } else { + type_info.params.push_back(Config::TypeInfo{.name = "dyn"}); + } + if (has_value) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param_value, + TypeSpecToTypeInfo(map_type.value_type())); + type_info.params.push_back(std::move(param_value)); + } else { + type_info.params.push_back(Config::TypeInfo{.name = "dyn"}); + } + } + } else if (type_spec.has_message_type()) { + type_info.name = type_spec.message_type().type(); + } else if (type_spec.has_type_param()) { + type_info.name = type_spec.type_param().type(); + type_info.is_type_param = true; + } else if (type_spec.has_type()) { + type_info.name = "type"; + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(type_spec.type())); + type_info.params.push_back(std::move(param)); + } else if (type_spec.has_abstract_type()) { + type_info.name = type_spec.abstract_type().name(); + for (const TypeSpec& param_spec : + type_spec.abstract_type().parameter_types()) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(param_spec)); + type_info.params.push_back(std::move(param)); + } + } else if (type_spec.has_error()) { + return absl::InvalidArgumentError( + "ErrorType cannot be converted to TypeInfo"); + } else if (type_spec.has_function()) { + return absl::InvalidArgumentError( + "FunctionType cannot be converted to TypeInfo"); + } else { + return absl::InvalidArgumentError("Unknown TypeSpec kind"); + } + + return type_info; +} } // namespace cel diff --git a/env/type_info.h b/env/type_info.h index bb3cfde43..3f802ce1a 100644 --- a/env/type_info.h +++ b/env/type_info.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ #include "absl/status/statusor.h" +#include "common/ast.h" #include "common/type.h" #include "env/config.h" #include "google/protobuf/arena.h" @@ -30,6 +31,12 @@ absl::StatusOr TypeInfoToType( const Config::TypeInfo& type_info, const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena); +// Converts a Config::TypeInfo to a cel::TypeSpec. +absl::StatusOr TypeInfoToTypeSpec(const Config::TypeInfo& type_info); + +// Converts a cel::TypeSpec to a Config::TypeInfo. +absl::StatusOr TypeSpecToTypeInfo(const TypeSpec& type_spec); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ diff --git a/env/type_info_test.cc b/env/type_info_test.cc index 015d8a928..f9d46f9a9 100644 --- a/env/type_info_test.cc +++ b/env/type_info_test.cc @@ -14,9 +14,14 @@ #include "env/type_info.h" +#include +#include +#include #include #include +#include "absl/status/status.h" +#include "common/ast/metadata.h" #include "common/type.h" #include "common/type_proto.h" #include "env/config.h" @@ -28,9 +33,27 @@ #include "google/protobuf/text_format.h" namespace cel { + +std::ostream& operator<<(std::ostream& os, const Config::TypeInfo& type_info) { + if (type_info.is_type_param) { + os << "?"; + } + os << type_info.name; + if (!type_info.params.empty()) { + os << "<"; + for (size_t i = 0; i < type_info.params.size(); ++i) { + if (i > 0) os << ", "; + os << type_info.params[i]; + } + os << ">"; + } + return os; +} + namespace { using absl_testing::IsOk; +using absl_testing::StatusIs; using testing::ValuesIn; struct TestCase { @@ -127,5 +150,151 @@ std::vector GetTestCases() { INSTANTIATE_TEST_SUITE_P(TypeInfoTest, TypeInfoTest, ValuesIn(GetTestCases())); +bool TypeInfoEqImpl(const Config::TypeInfo& actual, + const Config::TypeInfo& expected) { + if (actual.name != expected.name) return false; + if (actual.is_type_param != expected.is_type_param) return false; + if (actual.params.size() != expected.params.size()) return false; + for (size_t i = 0; i < actual.params.size(); ++i) { + if (!TypeInfoEqImpl(actual.params[i], expected.params[i])) return false; + } + return true; +} + +MATCHER_P(TypeInfoEq, expected, "") { return TypeInfoEqImpl(arg, expected); } + +struct TypeSpecTestCase { + TypeSpec type_spec; + Config::TypeInfo expected_type_info; +}; + +using TypeSpecToTypeInfoTest = testing::TestWithParam; + +TEST_P(TypeSpecToTypeInfoTest, Convert) { + const TypeSpecTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config::TypeInfo actual_type_info, + TypeSpecToTypeInfo(param.type_spec)); + EXPECT_THAT(actual_type_info, TypeInfoEq(param.expected_type_info)); +} + +std::vector GetTypeSpecTestCases() { + return { + TypeSpecTestCase{ + .type_spec = TypeSpec(PrimitiveType::kInt64), + .expected_type_info = {.name = "int"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + ListTypeSpec(std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "list", + .params = {Config::TypeInfo{.name = "int"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(ListTypeSpec()), + .expected_type_info = {.name = "list"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "int"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(MapTypeSpec()), + .expected_type_info = {.name = "map"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MessageTypeSpec("cel.expr.conformance.proto2.TestAllTypes")), + .expected_type_info = + {.name = "cel.expr.conformance.proto2.TestAllTypes"}, + }, + TypeSpecTestCase{ + .type_spec = + TypeSpec(AbstractType("A", {TypeSpec(ParamTypeSpec("B"))})), + .expected_type_info = {.name = "A", + .params = {Config::TypeInfo{ + .name = "B", .is_type_param = true}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(WellKnownTypeSpec::kAny), + .expected_type_info = {.name = "any"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(WellKnownTypeSpec::kTimestamp), + .expected_type_info = {.name = "timestamp"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + .expected_type_info = {.name = "double_wrapper"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + std::make_unique(WellKnownTypeSpec::kDuration)), + .expected_type_info = {.name = "type", + .params = {Config::TypeInfo{.name = + "duration"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(std::make_unique(DynTypeSpec())), + .expected_type_info = {.name = "type", + .params = {Config::TypeInfo{.name = "dyn"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(DynTypeSpec{}), + .expected_type_info = {.name = "dyn"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(NullTypeSpec{}), + .expected_type_info = {.name = "null"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec()))), + .expected_type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "dyn"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MapTypeSpec(std::make_unique(DynTypeSpec()), + std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "dyn"}, + Config::TypeInfo{.name = "int"}}}, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(TypeSpecToTypeInfoTest, TypeSpecToTypeInfoTest, + ValuesIn(GetTypeSpecTestCases())); + +using TypeInfoToTypeSpecTest = testing::TestWithParam; + +TEST_P(TypeInfoToTypeSpecTest, Convert) { + const TypeSpecTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(TypeSpec actual_type_spec, + TypeInfoToTypeSpec(param.expected_type_info)); + EXPECT_EQ(actual_type_spec, param.type_spec); +} + +INSTANTIATE_TEST_SUITE_P(TypeInfoToTypeSpecTest, TypeInfoToTypeSpecTest, + ValuesIn(GetTypeSpecTestCases())); + +TEST(TypeSpecToTypeInfoTest, ErrorConversions) { + EXPECT_THAT(TypeSpecToTypeInfo(TypeSpec(ErrorTypeSpec::kValue)), + StatusIs(absl::StatusCode::kInvalidArgument, + "ErrorType cannot be converted to TypeInfo")); + EXPECT_THAT(TypeSpecToTypeInfo(TypeSpec(FunctionTypeSpec())), + StatusIs(absl::StatusCode::kInvalidArgument, + "FunctionType cannot be converted to TypeInfo")); + EXPECT_THAT( + TypeSpecToTypeInfo(TypeSpec(UnsetTypeSpec())), + StatusIs(absl::StatusCode::kInvalidArgument, "Unknown TypeSpec kind")); +} + } // namespace } // namespace cel From a8b3224f86712c8c564b9a54abf4eac7cc5ee39a Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Tue, 9 Jun 2026 11:04:51 -0700 Subject: [PATCH 55/87] Add the function AddContextDeclarationWithProtoTypeMask to the type checker. PiperOrigin-RevId: 929295083 --- checker/internal/BUILD | 6 + checker/internal/type_check_env.cc | 5 + checker/internal/type_check_env.h | 14 + checker/internal/type_checker_builder_impl.cc | 56 +++- checker/internal/type_checker_builder_impl.h | 9 + .../type_checker_builder_impl_test.cc | 311 ++++++++++++++++++ checker/type_checker_builder.h | 22 ++ checker/type_checker_builder_factory_test.cc | 47 +++ 8 files changed, 464 insertions(+), 6 deletions(-) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 26c7b543f..777457830 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -66,6 +66,8 @@ cc_library( hdrs = ["type_check_env.h"], deps = [ ":descriptor_pool_type_introspector", + ":proto_type_mask", + ":proto_type_mask_registry", "//common:constant", "//common:container", "//common:decl", @@ -76,6 +78,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", @@ -129,6 +132,7 @@ cc_library( deps = [ ":format_type_name", ":namespace_generator", + ":proto_type_mask", ":type_check_env", ":type_inference_context", "//checker:checker_options", @@ -154,6 +158,7 @@ cc_library( "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:absl_check", @@ -226,6 +231,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:optional", "@com_google_protobuf//:protobuf", diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index 763d9ba46..47487220c 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -16,6 +16,7 @@ #include #include +#include #include "absl/base/nullability.h" #include "absl/status/statusor.h" @@ -96,6 +97,10 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( absl::StatusOr> TypeCheckEnv::LookupStructField( absl::string_view type_name, absl::string_view field_name) const { + if (proto_type_mask_registry_ != nullptr && + !proto_type_mask_registry_->FieldIsVisible(type_name, field_name)) { + return absl::nullopt; + } // Check the type providers in registration order. // Note: this doesn't allow for shadowing a type with a subset type of the // same name -- the later type provider will still be considered when diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 15f8ecc4d..00fea0ba3 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -25,16 +25,20 @@ #include "absl/container/flat_hash_map.h" #include "absl/log/absl_check.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/internal/descriptor_pool_type_introspector.h" +#include "checker/internal/proto_type_mask.h" +#include "checker/internal/proto_type_mask_registry.h" #include "common/constant.h" #include "common/container.h" #include "common/decl.h" #include "common/type.h" #include "common/type_introspector.h" +#include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" @@ -154,6 +158,14 @@ class TypeCheckEnv { variables_[decl.name()] = std::move(decl); } + absl::Status CreateProtoTypeMaskRegistry( + const std::vector& proto_type_masks) { + CEL_ASSIGN_OR_RETURN(proto_type_mask_registry_, + ProtoTypeMaskRegistry::Create(descriptor_pool_.get(), + proto_type_masks)); + return absl::OkStatus(); + } + const absl::flat_hash_map& functions() const { return functions_; } @@ -224,6 +236,8 @@ class TypeCheckEnv { absl::flat_hash_map variables_; absl::flat_hash_map functions_; + std::shared_ptr proto_type_mask_registry_; + // Type providers for custom types. std::vector> type_providers_; diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 85b581e83..9b91fc926 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -23,13 +24,16 @@ #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/cleanup/cleanup.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "checker/internal/proto_type_mask.h" #include "checker/internal/type_check_env.h" #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" @@ -86,10 +90,19 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) { } absl::Status AddWellKnownContextDeclarationVariables( - const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env, - bool use_json_name) { + const google::protobuf::Descriptor* absl_nonnull descriptor, + const absl::flat_hash_map>& + context_type_fields, + TypeCheckEnv& env, bool use_json_name) { for (int i = 0; i < descriptor->field_count(); ++i) { const google::protobuf::FieldDescriptor* field = descriptor->field(i); + // Skip fields that are hidden because of a proto type mask. + auto map_iterator = context_type_fields.find(descriptor->full_name()); + if (map_iterator != context_type_fields.end() && + !map_iterator->second.contains(field->name())) { + continue; + } Type type = MessageTypeField(field).GetType(); if (type.IsEnum()) { type = IntType(); @@ -109,11 +122,15 @@ absl::Status AddWellKnownContextDeclarationVariables( } absl::Status AddContextDeclarationVariables( - const google::protobuf::Descriptor* absl_nonnull descriptor, TypeCheckEnv& env) { + const google::protobuf::Descriptor* absl_nonnull descriptor, + const absl::flat_hash_map>& + context_type_fields, + TypeCheckEnv& env) { const bool use_json_name = env.proto_type_introspector().use_json_name(); if (IsWellKnownMessageType(descriptor)) { - return AddWellKnownContextDeclarationVariables(descriptor, env, - use_json_name); + return AddWellKnownContextDeclarationVariables( + descriptor, context_type_fields, env, use_json_name); } CEL_ASSIGN_OR_RETURN(auto fields, env.proto_type_introspector().ListFieldsForStructType( @@ -131,6 +148,13 @@ absl::Status AddContextDeclarationVariables( absl::string_view name = field_entry.name; + // Skip fields that are hidden because of a proto type mask. + auto map_iterator = context_type_fields.find(descriptor->full_name()); + if (map_iterator != context_type_fields.end() && + !map_iterator->second.contains(name)) { + continue; + } + if (!env.InsertVariableIfAbsent(MakeVariableDecl(name, type))) { return absl::AlreadyExistsError( absl::StrCat("variable '", name, @@ -317,7 +341,8 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( } for (const google::protobuf::Descriptor* context_type : config.context_types) { - CEL_RETURN_IF_ERROR(AddContextDeclarationVariables(context_type, env)); + CEL_RETURN_IF_ERROR(AddContextDeclarationVariables( + context_type, config.context_type_fields, env)); } for (VariableDeclRecord& var : config.variables) { @@ -339,6 +364,8 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig( } } + CEL_RETURN_IF_ERROR(env.CreateProtoTypeMaskRegistry(config.proto_type_masks)); + return absl::OkStatus(); } @@ -462,6 +489,23 @@ absl::Status TypeCheckerBuilderImpl::AddContextDeclaration( return absl::OkStatus(); } +absl::Status TypeCheckerBuilderImpl::AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) { + if (field_paths.empty()) { + return absl::InvalidArgumentError("field paths cannot be the empty set"); + } + + ProtoTypeMask proto_type_mask(std::string(type), field_paths); + target_config_->proto_type_masks.push_back(proto_type_mask); + + CEL_RETURN_IF_ERROR(AddContextDeclaration(type)); + CEL_ASSIGN_OR_RETURN( + absl::btree_set field_names, + proto_type_mask.GetFieldNames(template_env_.descriptor_pool())); + target_config_->context_type_fields.insert({type, std::move(field_names)}); + return absl::OkStatus(); +} + absl::Status TypeCheckerBuilderImpl::AddFunction(const FunctionDecl& decl) { CEL_RETURN_IF_ERROR( ValidateFunctionDecl(decl, options_.enable_type_parameter_name_validation, diff --git a/checker/internal/type_checker_builder_impl.h b/checker/internal/type_checker_builder_impl.h index 646a5d16f..9895a8aee 100644 --- a/checker/internal/type_checker_builder_impl.h +++ b/checker/internal/type_checker_builder_impl.h @@ -21,6 +21,7 @@ #include #include "absl/base/nullability.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" @@ -28,6 +29,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "checker/checker_options.h" +#include "checker/internal/proto_type_mask.h" #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" #include "checker/type_checker_builder.h" @@ -76,6 +78,8 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { absl::Status AddVariable(const VariableDecl& decl) override; absl::Status AddOrReplaceVariable(const VariableDecl& decl) override; absl::Status AddContextDeclaration(absl::string_view type) override; + absl::Status AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) override; absl::Status AddFunction(const FunctionDecl& decl) override; absl::Status MergeFunction(const FunctionDecl& decl) override; @@ -130,6 +134,11 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder { std::vector functions; std::vector> type_providers; std::vector context_types; + // Maps context type names to fields names to add as variables. + // Only includes context types that are defined with proto type masks. + absl::flat_hash_map> + context_type_fields; + std::vector proto_type_masks; }; absl::Status BuildLibraryConfig(const CheckerLibrary& library, diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc index 494e7e440..913e704ee 100644 --- a/checker/internal/type_checker_builder_impl_test.cc +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -15,12 +15,15 @@ #include "checker/internal/type_checker_builder_impl.h" #include +#include #include #include +#include #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "checker/checker_options.h" @@ -107,6 +110,168 @@ INSTANTIATE_TEST_SUITE_P( MapTypeSpec(std::make_unique(PrimitiveType::kString), std::make_unique(DynTypeSpec())))})); +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnEmptyFieldPaths) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {}), + StatusIs(absl::StatusCode::kInvalidArgument, + "field paths cannot be the empty set")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnUnknownFieldPath) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"unknown_field"}), + StatusIs(absl::StatusCode::kInvalidArgument, + "could not select field 'unknown_field' from type " + "'cel.expr.conformance.proto3.TestAllTypes'")); +} + +class ContextDeclsWithProtoTypeMaskFieldsDefinedTest + : public testing::TestWithParam {}; + +std::string LogFieldName(absl::string_view field_name, absl::string_view expr) { + return absl::StrCat("field_name: ", field_name, ", expr: ", expr); +} + +TEST_P(ContextDeclsWithProtoTypeMaskFieldsDefinedTest, + ContextDeclsWithProtoTypeMaskFieldsDefined) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {GetParam().expr}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + std::vector field_names = { + "single_int64", "single_uint32", "single_double", + "single_string", "single_any", "single_duration", + "single_bool_wrapper", "list_value", "standalone_message", + "standalone_enum", "repeated_bytes", "repeated_nested_message", + "map_int32_timestamp", "single_struct"}; + for (auto& field_name : field_names) { + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(field_name)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + if (field_name == GetParam().expr) { + // The field name that is part of the proto type mask is visible. + ASSERT_TRUE(result.IsValid()) + << LogFieldName(field_name, GetParam().expr); + EXPECT_EQ(result.GetAst()->GetReturnType(), GetParam().expected_type) + << LogFieldName(field_name, GetParam().expr); + } else { + // The field names that are not part of the proto type mask are not + // visible. + EXPECT_FALSE(result.IsValid()) + << LogFieldName(field_name, GetParam().expr); + } + } +} + +INSTANTIATE_TEST_SUITE_P( + TestAllTypes, ContextDeclsWithProtoTypeMaskFieldsDefinedTest, + testing::Values( + ContextDeclsTestCase{"single_int64", TypeSpec(PrimitiveType::kInt64)}, + ContextDeclsTestCase{"single_uint32", TypeSpec(PrimitiveType::kUint64)}, + ContextDeclsTestCase{"single_double", TypeSpec(PrimitiveType::kDouble)}, + ContextDeclsTestCase{"single_string", TypeSpec(PrimitiveType::kString)}, + ContextDeclsTestCase{"single_any", TypeSpec(WellKnownTypeSpec::kAny)}, + ContextDeclsTestCase{"single_duration", + TypeSpec(WellKnownTypeSpec::kDuration)}, + ContextDeclsTestCase{ + "single_bool_wrapper", + TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool))}, + ContextDeclsTestCase{ + "list_value", + TypeSpec(ListTypeSpec(std::make_unique(DynTypeSpec())))}, + ContextDeclsTestCase{ + "standalone_message", + TypeSpec(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))}, + ContextDeclsTestCase{"standalone_enum", + TypeSpec(PrimitiveType::kInt64)}, + ContextDeclsTestCase{"repeated_bytes", + TypeSpec(ListTypeSpec(std::make_unique( + PrimitiveType::kBytes)))}, + ContextDeclsTestCase{ + "repeated_nested_message", + TypeSpec(ListTypeSpec(std::make_unique(MessageTypeSpec( + "cel.expr.conformance.proto3.TestAllTypes.NestedMessage"))))}, + ContextDeclsTestCase{ + "map_int32_timestamp", + TypeSpec(MapTypeSpec( + std::make_unique(PrimitiveType::kInt64), + std::make_unique(WellKnownTypeSpec::kTimestamp)))}, + ContextDeclsTestCase{ + "single_struct", + TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(DynTypeSpec())))})); + +TEST(ContextDeclsWithProtoTypeMaskTest, FieldsInMaskAreVisibleFieldAccess) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + // Visible field: standalone_message.bb + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("payload.standalone_message.bb")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + EXPECT_EQ(result.GetAst()->GetReturnType(), TypeSpec(PrimitiveType::kInt64)); + // Visible field: single_int32 + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int32")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + EXPECT_EQ(result.GetAst()->GetReturnType(), TypeSpec(PrimitiveType::kInt64)); + // Not Visible field: single_int64 + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int64")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, FieldsInMaskAreVisibleFieldAssignment) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.NestedTestAllTypes", + {"payload.standalone_message.bb", "payload.single_int32"}), + IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + // Visible field: standalone_message.bb + ASSERT_OK_AND_ASSIGN( + auto ast, + MakeTestParsedAst( + R"(cel.expr.conformance.proto3.TestAllTypes.NestedMessage{bb: 12345})")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + // Visible field: single_int32 + ASSERT_OK_AND_ASSIGN( + ast, + MakeTestParsedAst( + R"(cel.expr.conformance.proto3.TestAllTypes{single_int32: 12345})")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + // Not Visible field: single_int64 + ASSERT_OK_AND_ASSIGN( + ast, + MakeTestParsedAst( + R"(cel.expr.conformance.proto3.TestAllTypes{single_int64: 12345})")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + TEST(ContextDeclsTest, ErrorOnDuplicateContextDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -120,6 +285,20 @@ TEST(ContextDeclsTest, ErrorOnDuplicateContextDeclaration) { "already exists")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnDuplicateContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"standalone_message"}), + IsOk()); + EXPECT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + StatusIs(absl::StatusCode::kAlreadyExists, + "context declaration 'cel.expr.conformance.proto3.TestAllTypes' " + "already exists")); +} + TEST(ContextDeclsTest, ErrorOnContextDeclarationNotFound) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -129,6 +308,16 @@ TEST(ContextDeclsTest, ErrorOnContextDeclarationNotFound) { "context declaration 'com.example.UnknownType' not found")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnContextDeclarationNotFound) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclarationWithProtoTypeMask("com.example.UnknownType", + {"any_field_name"}), + StatusIs(absl::StatusCode::kNotFound, + "context declaration 'com.example.UnknownType' not found")); +} + TEST(ContextDeclsTest, ErrorOnNonStructMessageType) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -139,6 +328,17 @@ TEST(ContextDeclsTest, ErrorOnNonStructMessageType) { "context declaration 'google.protobuf.Timestamp' is not a struct")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnNonStructMessageType) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + EXPECT_THAT( + builder.AddContextDeclarationWithProtoTypeMask( + "google.protobuf.Timestamp", {"any_field_name"}), + StatusIs( + absl::StatusCode::kInvalidArgument, + "context declaration 'google.protobuf.Timestamp' is not a struct")); +} + TEST(ContextDeclsTest, CustomStructNotSupported) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -160,6 +360,28 @@ TEST(ContextDeclsTest, CustomStructNotSupported) { "context declaration 'com.example.MyStruct' not found")); } +TEST(ContextDeclsWithProtoTypeMaskTest, CustomStructNotSupported) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + class MyTypeProvider : public cel::TypeIntrospector { + public: + absl::StatusOr> FindTypeImpl( + absl::string_view name) const override { + if (name == "com.example.MyStruct") { + return common_internal::MakeBasicStructType("com.example.MyStruct"); + } + return absl::nullopt; + } + }; + + builder.AddTypeProvider(std::make_unique()); + + EXPECT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "com.example.MyStruct", {"any_field_name"}), + StatusIs(absl::StatusCode::kNotFound, + "context declaration 'com.example.MyStruct' not found")); +} + TEST(ContextDeclsTest, ErrorOnOverlappingContextDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -179,6 +401,69 @@ TEST(ContextDeclsTest, ErrorOnOverlappingContextDeclaration) { "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnOverlappingContextDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT( + builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.TestAllTypes", {"single_int32"}), + IsOk()); + + EXPECT_THAT( + builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int32' declared multiple times (from context " + "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, + ErrorOnOverlappingContextDeclarationBothProtoTypeMasks) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.TestAllTypes", {"single_int32"}), + IsOk()); + + EXPECT_THAT( + builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int32' declared multiple times (from context " + "declaration: 'cel.expr.conformance.proto2.TestAllTypes')")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, + NonOverlappingContextDeclarationBothProtoTypeMasks) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + // We resolve the context declaration variables at the Build() call, so the + // error surfaces then. + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto2.NestedTestAllTypes", + {"payload.single_int64"}), + IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder.Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("single_int32")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("payload.single_int64")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); +} + TEST(ContextDeclsTest, ErrorOnOverlappingVariableDeclaration) { TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), {}); @@ -193,6 +478,32 @@ TEST(ContextDeclsTest, ErrorOnOverlappingVariableDeclaration) { "variable 'single_int64' declared multiple times")); } +TEST(ContextDeclsWithProtoTypeMaskTest, ErrorOnOverlappingVariableDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int64"}), + IsOk()); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())), + IsOk()); + + EXPECT_THAT(builder.Build(), + StatusIs(absl::StatusCode::kAlreadyExists, + "variable 'single_int64' declared multiple times")); +} + +TEST(ContextDeclsWithProtoTypeMaskTest, NonOverlappingVariableDeclaration) { + TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(), + {}); + ASSERT_THAT(builder.AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int32"}), + IsOk()); + ASSERT_THAT(builder.AddVariable(MakeVariableDecl("single_int64", IntType())), + IsOk()); + + EXPECT_THAT(builder.Build(), IsOk()); +} + TEST(TypeCheckerBuilderImplTest, InvalidTypeParamNameVariableValidationDisabled) { CheckerOptions options; diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index 5dd1f5256..f145b8a98 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -17,6 +17,7 @@ #include #include +#include #include "absl/base/nullability.h" #include "absl/functional/any_invocable.h" @@ -102,6 +103,27 @@ class TypeCheckerBuilder { // Note: only protobuf backed struct types are supported at this time. virtual absl::Status AddContextDeclaration(absl::string_view type) = 0; + // Declares struct type by fully qualified name as a context declaration. + // + // This version accepts a mask in terms of field selections from the + // context type. The mask specifies which fields are visible on the + // struct and its members. The visible fields for a type accumulate + // across calls. This is a lightweight way to adjust the type checking + // behavior for a group of related types. + // + // Context declarations are a way to declare a group of variables based on the + // definition of a struct type. Each top level field of the struct that is + // also the first field name in a field path is declared as an individual + // variable of the field type. + // + // It is an error if the type contains a field that overlaps with another + // declared variable. It is an error if the input field paths is the empty + // set. + // + // Note: only protobuf backed struct types are supported at this time. + virtual absl::Status AddContextDeclarationWithProtoTypeMask( + absl::string_view type, std::vector field_paths) = 0; + // Adds a function declaration that may be referenced in expressions checked // with the resulting TypeChecker. virtual absl::Status AddFunction(const FunctionDecl& decl) = 0; diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index 38430de5f..9c4775e7f 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -396,6 +396,27 @@ TEST(TypeCheckerBuilderTest, AddContextDeclaration) { EXPECT_TRUE(result.IsValid()); } +TEST(TypeCheckerBuilderTest, AddContextDeclarationWithProtoTypeMask) { + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool())); + + ASSERT_OK_AND_ASSIGN( + auto fn_decl, + MakeFunctionDecl("increment", MakeOverloadDecl("increment_int", IntType(), + IntType()))); + + ASSERT_THAT(builder->AddContextDeclarationWithProtoTypeMask( + "cel.expr.conformance.proto3.TestAllTypes", {"single_int64"}), + IsOk()); + ASSERT_THAT(builder->AddFunction(fn_decl), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("increment(single_int64)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, checker->Check(std::move(ast))); + EXPECT_TRUE(result.IsValid()); +} + TEST(TypeCheckerBuilderTest, WellKnownTypeContextDeclarationError) { ASSERT_OK_AND_ASSIGN( std::unique_ptr builder, @@ -428,6 +449,32 @@ TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclaration) { ASSERT_TRUE(result.IsValid()); } +TEST(TypeCheckerBuilderTest, + AllowWellKnownTypeContextDeclarationWithProtoTypeMask) { + CheckerOptions options; + options.allow_well_known_type_context_declarations = true; + ASSERT_OK_AND_ASSIGN( + std::unique_ptr builder, + CreateTypeCheckerBuilder(GetSharedTestingDescriptorPool(), options)); + + ASSERT_THAT(builder->AddContextDeclarationWithProtoTypeMask( + "google.protobuf.Any", {"value"}), + IsOk()); + ASSERT_THAT(builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr type_checker, + builder->Build()); + // Visible field: value + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("value")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, + type_checker->Check(std::move(ast))); + ASSERT_TRUE(result.IsValid()); + // Not visible field: type_url + ASSERT_OK_AND_ASSIGN(ast, MakeTestParsedAst("type_url")); + ASSERT_OK_AND_ASSIGN(result, type_checker->Check(std::move(ast))); + ASSERT_FALSE(result.IsValid()); +} + TEST(TypeCheckerBuilderTest, AllowWellKnownTypeContextDeclarationStruct) { CheckerOptions options; options.allow_well_known_type_context_declarations = true; From a721a9a8d287edd074cd7c4fd8157aef9fa3074a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 9 Jun 2026 13:28:55 -0700 Subject: [PATCH 56/87] Add support for old style variable format in env yaml parser. PiperOrigin-RevId: 929372359 --- env/env_yaml.cc | 10 ++++++++-- env/env_yaml_test.cc | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 8c635e65f..1bbfe6b36 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -620,8 +620,14 @@ absl::Status ParseVariableConfigs(Config& config, absl::string_view yaml, } variable_config.description = GetString(yaml, description); } - - CEL_ASSIGN_OR_RETURN(auto type_info, ParseTypeInfo(variable, yaml)); + const YAML::Node type = variable["type"]; + Config::TypeInfo type_info; + if (type.IsDefined() && !type.IsScalar()) { + // Old format, type spec is in 'type' instead of directly embedded. + CEL_ASSIGN_OR_RETURN(type_info, ParseTypeInfo(variable["type"], yaml)); + } else { + CEL_ASSIGN_OR_RETURN(type_info, ParseTypeInfo(variable, yaml)); + } ConstantKindCase constant_kind_case = GetConstantKindCase(type_info.name); std::string value_str; YAML::Node value = variable["value"]; diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index f6bde59c9..a60048617 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -242,6 +242,24 @@ TEST(EnvYamlTest, ParseVariableConfigWithTypeParamsLegacySyntax) { EXPECT_THAT(type_info.params[1].params, IsEmpty()); } +TEST(EnvYamlTest, ParseVariableConfigWithNestedRuleOldFormat) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "x" + type: + type_name: "int" + )yaml")); + + ASSERT_THAT(config.GetVariableConfigs(), SizeIs(1)); + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "x"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "int"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, IsEmpty()); +} + struct ParseConstantTestCase { std::string type; std::string value; From d301931cafb674699daa54d922691ccddc6c55ba Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 9 Jun 2026 16:19:27 -0700 Subject: [PATCH 57/87] internal PiperOrigin-RevId: 929463453 --- env/env.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/env/env.cc b/env/env.cc index 42652ce59..6cd3a3cdc 100644 --- a/env/env.cc +++ b/env/env.cc @@ -122,6 +122,7 @@ absl::StatusOr FunctionConfigToFunctionDecl( Env::Env() { compiler_options_.parser_options.enable_quoted_identifiers = true; + compiler_options_.adapt_parser_errors = true; } absl::StatusOr> Env::NewCompilerBuilder() { From 167e797bdd017b074c1f4fe4a80598d2fa1e1613 Mon Sep 17 00:00:00 2001 From: Antoine Pietri Date: Thu, 11 Jun 2026 00:28:02 -0700 Subject: [PATCH 58/87] Fix source range offsets for lists, maps and messages. Update the parser to derive the source range for list, map, and struct/message creation expressions from the full parser rule context rather than just the opening token. This ensures the recorded offsets in `EnrichedSourceInfo` span the entire composite expression. PiperOrigin-RevId: 930338742 --- parser/parser.cc | 6 +++--- parser/parser_test.cc | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/parser/parser.cc b/parser/parser.cc index 6c6434319..a858337a4 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -1104,7 +1104,7 @@ std::any ParserVisitor::visitCreateMessage( } else { name = absl::StrJoin(parts, "."); } - int64_t obj_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + int64_t obj_id = factory_.NextId(SourceRangeFromParserRuleContext(ctx)); std::vector fields; if (ctx->entries) { fields = visitFields(ctx->entries); @@ -1206,7 +1206,7 @@ std::any ParserVisitor::visitNested(CelParser::NestedContext* ctx) { } std::any ParserVisitor::visitCreateList(CelParser::CreateListContext* ctx) { - int64_t list_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + int64_t list_id = factory_.NextId(SourceRangeFromParserRuleContext(ctx)); auto elems = visitList(ctx->elems); return ExprToAny(factory_.NewList(list_id, std::move(elems))); } @@ -1244,7 +1244,7 @@ std::vector ParserVisitor::visitList(CelParser::ExprListContext* ctx) { } std::any ParserVisitor::visitCreateMap(CelParser::CreateMapContext* ctx) { - int64_t struct_id = factory_.NextId(SourceRangeFromToken(ctx->op)); + int64_t struct_id = factory_.NextId(SourceRangeFromParserRuleContext(ctx)); std::vector entries; if (ctx->entries) { entries = visitEntries(ctx->entries); diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 33c52b1d2..1add80f84 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1520,6 +1520,29 @@ TEST_P(ExpressionTest, Parse) { } } +TEST(ExpressionTest, CompositeExpressionOffsets) { + ParserOptions options; + std::vector macros = Macro::AllMacros(); + + std::string list_expr = "[1, 2]"; + auto list_result = EnrichedParse(list_expr, macros, "", options); + ASSERT_THAT(list_result, IsOk()); + auto list_offsets = list_result->enriched_source_info().offsets(); + EXPECT_EQ(list_offsets.at(1), std::make_pair(0, 5)); + + std::string map_expr = "{'a': 1}"; + auto map_result = EnrichedParse(map_expr, macros, "", options); + ASSERT_THAT(map_result, IsOk()); + auto map_offsets = map_result->enriched_source_info().offsets(); + EXPECT_EQ(map_offsets.at(1), std::make_pair(0, 7)); + + std::string msg_expr = "Msg{f: 1}"; + auto msg_result = EnrichedParse(msg_expr, macros, "", options); + ASSERT_THAT(msg_result, IsOk()); + auto msg_offsets = msg_result->enriched_source_info().offsets(); + EXPECT_EQ(msg_offsets.at(1), std::make_pair(0, 8)); +} + TEST(ExpressionTest, TsanOom) { Parse( "[[a([[???[a[[??[a([[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[" From e51522fe9eaea8d5005ee58b4385b16ee642146f Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 11 Jun 2026 10:53:19 -0700 Subject: [PATCH 59/87] Update abbreviation / import validation. More closely match the validation in the go and java implementations. PiperOrigin-RevId: 930622025 --- common/container.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/common/container.cc b/common/container.cc index f69f0cc80..e1db8f86c 100644 --- a/common/container.cc +++ b/common/container.cc @@ -19,6 +19,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "internal/lexis.h" @@ -92,9 +93,11 @@ absl::Status ExpressionContainer::SetContainer(absl::string_view name) { } absl::Status ExpressionContainer::AddAbbreviation(absl::string_view abrev) { + abrev = absl::StripAsciiWhitespace(abrev); if (!IsValidQualifiedName(abrev)) { return absl::InvalidArgumentError( - absl::StrCat("invalid qualified name: ", abrev)); + absl::StrCat("invalid qualified name: ", abrev, + ", wanted name of the form 'qualified.name'")); } auto pos = abrev.rfind('.'); From 6975536ba104d314fcaab2cca1cfa179239554c9 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 11 Jun 2026 11:02:18 -0700 Subject: [PATCH 60/87] Fix variadic logical operator planning PiperOrigin-RevId: 930627361 --- conformance/BUILD | 23 +++++-- conformance/run.bzl | 17 ++++-- conformance/run.cc | 5 ++ conformance/service.cc | 46 +++++++++----- conformance/service.h | 1 + eval/compiler/BUILD | 1 + eval/compiler/flat_expr_builder.cc | 80 +++++++++++++------------ eval/compiler/flat_expr_builder_test.cc | 7 ++- 8 files changed, 113 insertions(+), 67 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index a6f25e001..35d554c7b 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -32,7 +32,6 @@ cc_library( "//common:ast", "//common:ast_proto", "//common:decl_proto_v1alpha1", - "//common:expr", "//common:source", "//common:value", "//common/internal:value_conversion", @@ -57,8 +56,6 @@ cc_library( "//extensions/protobuf:enum_adapter", "//internal:status_macros", "//parser", - "//parser:macro", - "//parser:macro_expr_factory", "//parser:macro_registry", "//parser:options", "//parser:standard_macros", @@ -75,8 +72,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", @@ -302,6 +297,24 @@ gen_conformance_tests( skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED, ) +gen_conformance_tests( + name = "conformance_variadic", + checked = True, + data = _ALL_TESTS, + enable_variadic_logical_operators = True, + modern = True, + skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED, +) + +gen_conformance_tests( + name = "conformance_legacy_variadic", + checked = True, + data = _ALL_TESTS, + enable_variadic_logical_operators = True, + modern = False, + skip_tests = _TESTS_TO_SKIP_LEGACY + _TESTS_TO_SKIP_CHECKED, +) + # Generates a bunch of `cc_test` whose names follow the pattern # `conformance_dashboard_..._{arena|refcount}_{optimized|unoptimized}_{recursive|iterative}`. gen_conformance_tests( diff --git a/conformance/run.bzl b/conformance/run.bzl index 2c0b51c0e..8faeb6c16 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -56,7 +56,7 @@ def _conformance_test_name(name, optimize, recursive): ], ) -def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard): +def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators): args = [] if modern: args.append("--modern") @@ -72,12 +72,14 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, args.append("--noskip_check") if dashboard: args.append("--dashboard") + if enable_variadic_logical_operators: + args.append("--enable_variadic_logical_operators") return args -def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard): +def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard, enable_variadic_logical_operators): cc_test( name = _conformance_test_name(name, optimize, recursive), - args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard) + ["$(rlocationpath {})".format(test) for test in data], + args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators) + ["$(rlocationpath {})".format(test) for test in data], env = select( { "@platforms//os:windows": {"CEL_SKIP_TESTS": ",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS)}, @@ -89,18 +91,20 @@ def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_ tags = tags, ) -def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = []): +def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = [], enable_variadic_logical_operators = False): """Generates conformance tests. Args: name: prefix for all tests + data: textproto targets describing conformance tests modern: run using modern APIs checked: whether to apply type checking - data: textproto targets describing conformance tests + select_opt: enable select optimization + dashboard: enable dashboard mode skip_tests: tests to skip in the format of the cel-spec test runner. See documentation in github.com/google/cel-spec/tests/simple/simple_test.go tags: tags added to the generated targets - dashboard: enable dashboard mode + enable_variadic_logical_operators: enable variadic logical operators """ skip_check = not checked tests = [] @@ -119,6 +123,7 @@ def gen_conformance_tests(name, data, modern = False, checked = False, select_op skip_tests = _expand_tests_to_skip(skip_tests), tags = tags, dashboard = dashboard, + enable_variadic_logical_operators = enable_variadic_logical_operators, ) native.test_suite( name = name, diff --git a/conformance/run.cc b/conformance/run.cc index 4a0493494..1be16ba60 100644 --- a/conformance/run.cc +++ b/conformance/run.cc @@ -66,6 +66,9 @@ ABSL_FLAG(std::vector, skip_tests, {}, "Tests to skip"); ABSL_FLAG(bool, dashboard, false, "Dashboard mode, ignore test failures"); ABSL_FLAG(bool, skip_check, true, "Skip type checking the expressions"); ABSL_FLAG(bool, select_optimization, false, "Enable select optimization."); +ABSL_FLAG(bool, enable_variadic_logical_operators, false, + "Enable parsing logical AND & OR operators as a single flat variadic " + "call."); namespace { @@ -261,6 +264,8 @@ NewConformanceServiceFromFlags() { .modern = absl::GetFlag(FLAGS_modern), .recursive = absl::GetFlag(FLAGS_recursive), .select_optimization = absl::GetFlag(FLAGS_select_optimization), + .enable_variadic_logical_operators = + absl::GetFlag(FLAGS_enable_variadic_logical_operators), }); ABSL_CHECK_OK(status_or_service); return std::shared_ptr( diff --git a/conformance/service.cc b/conformance/service.cc index 7e3eded82..d81200cad 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -128,13 +128,15 @@ cel::expr::Expr ExtractExpr( absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response, - bool enable_optional_syntax) { + bool enable_optional_syntax, + bool enable_variadic_logical_operators) { if (request.cel_source().empty()) { return absl::InvalidArgumentError("no source code"); } cel::ParserOptions options; options.enable_optional_syntax = enable_optional_syntax; options.enable_quoted_identifiers = true; + options.enable_variadic_logical_operators = enable_variadic_logical_operators; cel::MacroRegistry macros; CEL_RETURN_IF_ERROR(cel::RegisterStandardMacros(macros, options)); CEL_RETURN_IF_ERROR( @@ -236,7 +238,8 @@ absl::Status CheckImpl(google::protobuf::Arena* arena, class LegacyConformanceServiceImpl : public ConformanceServiceInterface { public: static absl::StatusOr> Create( - bool optimize, bool recursive, bool select_optimization) { + bool optimize, bool recursive, bool select_optimization, + bool enable_variadic_logical_operators) { static auto* constant_arena = new Arena(); google::protobuf::LinkMessageReflection< @@ -313,14 +316,15 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathExtensionFunctions( builder->GetRegistry(), options)); - return absl::WrapUnique( - new LegacyConformanceServiceImpl(std::move(builder))); + return absl::WrapUnique(new LegacyConformanceServiceImpl( + std::move(builder), enable_variadic_logical_operators)); } void Parse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response) override { auto status = - LegacyParse(request, response, /*enable_optional_syntax=*/false); + LegacyParse(request, response, /*enable_optional_syntax=*/false, + enable_variadic_logical_operators_); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); @@ -418,17 +422,20 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { } private: - explicit LegacyConformanceServiceImpl( - std::unique_ptr builder) - : builder_(std::move(builder)) {} + LegacyConformanceServiceImpl(std::unique_ptr builder, + bool enable_variadic_logical_operators) + : builder_(std::move(builder)), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} std::unique_ptr builder_; + bool enable_variadic_logical_operators_; }; class ModernConformanceServiceImpl : public ConformanceServiceInterface { public: static absl::StatusOr> Create( - bool optimize, bool recursive, bool select_optimization) { + bool optimize, bool recursive, bool select_optimization, + bool enable_variadic_logical_operators) { google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::TestAllTypes>(); google::protobuf::LinkMessageReflection< @@ -470,8 +477,9 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { options.max_recursion_depth = 48; } - return absl::WrapUnique(new ModernConformanceServiceImpl( - options, optimize, select_optimization)); + return absl::WrapUnique( + new ModernConformanceServiceImpl(options, optimize, select_optimization, + enable_variadic_logical_operators)); } absl::StatusOr> Setup( @@ -523,7 +531,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { void Parse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response) override { auto status = - LegacyParse(request, response, /*enable_optional_syntax=*/true); + LegacyParse(request, response, /*enable_optional_syntax=*/true, + enable_variadic_logical_operators_); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); @@ -614,10 +623,12 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { private: ModernConformanceServiceImpl(const RuntimeOptions& options, bool enable_optimizations, - bool enable_select_optimization) + bool enable_select_optimization, + bool enable_variadic_logical_operators) : options_(options), enable_optimizations_(enable_optimizations), - enable_select_optimization_(enable_select_optimization) {} + enable_select_optimization_(enable_select_optimization), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} static absl::StatusOr> Plan( const cel::Runtime& runtime, @@ -648,6 +659,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { RuntimeOptions options_; bool enable_optimizations_; bool enable_select_optimization_; + bool enable_variadic_logical_operators_; }; } // namespace @@ -660,10 +672,12 @@ absl::StatusOr> NewConformanceService(const ConformanceServiceOptions& options) { if (options.modern) { return google::api::expr::runtime::ModernConformanceServiceImpl::Create( - options.optimize, options.recursive, options.select_optimization); + options.optimize, options.recursive, options.select_optimization, + options.enable_variadic_logical_operators); } else { return google::api::expr::runtime::LegacyConformanceServiceImpl::Create( - options.optimize, options.recursive, options.select_optimization); + options.optimize, options.recursive, options.select_optimization, + options.enable_variadic_logical_operators); } } diff --git a/conformance/service.h b/conformance/service.h index 2dd2abf32..8eb97296e 100644 --- a/conformance/service.h +++ b/conformance/service.h @@ -46,6 +46,7 @@ struct ConformanceServiceOptions { bool arena; bool recursive; bool select_optimization; + bool enable_variadic_logical_operators = false; }; absl::StatusOr> diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index ed8e4d20c..f7300cb58 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -193,6 +193,7 @@ cc_test( "//internal:status_macros", "//internal:testing", "//parser", + "//parser:options", "//runtime:function", "//runtime:function_adapter", "//runtime:runtime_options", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index d6ccdf040..fc6d87b16 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -2154,7 +2154,7 @@ void BinaryCondVisitor::PreVisit(const cel::Expr* expr) { case BinaryCond::kOr: visitor_->ValidateOrError( !expr->call_expr().has_target() && - expr->call_expr().args().size() == 2, + expr->call_expr().args().size() >= 2, "Invalid argument count for a binary function call."); break; case BinaryCond::kOptionalOr: @@ -2172,28 +2172,40 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { return; } const int last_arg_index = expr->call_expr().args().size() - 1; - if (short_circuiting_ && arg_num < last_arg_index && - (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) { - // If first branch evaluation result is enough to determine output, - // jump over the second branch and provide result of the first argument as - // final output. - // Retain pointers to the jump steps so we can update the target after - // planning the next arguments. - std::unique_ptr jump_step; - switch (cond_) { - case BinaryCond::kAnd: - jump_step = CreateCondJumpStep(false, true, {}, expr->id()); - break; - case BinaryCond::kOr: - jump_step = CreateCondJumpStep(true, true, {}, expr->id()); - break; - default: - ABSL_UNREACHABLE(); + if (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr) { + if (arg_num > 0) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->AddStep(CreateAndStep(expr->id())); + break; + case BinaryCond::kOr: + visitor_->AddStep(CreateOrStep(expr->id())); + break; + default: + break; + } + if (short_circuiting_ && !jump_steps_.empty()) { + visitor_->SetProgressStatusIfError( + jump_steps_.back().set_target(visitor_->GetCurrentIndex())); + } } - ProgramStepIndex index = visitor_->GetCurrentIndex(); - if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); - jump_step_ptr) { - jump_steps_.push_back(Jump(index, jump_step_ptr)); + if (short_circuiting_ && arg_num < last_arg_index) { + std::unique_ptr jump_step; + switch (cond_) { + case BinaryCond::kAnd: + jump_step = CreateCondJumpStep(false, true, {}, expr->id()); + break; + case BinaryCond::kOr: + jump_step = CreateCondJumpStep(true, true, {}, expr->id()); + break; + default: + ABSL_UNREACHABLE(); + } + ProgramStepIndex index = visitor_->GetCurrentIndex(); + if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); + jump_step_ptr) { + jump_steps_.push_back(Jump(index, jump_step_ptr)); + } } } } @@ -2251,17 +2263,9 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { return; } - int args_count = (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr) - ? expr->call_expr().args().size() - : 2; - for (int i = 0; i < args_count - 1; ++i) { + if (cond_ == BinaryCond::kOptionalOr || + cond_ == BinaryCond::kOptionalOrValue) { switch (cond_) { - case BinaryCond::kAnd: - visitor_->AddStep(CreateAndStep(expr->id())); - break; - case BinaryCond::kOr: - visitor_->AddStep(CreateOrStep(expr->id())); - break; case BinaryCond::kOptionalOr: visitor_->AddStep( CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); @@ -2273,13 +2277,11 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { default: ABSL_UNREACHABLE(); } - } - if (short_circuiting_) { - // If short-circuiting is enabled, point the conditional jump past the - // boolean operator step. - for (auto& jump : jump_steps_) { - visitor_->SetProgressStatusIfError( - jump.set_target(visitor_->GetCurrentIndex())); + if (short_circuiting_) { + for (auto& jump : jump_steps_) { + visitor_->SetProgressStatusIfError( + jump.set_target(visitor_->GetCurrentIndex())); + } } } } diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index e2581e3fd..105060282 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -64,6 +64,7 @@ #include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "parser/options.h" #include "parser/parser.h" #include "runtime/function.h" #include "runtime/function_adapter.h" @@ -2916,7 +2917,11 @@ class FlatExprBuilderVariadicLogicalTest TEST_P(FlatExprBuilderVariadicLogicalTest, Evaluate) { const auto& test_case = GetParam(); - ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(test_case.expr)); + parser::ParserOptions parser_options; + parser_options.enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN( + ParsedExpr parsed_expr, + parser::Parse(test_case.expr, test_case.label, parser_options)); cel::RuntimeOptions options; options.unknown_processing = From c0e43073ea7f4fbbcdc59b45f34abe8bd5b9364a Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 11 Jun 2026 11:22:17 -0700 Subject: [PATCH 61/87] refactor: Move type formatter to common. PiperOrigin-RevId: 930639404 --- checker/internal/BUILD | 15 ++---------- checker/internal/type_checker_impl.cc | 2 +- checker/internal/type_inference_context.cc | 18 +++++++-------- common/BUILD | 23 +++++++++++++++++++ .../internal => common}/format_type_name.cc | 6 ++--- .../internal => common}/format_type_name.h | 10 ++++---- .../format_type_name_test.cc | 6 ++--- 7 files changed, 46 insertions(+), 34 deletions(-) rename {checker/internal => common}/format_type_name.cc (97%) rename {checker/internal => common}/format_type_name.h (74%) rename {checker/internal => common}/format_type_name_test.cc (97%) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 777457830..20c476db2 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -130,7 +130,6 @@ cc_library( "type_checker_impl.h", ], deps = [ - ":format_type_name", ":namespace_generator", ":proto_type_mask", ":type_check_env", @@ -149,6 +148,7 @@ cc_library( "//common:container", "//common:decl", "//common:expr", + "//common:format_type_name", "//common:standard_definitions", "//common:type", "//common:type_kind", @@ -243,8 +243,8 @@ cc_library( srcs = ["type_inference_context.cc"], hdrs = ["type_inference_context.h"], deps = [ - ":format_type_name", "//common:decl", + "//common:format_type_name", "//common:standard_definitions", "//common:type", "//common:type_kind", @@ -275,17 +275,6 @@ cc_test( ], ) -cc_library( - name = "format_type_name", - srcs = ["format_type_name.cc"], - hdrs = ["format_type_name.h"], - deps = [ - "//common:type", - "//common:type_kind", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "descriptor_pool_type_introspector", srcs = ["descriptor_pool_type_introspector.cc"], diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 6b6b051b1..f3a06a28d 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -36,7 +36,6 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "checker/checker_options.h" -#include "checker/internal/format_type_name.h" #include "checker/internal/namespace_generator.h" #include "checker/internal/type_check_env.h" #include "checker/internal/type_checker_builder_impl.h" @@ -52,6 +51,7 @@ #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" +#include "common/format_type_name.h" #include "common/standard_definitions.h" #include "common/type.h" #include "common/type_kind.h" diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 1a87d9e15..4681784af 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -28,8 +28,8 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" -#include "checker/internal/format_type_name.h" #include "common/decl.h" +#include "common/format_type_name.h" #include "common/standard_definitions.h" #include "common/type.h" #include "common/type_kind.h" @@ -657,14 +657,14 @@ bool TypeInferenceContext::AssignabilityContext::IsAssignable(const Type& from, std::string TypeInferenceContext::DebugString() const { return absl::StrCat( "type_parameter_bindings: ", - absl::StrJoin( - type_parameter_bindings_, "\n ", - [](std::string* out, const auto& binding) { - absl::StrAppend( - out, binding.first, " (", binding.second.name, ") -> ", - checker_internal::FormatTypeName( - binding.second.type.value_or(Type(TypeParamType("none"))))); - })); + absl::StrJoin(type_parameter_bindings_, "\n ", + [](std::string* out, const auto& binding) { + absl::StrAppend( + out, binding.first, " (", binding.second.name, + ") -> ", + cel::FormatTypeName(binding.second.type.value_or( + Type(TypeParamType("none"))))); + })); } void TypeInferenceContext::AssignabilityContext:: diff --git a/common/BUILD b/common/BUILD index f7c897e57..93410306f 100644 --- a/common/BUILD +++ b/common/BUILD @@ -601,6 +601,17 @@ cc_library( ], ) +cc_library( + name = "format_type_name", + srcs = ["format_type_name.cc"], + hdrs = ["format_type_name.h"], + deps = [ + ":type", + ":type_kind", + "@com_google_absl//absl/strings", + ], +) + cc_test( name = "type_test", srcs = glob([ @@ -623,6 +634,18 @@ cc_test( ], ) +cc_test( + name = "format_type_name_test", + srcs = ["format_type_name_test.cc"], + deps = [ + ":format_type_name", + ":type", + "//internal:testing", + "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "value", srcs = glob( diff --git a/checker/internal/format_type_name.cc b/common/format_type_name.cc similarity index 97% rename from checker/internal/format_type_name.cc rename to common/format_type_name.cc index 7cd17251f..4bd6c2e61 100644 --- a/checker/internal/format_type_name.cc +++ b/common/format_type_name.cc @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "checker/internal/format_type_name.h" +#include "common/format_type_name.h" #include #include @@ -20,7 +20,7 @@ #include "common/type.h" #include "common/type_kind.h" -namespace cel::checker_internal { +namespace cel { namespace { struct FormatImplRecord { @@ -177,4 +177,4 @@ std::string FormatTypeName(const Type& type) { return out; } -} // namespace cel::checker_internal +} // namespace cel diff --git a/checker/internal/format_type_name.h b/common/format_type_name.h similarity index 74% rename from checker/internal/format_type_name.h rename to common/format_type_name.h index c31e1c4d0..723ac20fd 100644 --- a/checker/internal/format_type_name.h +++ b/common/format_type_name.h @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_FORMAT_TYPE_NAME_H_ -#define THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_FORMAT_TYPE_NAME_H_ +#ifndef THIRD_PARTY_CEL_CPP_COMMON_FORMAT_TYPE_NAME_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_FORMAT_TYPE_NAME_H_ #include #include "common/type.h" -namespace cel::checker_internal { +namespace cel { // Format the type name for presentation in error messages. Matches the // formatting used in github.com/cel-spec. std::string FormatTypeName(const Type& type); -} // namespace cel::checker_internal +} // namespace cel -#endif // THIRD_PARTY_CEL_CPP_CHECKER_INTERNAL_FORMAT_TYPE_NAME_H_ +#endif // THIRD_PARTY_CEL_CPP_COMMON_FORMAT_TYPE_NAME_H_ diff --git a/checker/internal/format_type_name_test.cc b/common/format_type_name_test.cc similarity index 97% rename from checker/internal/format_type_name_test.cc rename to common/format_type_name_test.cc index ff04e04d2..ca63f60b0 100644 --- a/checker/internal/format_type_name_test.cc +++ b/common/format_type_name_test.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "checker/internal/format_type_name.h" +#include "common/format_type_name.h" #include "common/type.h" #include "internal/testing.h" #include "cel/expr/conformance/proto2/test_all_types.pb.h" #include "google/protobuf/arena.h" -namespace cel::checker_internal { +namespace cel { namespace { using ::cel::expr::conformance::proto2::GlobalEnum_descriptor; @@ -115,4 +115,4 @@ TEST(FormatTypeNameTest, ArbitraryNesting) { #endif } // namespace -} // namespace cel::checker_internal +} // namespace cel From d2a8a7831b76fb128de4d03254bc8f81dfcf5652 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 11 Jun 2026 12:12:47 -0700 Subject: [PATCH 62/87] Create a `proto_to_predicate` compiler for converting proto messages into CEL expressions PiperOrigin-RevId: 930668150 --- common/expr_factory.h | 5 + internal/json.h | 4 +- tools/BUILD | 50 +++ tools/proto_to_predicate.cc | 459 ++++++++++++++++++++++++ tools/proto_to_predicate.h | 48 +++ tools/proto_to_predicate_test.cc | 593 +++++++++++++++++++++++++++++++ tools/testdata/BUILD | 19 +- tools/testdata/test_policy.proto | 73 ++++ 8 files changed, 1245 insertions(+), 6 deletions(-) create mode 100644 tools/proto_to_predicate.cc create mode 100644 tools/proto_to_predicate.h create mode 100644 tools/proto_to_predicate_test.cc create mode 100644 tools/testdata/test_policy.proto diff --git a/common/expr_factory.h b/common/expr_factory.h index 5607d8deb..757318545 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -34,6 +34,10 @@ class MacroExprFactory; class ParserMacroExprFactory; class OptimizerExprFactory; +namespace tools { +class ProtoToPredicateBuilder; +} + class ExprFactory { protected: // `IsExprLike` determines whether `T` is some `Expr`. Currently that means @@ -380,6 +384,7 @@ class ExprFactory { friend class MacroExprFactory; friend class ParserMacroExprFactory; friend class OptimizerExprFactory; + friend class tools::ProtoToPredicateBuilder; ExprFactory() : accu_var_(kAccumulatorVariableName) {} diff --git a/internal/json.h b/internal/json.h index d32c42741..e35909d0e 100644 --- a/internal/json.h +++ b/internal/json.h @@ -26,7 +26,7 @@ namespace cel::internal { // Converts the given message to its `google.protobuf.Value` equivalent -// representation. This is similar to `proto2::json::MessageToJsonString()`, +// representation. This is similar to `google::protobuf::json::MessageToJsonString()`, // except that this results in structured serialization. absl::Status MessageToJson( const google::protobuf::Message& message, @@ -45,7 +45,7 @@ absl::Status MessageToJson( google::protobuf::Message* absl_nonnull result); // Converts the given message field to its `google.protobuf.Value` equivalent -// representation. This is similar to `proto2::json::MessageToJsonString()`, +// representation. This is similar to `google::protobuf::json::MessageToJsonString()`, // except that this results in structured serialization. absl::Status MessageFieldToJson( const google::protobuf::Message& message, diff --git a/tools/BUILD b/tools/BUILD index ceb2befc5..af006a67b 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -204,6 +204,56 @@ cc_library( ], ) +cc_library( + name = "proto_to_predicate", + srcs = ["proto_to_predicate.cc"], + hdrs = ["proto_to_predicate.h"], + deps = [ + "//common:ast", + "//common:expr", + "//common:expr_factory", + "//common:operators", + "//internal:status_macros", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_to_predicate_test", + srcs = ["proto_to_predicate_test.cc"], + deps = [ + ":cel_unparser", + ":proto_to_predicate", + "//common:ast", + "//common:ast_proto", + "//common:value", + "//env:config", + "//env:env_runtime", + "//env:env_yaml", + "//env:runtime_std_extensions", + "//eval/testutil:test_message_cc_proto", + "//extensions/protobuf:value", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//runtime", + "//runtime:activation", + "//tools/testdata:test_policy_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "descriptor_pool_builder_test", srcs = ["descriptor_pool_builder_test.cc"], diff --git a/tools/proto_to_predicate.cc b/tools/proto_to_predicate.cc new file mode 100644 index 000000000..8c89ee2f0 --- /dev/null +++ b/tools/proto_to_predicate.cc @@ -0,0 +1,459 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/proto_to_predicate.h" + +#include +#include +#include +#include +#include +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/log/absl_log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/operators.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/reflection.h" + +namespace cel::tools { + +using ::google::api::expr::common::CelOperator; +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::Message; +using ::google::protobuf::Reflection; + +class ProtoToPredicateBuilder final : private ExprFactory { + public: + ProtoToPredicateBuilder() : id_(1) {} + + absl::StatusOr Build(absl::string_view input_name, + const Message& message) { + std::vector predicates; + Expr base_expr = NewIdent(NextId(), input_name); + + CEL_RETURN_IF_ERROR(Walk(message, base_expr, predicates)); + Expr root = LogicalAnd(predicates); + return Ast(std::move(root), std::move(source_info_)); + } + + absl::StatusOr Build(absl::string_view input_name, + absl::Span messages) { + if (messages.empty()) { + return Ast(NewBoolConst(NextId(), true), std::move(source_info_)); + } + + std::vector message_asts; + message_asts.reserve(messages.size()); + for (const auto* message : messages) { + std::vector predicates; + Expr base_expr = NewIdent(NextId(), input_name); + + CEL_RETURN_IF_ERROR(Walk(*message, base_expr, predicates)); + message_asts.push_back(LogicalAnd(predicates)); + } + + return Ast(LogicalOr(message_asts), std::move(source_info_)); + } + + private: + // Retrieves the "match_path" string option from the field options if + // defined, returning an empty string otherwise. + std::string GetMatchPath(const ::google::protobuf::FieldDescriptor* field) { + const ::google::protobuf::Message& options = field->options(); + const ::google::protobuf::Reflection* refl = options.GetReflection(); + std::vector fields; + refl->ListFields(options, &fields); + for (const auto* f : fields) { + if (f->name() == "match_path") { + return refl->GetString(options, f); + } + } + return ""; + } + + // Parses a dot-separated string representation of a path (e.g. "dest.region") + // and builds a corresponding select chain AST. + Expr ParseAndBuildPath(absl::string_view path_str) { + std::vector parts = absl::StrSplit(path_str, '.'); + Expr e = NewIdent(NextId(), parts[0]); + for (size_t i = 1; i < parts.size(); ++i) { + e = NewSelect(NextId(), std::move(e), parts[i]); + } + return e; + } + ExprId NextId() { return id_++; } + + // --------------------------------------------------------------------------- + // Field value extraction + // --------------------------------------------------------------------------- + + // Converts a singular field value to a CEL constant expression. + Expr PrimitiveToExpr(ExprId expr_id, const Message& message, + const Reflection* reflection, + const FieldDescriptor* field) { + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return NewIntConst(expr_id, reflection->GetInt32(message, field)); + case FieldDescriptor::CPPTYPE_INT64: + return NewIntConst(expr_id, reflection->GetInt64(message, field)); + case FieldDescriptor::CPPTYPE_UINT32: + return NewUintConst(expr_id, reflection->GetUInt32(message, field)); + case FieldDescriptor::CPPTYPE_UINT64: + return NewUintConst(expr_id, reflection->GetUInt64(message, field)); + case FieldDescriptor::CPPTYPE_DOUBLE: + return NewDoubleConst(expr_id, reflection->GetDouble(message, field)); + case FieldDescriptor::CPPTYPE_FLOAT: + return NewDoubleConst(expr_id, reflection->GetFloat(message, field)); + case FieldDescriptor::CPPTYPE_BOOL: + return NewBoolConst(expr_id, reflection->GetBool(message, field)); + case FieldDescriptor::CPPTYPE_ENUM: + return NewIntConst(expr_id, reflection->GetEnumValue(message, field)); + case FieldDescriptor::CPPTYPE_STRING: { + std::string str_val = reflection->GetString(message, field); + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return NewBytesConst(expr_id, std::move(str_val)); + } + return NewStringConst(expr_id, std::move(str_val)); + } + default: + // Log a warning as message should be handled by Walk. + ABSL_LOG(WARNING) << "PrimitiveToExpr: Unhandled field type: " + << FieldDescriptor::TypeName(field->type()); + break; + } + return NewNullConst(expr_id); + } + + Expr PrimitiveToExpr(const Message& message, const Reflection* reflection, + const FieldDescriptor* field) { + return PrimitiveToExpr(NextId(), message, reflection, field); + } + + // Converts a repeated field element to a CEL constant expression. + Expr RepeatedPrimitiveToExpr(const Message& message, + const Reflection* reflection, + const FieldDescriptor* field, int index) { + const ExprId id = NextId(); + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return NewIntConst(id, + reflection->GetRepeatedInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_INT64: + return NewIntConst(id, + reflection->GetRepeatedInt64(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT32: + return NewUintConst( + id, reflection->GetRepeatedUInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT64: + return NewUintConst( + id, reflection->GetRepeatedUInt64(message, field, index)); + case FieldDescriptor::CPPTYPE_DOUBLE: + return NewDoubleConst( + id, reflection->GetRepeatedDouble(message, field, index)); + case FieldDescriptor::CPPTYPE_FLOAT: + return NewDoubleConst( + id, reflection->GetRepeatedFloat(message, field, index)); + case FieldDescriptor::CPPTYPE_BOOL: + return NewBoolConst(id, + reflection->GetRepeatedBool(message, field, index)); + case FieldDescriptor::CPPTYPE_ENUM: + return NewIntConst( + id, reflection->GetRepeatedEnumValue(message, field, index)); + case FieldDescriptor::CPPTYPE_STRING: { + std::string str_val = + reflection->GetRepeatedString(message, field, index); + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return NewBytesConst(id, std::move(str_val)); + } + return NewStringConst(id, std::move(str_val)); + } + default: + break; + } + return NewNullConst(id); + } + + // --------------------------------------------------------------------------- + // Expression construction helpers + // --------------------------------------------------------------------------- + + // Creates a binary operator call: `lhs rhs`. + Expr ConstructBinaryOp(absl::string_view op, Expr lhs, Expr rhs) { + std::vector args = {std::move(lhs), std::move(rhs)}; + return NewCall(NextId(), op, std::move(args)); + } + + Expr ConstructEquality(Expr lhs, Expr rhs) { + return ConstructBinaryOp(CelOperator::EQUALS, std::move(lhs), + std::move(rhs)); + } + + Expr LogicalOr(std::vector& exprs) { + return LogicalOp(CelOperator::LOGICAL_OR, exprs); + } + + Expr LogicalAnd(std::vector& exprs) { + return LogicalOp(CelOperator::LOGICAL_AND, exprs); + } + + // Left-folds a vector of expressions with a binary operator. + // Requires: `exprs` is non-empty. + Expr LogicalOp(absl::string_view op, std::vector& exprs) { + if (exprs.empty()) { + return NewBoolConst(NextId(), true); + } + if (exprs.size() == 1) { + return std::move(exprs[0]); + } + return NewCall(NextId(), op, std::move(exprs)); + } + + // --------------------------------------------------------------------------- + // Map field predicate (extracted from Walk for readability) + // --------------------------------------------------------------------------- + + // Builds the predicate for a map field to assert that all key-value pairs + // specified in the policy are present in the input map field: + // "key" in input.map && input.map["key"] == value + absl::Status WalkMapField(const Reflection* reflection, + const Message& message, + const FieldDescriptor* field, const Expr& base_expr, + int size, std::vector& predicates) { + const FieldDescriptor* const key_field = + field->message_type()->FindFieldByName("key"); + const FieldDescriptor* const value_field = + field->message_type()->FindFieldByName("value"); + + Expr map_path = NewSelect(NextId(), base_expr, field->name()); + + struct MapEntry { + const Message* message; + }; + std::vector entries; + entries.reserve(size); + for (int i = 0; i < size; ++i) { + entries.push_back({&reflection->GetRepeatedMessage(message, field, i)}); + } + + if (!entries.empty()) { + const Reflection* const entry_ref = entries[0].message->GetReflection(); + std::sort(entries.begin(), entries.end(), + [entry_ref, key_field](const MapEntry& a, const MapEntry& b) { + switch (key_field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return entry_ref->GetInt32(*a.message, key_field) < + entry_ref->GetInt32(*b.message, key_field); + case FieldDescriptor::CPPTYPE_INT64: + return entry_ref->GetInt64(*a.message, key_field) < + entry_ref->GetInt64(*b.message, key_field); + case FieldDescriptor::CPPTYPE_UINT32: + return entry_ref->GetUInt32(*a.message, key_field) < + entry_ref->GetUInt32(*b.message, key_field); + case FieldDescriptor::CPPTYPE_UINT64: + return entry_ref->GetUInt64(*a.message, key_field) < + entry_ref->GetUInt64(*b.message, key_field); + case FieldDescriptor::CPPTYPE_BOOL: + return !entry_ref->GetBool(*a.message, key_field) && + entry_ref->GetBool(*b.message, key_field); + case FieldDescriptor::CPPTYPE_STRING: + return entry_ref->GetString(*a.message, key_field) < + entry_ref->GetString(*b.message, key_field); + default: + return false; + } + }); + } + + std::vector map_checks; + map_checks.reserve(size); + for (const auto& entry : entries) { + const Message& entry_msg = *entry.message; + const Reflection* const entry_ref = entry_msg.GetReflection(); + + Expr key_expr = PrimitiveToExpr(entry_msg, entry_ref, key_field); + + // Represents `"key" in input.map` to assert the key exists. + Expr in_check = NewCall(NextId(), CelOperator::IN, + std::vector{key_expr, map_path}); + // Represents `input.map["key"]` to lookup the value. + Expr lookup_path = NewCall(NextId(), CelOperator::INDEX, + std::vector{map_path, key_expr}); + + if (value_field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + const Message& value_msg = + entry_ref->GetMessage(entry_msg, value_field); + std::vector val_predicates; + CEL_RETURN_IF_ERROR(Walk(value_msg, lookup_path, val_predicates)); + + if (!val_predicates.empty()) { + // Represents `"key" in input.map && (nested message fields check...)` + map_checks.push_back(std::move(in_check)); + map_checks.insert(map_checks.end(), + std::make_move_iterator(val_predicates.begin()), + std::make_move_iterator(val_predicates.end())); + } else { + // Represents `"key" in input.map` if nested message is empty. + map_checks.push_back(std::move(in_check)); + } + } else { + Expr value_expr = PrimitiveToExpr(entry_msg, entry_ref, value_field); + // Represents `input.map["key"] == value` + Expr eq_check = + ConstructEquality(std::move(lookup_path), std::move(value_expr)); + + // Represents `"key" in input.map && input.map["key"] == value` + map_checks.push_back(std::move(in_check)); + map_checks.push_back(std::move(eq_check)); + } + } + + predicates.push_back(LogicalAnd(map_checks)); + return absl::OkStatus(); + } + + // --------------------------------------------------------------------------- + // Repeated field predicate (extracted from Walk for readability) + // --------------------------------------------------------------------------- + + // Builds predicates for a repeated field: + // - Repeated Messages are mapped to a logical OR (||) of the generated + // predicates for each message. + // - Repeated Primitives are mapped either to: + // - `lhs in [values]` if a "match_path" option is specified. + // - `value in input.field` conjoined with && for each value otherwise. + absl::Status WalkRepeatedField(const Reflection* reflection, + const Message& message, + const FieldDescriptor* field, + const Expr& base_expr, int size, + std::vector& predicates) { + if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + std::vector message_asts; + message_asts.reserve(size); + for (int i = 0; i < size; ++i) { + const Message& sub_message = + reflection->GetRepeatedMessage(message, field, i); + std::vector sub_predicates; + Expr sub_base = NewSelect(NextId(), base_expr, field->name()); + CEL_RETURN_IF_ERROR(Walk(sub_message, sub_base, sub_predicates)); + message_asts.push_back(LogicalAnd(sub_predicates)); + } + // Represents alternate message predicates conjoined with OR: `msg_1 || + // msg_2 || ...` + predicates.push_back(LogicalOr(message_asts)); + return absl::OkStatus(); + } + + std::vector elements; + elements.reserve(size); + for (int i = 0; i < size; ++i) { + elements.push_back(NewListElement( + RepeatedPrimitiveToExpr(message, reflection, field, i))); + } + Expr literal_list = NewList(NextId(), std::move(elements)); + + std::string match_path_val = GetMatchPath(field); + if (!match_path_val.empty()) { + Expr lhs = ParseAndBuildPath(match_path_val); + // Represents `lhs in [values]` check (e.g. `dest.region in ["us-east", + // "us-west"]`). + predicates.push_back( + NewCall(NextId(), CelOperator::IN, + std::vector{std::move(lhs), std::move(literal_list)})); + return absl::OkStatus(); + } + + Expr map_path = NewSelect(NextId(), base_expr, field->name()); + std::vector element_checks; + element_checks.reserve(size); + for (int i = 0; i < size; ++i) { + Expr elem_expr = RepeatedPrimitiveToExpr(message, reflection, field, i); + // Represents `value in input.field` check. + Expr in_check = + NewCall(NextId(), CelOperator::IN, + std::vector{std::move(elem_expr), map_path}); + element_checks.push_back(std::move(in_check)); + } + // Represents `"val1" in input.list && "val2" in input.list && ...` + predicates.push_back(LogicalAnd(element_checks)); + + return absl::OkStatus(); + } + + // --------------------------------------------------------------------------- + // Recursive message walk + // --------------------------------------------------------------------------- + + absl::Status Walk(const Message& message, const Expr& base_expr, + std::vector& predicates) { + const Reflection* const reflection = message.GetReflection(); + std::vector fields; + reflection->ListFields(message, &fields); + + for (const auto* field : fields) { + if (field->is_map()) { + const int size = reflection->FieldSize(message, field); + if (size > 0) { + CEL_RETURN_IF_ERROR(WalkMapField(reflection, message, field, + base_expr, size, predicates)); + } + } else if (field->is_repeated()) { + const int size = reflection->FieldSize(message, field); + if (size > 0) { + CEL_RETURN_IF_ERROR(WalkRepeatedField(reflection, message, field, + base_expr, size, predicates)); + } + } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + const Message& sub_message = reflection->GetMessage(message, field); + Expr field_path = NewSelect(NextId(), base_expr, field->name()); + CEL_RETURN_IF_ERROR(Walk(sub_message, field_path, predicates)); + } else { + // Primitive field: base_expr.field == + Expr field_path = NewSelect(NextId(), base_expr, field->name()); + predicates.push_back( + ConstructEquality(std::move(field_path), + PrimitiveToExpr(message, reflection, field))); + } + } + return absl::OkStatus(); + } + + ExprId id_; + SourceInfo source_info_; +}; + +absl::StatusOr ProtoToPredicateAst(absl::string_view input_name, + const ::google::protobuf::Message& message) { + ProtoToPredicateBuilder builder; + return builder.Build(input_name, message); +} + +absl::StatusOr ProtoToPredicateAst( + absl::string_view input_name, + absl::Span messages) { + ProtoToPredicateBuilder builder; + return builder.Build(input_name, messages); +} + +} // namespace cel::tools diff --git a/tools/proto_to_predicate.h b/tools/proto_to_predicate.h new file mode 100644 index 000000000..ed01cb1e8 --- /dev/null +++ b/tools/proto_to_predicate.h @@ -0,0 +1,48 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ +#define THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "google/protobuf/message.h" + +namespace cel::tools { + +// Translates a Protocol Buffer message into a CEL AST representing a predicate. +// +// NOTE: The protocol message schemas used for policy definition should use +// `proto2` or `editions` (and not `proto3` implicit presence) to ensure correct +// behavior, as this library relies on field presence (via reflection) to +// identify which fields are explicitly set by the policy. +absl::StatusOr ProtoToPredicateAst(absl::string_view input_name, + const ::google::protobuf::Message& message); + +// Translates a list of Protocol Buffer messages into a CEL AST representing a +// conjoined or alternate predicate. +// +// NOTE: The protocol message schemas used for policy definition should use +// `proto2` or `editions` (and not `proto3` implicit presence) to ensure correct +// behavior, as this library relies on field presence (via reflection) to +// identify which fields are explicitly set by the policy. +absl::StatusOr ProtoToPredicateAst( + absl::string_view input_name, + absl::Span messages); + +} // namespace cel::tools + +#endif // THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ diff --git a/tools/proto_to_predicate_test.cc b/tools/proto_to_predicate_test.cc new file mode 100644 index 000000000..80ad140c7 --- /dev/null +++ b/tools/proto_to_predicate_test.cc @@ -0,0 +1,593 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/proto_to_predicate.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/value.h" +#include "env/config.h" +#include "env/env_runtime.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "eval/testutil/test_message.pb.h" +#include "extensions/protobuf/value.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "tools/cel_unparser.h" +#include "tools/testdata/test_policy.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/json/json.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace cel::tools { +namespace { + +using ::absl_testing::IsOk; +using ::google::api::expr::runtime::TestMessage; + +constexpr absl::string_view kEnvYaml = R"( +name: "test" +extensions: + - name: "bindings" + - name: "optional" +variables: + - name: "input" + type: "google.api.expr.runtime.TestMessage" +)"; + +TestMessage ParseTestMessage(absl::string_view textproto) { + TestMessage msg; + google::protobuf::TextFormat::ParseFromString(textproto, &msg); + return msg; +} + +absl::StatusOr EvaluatePredicate(const cel::Ast& ast, + const TestMessage& input) { + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + + CEL_ASSIGN_OR_RETURN(cel::Config config, + cel::EnvConfigFromYaml(std::string(kEnvYaml))); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::make_unique(ast))); + + google::protobuf::Arena arena; + cel::Activation activation; + CEL_ASSIGN_OR_RETURN( + cel::Value val, cel::extensions::ProtoMessageToValue( + input, descriptor_pool.get(), + google::protobuf::MessageFactory::generated_factory(), &arena)); + activation.InsertOrAssignValue("input", val); + + CEL_ASSIGN_OR_RETURN(cel::Value result, + program->Evaluate(&arena, activation)); + if (!result.IsBool()) { + return absl::InvalidArgumentError( + "Predicate evaluate result must be a boolean value."); + } + return result.GetBool(); +} + +struct TestCase { + std::string name; + std::vector input_textprotos; + std::string expected_unparsed; + std::string eval_textproto; + bool expected_eval_result = true; + // If true, skip the eval step of the test. This is useful for tests where + // the expected expression does not share the same type structure as the + // input proto, such as empty messages. + bool skip_eval = false; +}; + +class ProtoToPredicateTest : public ::testing::TestWithParam {}; + +TEST_P(ProtoToPredicateTest, ConformanceTests) { + const TestCase& param = GetParam(); + + std::vector input_messages; + input_messages.reserve(param.input_textprotos.size()); + for (const auto& proto_str : param.input_textprotos) { + input_messages.push_back(ParseTestMessage(proto_str)); + } + + std::vector ptr_messages; + ptr_messages.reserve(input_messages.size()); + for (const auto& msg : input_messages) { + ptr_messages.push_back(&msg); + } + + absl::StatusOr ast_or; + if (input_messages.size() == 1) { + ast_or = ProtoToPredicateAst("input", input_messages[0]); + } else { + ast_or = ProtoToPredicateAst("input", absl::MakeSpan(ptr_messages)); + } + + ASSERT_THAT(ast_or, IsOk()); + cel::Ast ast = std::move(*ast_or); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + + EXPECT_EQ(unparsed, param.expected_unparsed); + + if (!param.skip_eval) { + TestMessage eval_msg = ParseTestMessage(param.eval_textproto); + ASSERT_OK_AND_ASSIGN(bool eval_result, EvaluatePredicate(ast, eval_msg)); + EXPECT_EQ(eval_result, param.expected_eval_result); + } +} + +INSTANTIATE_TEST_SUITE_P( + ProtoToPredicateSubCases, ProtoToPredicateTest, + testing::Values( + TestCase{ + .name = "EmptyMessageTest", + .input_textprotos = {""}, + .expected_unparsed = "true", + .eval_textproto = "", + }, + TestCase{ + .name = "EmptyMessagesListTest", + .input_textprotos = {}, + .expected_unparsed = "true", + .eval_textproto = "", + }, + TestCase{ + .name = "PrimitivesTest", + .input_textprotos = {R"pb( + int32_value: 42 string_value: "hello" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 && input.string_value == \"hello\"", + .eval_textproto = R"pb( + int32_value: 42 string_value: "hello" + )pb", + }, + TestCase{ + .name = "AllPrimitivesTest", + .input_textprotos = {R"pb( + int32_value: 42 + int64_value: 43 + uint32_value: 44 + uint64_value: 45 + float_value: 46.5 + double_value: 47.5 + bool_value: true + enum_value: TEST_ENUM_1 + string_value: "hello" + bytes_value: "world" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 && input.int64_value == 43 && " + "input.uint32_value == 44u && input.uint64_value == 45u && " + "input.float_value == 46.5 && input.double_value == 47.5 && " + "input.string_value == \"hello\" && " + "input.bytes_value == b\"world\" && " + "input.bool_value == true && " + "input.enum_value == 1", + .eval_textproto = R"pb( + int32_value: 42 + int64_value: 43 + uint32_value: 44 + uint64_value: 45 + float_value: 46.5 + double_value: 47.5 + bool_value: true + enum_value: TEST_ENUM_1 + string_value: "hello" + bytes_value: "world" + )pb", + }, + TestCase{ + .name = "NestedMessageTest", + .input_textprotos = {R"pb( + message_value: { int32_value: 42 } + )pb"}, + .expected_unparsed = "input.message_value.int32_value == 42", + .eval_textproto = R"pb( + message_value: { int32_value: 42 } + )pb", + }, + TestCase{ + .name = "RepeatedFieldTest", + .input_textprotos = {R"pb( + int32_list: [ 1, 2 ] + )pb"}, + .expected_unparsed = + "1 in input.int32_list && 2 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 1, 2 ] + )pb", + }, + TestCase{ + .name = "RepeatedFieldSingleElementTest", + .input_textprotos = {R"pb( + int32_list: [ 42 ] + )pb"}, + .expected_unparsed = "42 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 42 ] + )pb", + }, + TestCase{ + .name = "RepeatedFieldEmptyTest", + .input_textprotos = {R"pb( + int32_list: [] + )pb"}, + .expected_unparsed = "true", + .eval_textproto = R"pb( + int32_list: [] + )pb", + }, + TestCase{ + .name = "ListFieldEvalNegative", + .input_textprotos = {R"pb( + int32_list: [ 1, 2 ] + )pb"}, + .expected_unparsed = + "1 in input.int32_list && 2 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 1, 3 ] + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "SingleRepeatedFieldAllPrimitivesTest", + .input_textprotos = {R"pb( + int32_list: [ 42 ] + int64_list: [ 43 ] + uint32_list: [ 44 ] + uint64_list: [ 45 ] + float_list: [ 46.5 ] + double_list: [ 47.5 ] + bool_list: [ true ] + enum_list: [ TEST_ENUM_1 ] + string_list: [ "hello" ] + bytes_list: [ "world" ] + )pb"}, + .expected_unparsed = "42 in input.int32_list && " + "43 in input.int64_list && " + "44u in input.uint32_list && " + "45u in input.uint64_list && " + "46.5 in input.float_list && " + "47.5 in input.double_list && " + "\"hello\" in input.string_list && " + "b\"world\" in input.bytes_list && " + "true in input.bool_list && " + "1 in input.enum_list", + .eval_textproto = R"pb( + int32_list: [ 42 ] + int64_list: [ 43 ] + uint32_list: [ 44 ] + uint64_list: [ 45 ] + float_list: [ 46.5 ] + double_list: [ 47.5 ] + bool_list: [ true ] + enum_list: [ TEST_ENUM_1 ] + string_list: [ "hello" ] + bytes_list: [ "world" ] + )pb", + }, + TestCase{ + .name = "MultipleRepeatedFieldAllPrimitivesTest", + .input_textprotos = {R"pb( + int32_list: [ 42, 142 ] + int64_list: [ 43, 143 ] + uint32_list: [ 44, 144 ] + uint64_list: [ 45, 145 ] + float_list: [ 46.5, 146.5 ] + double_list: [ 47.5, 147.5 ] + bool_list: [ true, false ] + enum_list: [ TEST_ENUM_1, TEST_ENUM_2 ] + string_list: [ "hello", "universe" ] + bytes_list: [ "world", "space" ] + )pb"}, + .expected_unparsed = + "42 in input.int32_list && 142 in input.int32_list && " + "43 in input.int64_list && 143 in input.int64_list && " + "44u in input.uint32_list && 144u in input.uint32_list && " + "45u in input.uint64_list && 145u in input.uint64_list && " + "46.5 in input.float_list && 146.5 in input.float_list && " + "47.5 in input.double_list && 147.5 in input.double_list && " + "\"hello\" in input.string_list && \"universe\" in " + "input.string_list && " + "b\"world\" in input.bytes_list && b\"space\" in " + "input.bytes_list && " + "true in input.bool_list && false in input.bool_list && " + "1 in input.enum_list && 2 in input.enum_list", + .eval_textproto = R"pb( + int32_list: [ 42, 142 ] + int64_list: [ 43, 143 ] + uint32_list: [ 44, 144 ] + uint64_list: [ 45, 145 ] + float_list: [ 46.5, 146.5 ] + double_list: [ 47.5, 147.5 ] + bool_list: [ true, false ] + enum_list: [ TEST_ENUM_1, TEST_ENUM_2 ] + string_list: [ "hello", "universe" ] + bytes_list: [ "world", "space" ] + )pb", + }, + TestCase{ + .name = "MapFieldTest", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb", + }, + TestCase{ + .name = "MapFieldEvalNegativeVal", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 3 } + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "MapFieldEvalNegativeNoKey", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "MapFieldIntKeySortingTest", + .input_textprotos = {R"pb( + int32_int32_map: { key: 10 value: 100 } + int32_int32_map: { key: 5 value: 50 } + int32_int32_map: { key: 8 value: 80 } + )pb"}, + .expected_unparsed = "5 in input.int32_int32_map && " + "input.int32_int32_map[5] == 50 && " + "8 in input.int32_int32_map && " + "input.int32_int32_map[8] == 80 && " + "10 in input.int32_int32_map && " + "input.int32_int32_map[10] == 100", + .eval_textproto = R"pb( + int32_int32_map: { key: 5 value: 50 } + int32_int32_map: { key: 8 value: 80 } + int32_int32_map: { key: 10 value: 100 } + )pb", + }, + TestCase{ + .name = "MultipleMessagesTest", + .input_textprotos = {R"pb( + int32_value: 42 + )pb", + R"pb( + int32_value: 41 string_value: "hello" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 || input.int32_value == 41 && " + "input.string_value == \"hello\"", + .eval_textproto = R"pb( + int32_value: 41 string_value: "hello" + )pb", + }, + TestCase{ + .name = "RepeatedMessageFieldTest", + .input_textprotos = {R"pb( + message_list: + [ { int32_value: 42 } + , { int32_value: 43 }] + )pb"}, + .expected_unparsed = "input.message_list.int32_value == 42 || " + "input.message_list.int32_value == 43", + .skip_eval = true, + }, + TestCase{ + .name = "RepeatedMessageSingleElementTest", + .input_textprotos = {R"pb( + message_list: + [ { int32_value: 42 }] + )pb"}, + .expected_unparsed = "input.message_list.int32_value == 42", + .skip_eval = true, + })); + +struct PolicyTestCase { + std::string name; + std::string json_input; + std::string expected_unparsed; +}; + +class PolicyJsonTest : public ::testing::TestWithParam {}; + +TEST_P(PolicyJsonTest, Conformance) { + const PolicyTestCase& param = GetParam(); + + cel::cpp::tools::Policy policy; + google::protobuf::json::ParseOptions options; + options.ignore_unknown_fields = true; + auto status = + google::protobuf::json::JsonStringToMessage(param.json_input, &policy, options); + ASSERT_THAT(status, IsOk()) << "Failed to parse JSON: " << param.json_input; + + absl::StatusOr ast_or; + std::vector ptr_messages; + ptr_messages.reserve(policy.destinations_size()); + for (const auto& dest : policy.destinations()) { + ptr_messages.push_back(&dest); + } + + if (ptr_messages.empty()) { + auto parsed_expr_or = google::api::expr::parser::Parse("false"); + ASSERT_THAT(parsed_expr_or, IsOk()); + auto ast_ptr_or = cel::CreateAstFromParsedExpr(*parsed_expr_or); + ASSERT_THAT(ast_ptr_or, IsOk()); + ast_or = std::move(**ast_ptr_or); + } else if (ptr_messages.size() == 1) { + ast_or = ProtoToPredicateAst("dest", *ptr_messages[0]); + } else { + ast_or = ProtoToPredicateAst("dest", absl::MakeSpan(ptr_messages)); + } + + ASSERT_THAT(ast_or, IsOk()); + cel::Ast ast = std::move(*ast_or); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + + EXPECT_EQ(unparsed, param.expected_unparsed); +} + +INSTANTIATE_TEST_SUITE_P( + PolicyJsonSubCases, PolicyJsonTest, + testing::Values( + PolicyTestCase{ + .name = "SimpleMatch", + .json_input = + R"({ "destinations": [ { "agent": { "id": "agent-007" } } ] })", + .expected_unparsed = "dest.agent.name == \"agent-007\"", + }, + PolicyTestCase{ + .name = "MultipleFields", + .json_input = + R"({ "destinations": [ { + "tool": { + "name": "admin_tool", + "annotations": { + "read_only_hint": false + } + } + } + ] })", + .expected_unparsed = + "dest.tool.name == \"admin_tool\" && " + "dest.tool.annotations.read_only_hint == false", + }, + PolicyTestCase{ + .name = "RepeatedMessages", + .json_input = + R"({ "destinations": [ + { "agent": { "id": "worker-1" } }, + { "agent": { "id": "worker-2" } }, + ] })", + .expected_unparsed = "dest.agent.name == \"worker-1\" || " + "dest.agent.name == \"worker-2\"", + }, + PolicyTestCase{ + .name = "RepeatedPrimitiveArraySingleElement", + .json_input = + R"({ "destinations": [ { + "tool": { + "role_members": { + "admin": { + "principals": ["alice"] + } + } + } + } ] })", + .expected_unparsed = + "\"admin\" in dest.tool.role_members && " + "\"alice\" in dest.tool.role_members[\"admin\"].principals", + }, + PolicyTestCase{ + .name = "RepeatedArrayEmpty", + .json_input = R"({ "destinations": [ { "tool": { } } ] })", + .expected_unparsed = "true", + }, + PolicyTestCase{ + .name = "MapEquality", + .json_input = + R"({ "destinations": [ + { "tool": { + "name": "shell", + "labels": { + "cluster": "us-central1", + "project": "dev" + } + } + } ] })", + .expected_unparsed = + "dest.tool.name == \"shell\" && \"cluster\" in " + "dest.tool.labels && dest.tool.labels[\"cluster\"] == " + "\"us-central1\" && \"project\" in dest.tool.labels && " + "dest.tool.labels[\"project\"] == \"dev\"", + }, + PolicyTestCase{ + .name = "NestedMapEquality", + .json_input = + R"({ "destinations": [ + { "tool": { + "role_members": { + "admin": { + "all_users": true + } + } + } } + ] })", + .expected_unparsed = + "\"admin\" in dest.tool.role_members && " + "dest.tool.role_members[\"admin\"].all_users == true", + }, + PolicyTestCase{ + .name = "EmptyPolicy", + .json_input = "{}", + .expected_unparsed = "false", + })); + +} // namespace +} // namespace cel::tools diff --git a/tools/testdata/BUILD b/tools/testdata/BUILD index 493f0ff2f..c88c9c478 100644 --- a/tools/testdata/BUILD +++ b/tools/testdata/BUILD @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -load( - "@com_github_google_flatbuffers//:build_defs.bzl", - "flatbuffer_library_public", -) +load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_library_public") +load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") +load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") load("@rules_cc//cc:cc_library.bzl", "cc_library") licenses(["notice"]) @@ -46,3 +45,15 @@ cc_library( linkstatic = True, deps = ["@com_github_google_flatbuffers//:runtime_cc"], ) + +proto_library( + name = "test_policy_proto", + srcs = ["test_policy.proto"], + visibility = ["//tools:__subpackages__"], +) + +cc_proto_library( + name = "test_policy_cc_proto", + visibility = ["//tools:__subpackages__"], + deps = [":test_policy_proto"], +) diff --git a/tools/testdata/test_policy.proto b/tools/testdata/test_policy.proto new file mode 100644 index 000000000..b5d424c04 --- /dev/null +++ b/tools/testdata/test_policy.proto @@ -0,0 +1,73 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test schema representing client-configured policies. +// It is used by the `proto_to_predicate` tool to translate Protobuf policies +// into CEL predicates. +edition = "2023"; + +package cel.cpp.tools; + +option cc_enable_arenas = true; + +// Represents the targeted client agent. +message Agent { + string name = 1 [json_name = "id"]; +} + +// Specifies additional metadata tool annotations. +message ToolAnnotations { + bool read_only_hint = 1; +} + +// Represents a mapped nested message entry value inside map fields. +message Members { + repeated string principals = 1; + + repeated string regions = 2; + + bool all_users = 3; + + bool all_authenticated_users = 4; +} + +// Represents a metadata tool block. +message Tool { + // The name of the tool. + string name = 1; + + // Additional metadata annotations for the tool. + ToolAnnotations annotations = 2; + + // A string-to-string map, transpiled as conjoined existence and equality + // checks. + map labels = 3; + + // A map with string keys representing roles and Member instances as values. + map role_members = 4; +} + +// Represents a policy mapping destination block. +message Target { + oneof kind { + Agent agent = 1; + Tool tool = 2; + } +} + +// Represents the top-level policy containing multiple alternate destination +// rules. +message Policy { + repeated Target destinations = 1; +} From 9597d49ef616726bdc8b2a553fd89fd46ad8b493 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 15 Jun 2026 11:43:00 -0700 Subject: [PATCH 63/87] Add support for context types env.yaml Add support for declaring a protobuf context message whose top level fields are declared as variables in the CEL environment. PiperOrigin-RevId: 932580314 --- env/BUILD | 1 + env/config.h | 6 ++++++ env/env.cc | 5 +++++ env/env_test.cc | 19 +++++++++++++++++++ env/env_yaml.cc | 30 ++++++++++++++++++++++++++++++ env/env_yaml_test.cc | 41 +++++++++++++++++++++++++++++++++++++++++ 6 files changed, 102 insertions(+) diff --git a/env/BUILD b/env/BUILD index 3035e11ac..bd82e8ec6 100644 --- a/env/BUILD +++ b/env/BUILD @@ -130,6 +130,7 @@ cc_library( "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/env/config.h b/env/config.h index e427832ff..68e4a1dd9 100644 --- a/env/config.h +++ b/env/config.h @@ -32,6 +32,11 @@ class Config { void SetName(std::string name) { name_ = std::move(name); } std::string GetName() const { return name_; } + void SetContextType(std::string context_type) { + context_type_ = std::move(context_type); + } + std::string GetContextType() const { return context_type_; } + struct ContainerConfig { std::string name; std::vector abbreviations; @@ -150,6 +155,7 @@ class Config { private: std::string name_; + std::string context_type_; ContainerConfig container_config_; std::vector extension_configs_; StandardLibraryConfig standard_library_config_; diff --git a/env/env.cc b/env/env.cc index 6cd3a3cdc..22d24295e 100644 --- a/env/env.cc +++ b/env/env.cc @@ -138,6 +138,11 @@ absl::StatusOr> Env::NewCompilerBuilder() { for (const auto& abbr : config_.GetContainerConfig().abbreviations) { CEL_RETURN_IF_ERROR(container.AddAbbreviation(abbr)); } + + if (!config_.GetContextType().empty()) { + CEL_RETURN_IF_ERROR( + checker_builder.AddContextDeclaration(config_.GetContextType())); + } for (const auto& alias : config_.GetContainerConfig().aliases) { CEL_RETURN_IF_ERROR(container.AddAlias(alias.alias, alias.qualified_name)); } diff --git a/env/env_test.cc b/env/env_test.cc index b599aa569..fda87dfab 100644 --- a/env/env_test.cc +++ b/env/env_test.cc @@ -344,6 +344,25 @@ TEST(ContainerConfigTest, ContainerConfigWithAliases) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << result.FormatError(); } +TEST(ContextVariableConfigTest, Basic) { + Env env; + env.SetDescriptorPool(internal::GetSharedTestingDescriptorPool()); + Config config; + config.SetContextType("cel.expr.conformance.proto3.TestAllTypes"); + env.SetConfig(config); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, env.NewCompiler()); + + // Top-level fields of TestAllTypes like "single_int32" should resolve + // successfully. + ASSERT_OK_AND_ASSIGN(auto result, compiler->Compile("single_int32 > 10")); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + ASSERT_OK_AND_ASSIGN(auto result_invalid, + compiler->Compile("non_existent_field > 10")); + EXPECT_THAT(result_invalid.GetIssues(), Not(IsEmpty())); +} + struct VariableConfigWithValueTestCase { Config::VariableConfig variable_config; std::string validate_type_expr; diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 1bbfe6b36..e7b8a7885 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -26,6 +26,7 @@ #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/escaping.h" @@ -1245,6 +1246,34 @@ void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out, } out << YAML::EndSeq; } + +absl::Status ParseContextVariableConfig(Config& config, absl::string_view yaml, + const YAML::Node& root) { + const YAML::Node context_variable = root["context_variable"]; + if (!context_variable.IsDefined()) { + return absl::OkStatus(); + } + if (!context_variable.IsMap()) { + return YamlError(yaml, context_variable, + "Node 'context_variable' is not a map"); + } + + const YAML::Node type_name = context_variable["type_name"]; + const YAML::Node type = context_variable["type"]; + const YAML::Node* type_node = nullptr; + if (type.IsDefined() && type.IsScalar()) { + type_node = &type; + } else if (type_name.IsDefined() && type_name.IsScalar()) { + type_node = &type_name; + } else { + return YamlError(yaml, context_variable, + "Node 'context_variable' does not have a valid type"); + } + ABSL_DCHECK(type_node != nullptr); + config.SetContextType(GetString(yaml, *type_node)); + return absl::OkStatus(); +} + } // namespace absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { @@ -1263,6 +1292,7 @@ absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { CEL_RETURN_IF_ERROR(ParseContainerConfig(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseExtensionConfigs(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseStandardLibraryConfig(config, yaml, root)); + CEL_RETURN_IF_ERROR(ParseContextVariableConfig(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseVariableConfigs(config, yaml, root)); CEL_RETURN_IF_ERROR(ParseFunctionConfigs(config, yaml, root)); return config; diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index a60048617..9c5b3f04f 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -216,6 +216,47 @@ TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { EXPECT_THAT(type_info.params[1].params, IsEmpty()); } +TEST(EnvYamlTest, ParseContextVariableConfig) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + context_variable: + type_name: "cel.expr.conformance.proto3.TestAllTypes" + )yaml")); + + EXPECT_EQ(config.GetContextType(), + "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(EnvYamlTest, ParseContextVariableConfigAlternativeSyntax) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + context_variable: + type: "cel.expr.conformance.proto3.TestAllTypes" + )yaml")); + + EXPECT_EQ(config.GetContextType(), + "cel.expr.conformance.proto3.TestAllTypes"); +} + +TEST(EnvYamlTest, ParseContextVariableMalformedContextVariable) { + EXPECT_THAT(EnvConfigFromYaml(R"yaml( + context_variable: 123 + + )yaml"), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Node 'context_variable' is not a map"))); +} + +TEST(EnvYamlTest, ParseContextVariableMalformedContextVariable2) { + EXPECT_THAT( + EnvConfigFromYaml(R"yaml( + context_variable: + type: + foo: bar + )yaml"), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr("Node 'context_variable' does not have a valid type"))); +} + TEST(EnvYamlTest, ParseVariableConfigWithTypeParamsLegacySyntax) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( variables: From 01e6b6379d8100f6f6dcb6e41627037d1b2a7112 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 16 Jun 2026 06:44:51 -0700 Subject: [PATCH 64/87] Export policy compiler. Co-authored-by: Dmitri Plotnikov PiperOrigin-RevId: 933062225 --- MODULE.bazel | 14 + conformance/policy/BUILD | 78 ++ .../policy/policy_conformance_test.bzl | 46 + conformance/policy/policy_conformance_test.cc | 659 ++++++++++ internal/BUILD | 1 + internal/runfiles.cc | 15 +- internal/runfiles.h | 6 + policy/BUILD | 239 ++++ policy/cel_policy.cc | 273 +++++ policy/cel_policy.h | 320 +++++ policy/cel_policy_parse_context.cc | 49 + policy/cel_policy_parse_context.h | 65 + policy/cel_policy_parse_result.cc | 91 ++ policy/cel_policy_parse_result.h | 105 ++ policy/cel_policy_parser.h | 40 + policy/cel_policy_test.cc | 220 ++++ policy/cel_policy_validation_result.cc | 32 + policy/cel_policy_validation_result.h | 84 ++ policy/compiler.cc | 1058 +++++++++++++++++ policy/compiler.h | 50 + policy/compiler_test.cc | 946 +++++++++++++++ policy/internal/BUILD | 68 ++ policy/internal/issue_reporter.cc | 45 + policy/internal/issue_reporter.h | 57 + policy/internal/optimizer_expr_factory.cc | 373 ++++++ policy/internal/optimizer_expr_factory.h | 419 +++++++ .../internal/optimizer_expr_factory_test.cc | 570 +++++++++ policy/test_custom_yaml_policy_parser.cc | 188 +++ policy/test_util.cc | 221 ++++ policy/test_util.h | 33 + policy/testdata/BUILD | 19 + policy/testdata/cel_policy.yaml | 42 + policy/testdata/cel_policy_parser.baseline | 89 ++ policy/testdata/custom_policy_format.yaml | 29 + .../custom_policy_format_parser.baseline | 75 ++ .../custom_policy_format_with_errors.yaml | 33 + ..._policy_format_with_errors_parser.baseline | 16 + policy/testdata/nested_rule.yaml | 37 + policy/testdata/nested_rule_parser.baseline | 84 ++ policy/yaml_policy_parser.cc | 411 +++++++ policy/yaml_policy_parser.h | 135 +++ policy/yaml_policy_parser_test.cc | 305 +++++ 42 files changed, 7639 insertions(+), 1 deletion(-) create mode 100644 conformance/policy/BUILD create mode 100644 conformance/policy/policy_conformance_test.bzl create mode 100644 conformance/policy/policy_conformance_test.cc create mode 100644 policy/BUILD create mode 100644 policy/cel_policy.cc create mode 100644 policy/cel_policy.h create mode 100644 policy/cel_policy_parse_context.cc create mode 100644 policy/cel_policy_parse_context.h create mode 100644 policy/cel_policy_parse_result.cc create mode 100644 policy/cel_policy_parse_result.h create mode 100644 policy/cel_policy_parser.h create mode 100644 policy/cel_policy_test.cc create mode 100644 policy/cel_policy_validation_result.cc create mode 100644 policy/cel_policy_validation_result.h create mode 100644 policy/compiler.cc create mode 100644 policy/compiler.h create mode 100644 policy/compiler_test.cc create mode 100644 policy/internal/BUILD create mode 100644 policy/internal/issue_reporter.cc create mode 100644 policy/internal/issue_reporter.h create mode 100644 policy/internal/optimizer_expr_factory.cc create mode 100644 policy/internal/optimizer_expr_factory.h create mode 100644 policy/internal/optimizer_expr_factory_test.cc create mode 100644 policy/test_custom_yaml_policy_parser.cc create mode 100644 policy/test_util.cc create mode 100644 policy/test_util.h create mode 100644 policy/testdata/BUILD create mode 100644 policy/testdata/cel_policy.yaml create mode 100644 policy/testdata/cel_policy_parser.baseline create mode 100644 policy/testdata/custom_policy_format.yaml create mode 100644 policy/testdata/custom_policy_format_parser.baseline create mode 100644 policy/testdata/custom_policy_format_with_errors.yaml create mode 100644 policy/testdata/custom_policy_format_with_errors_parser.baseline create mode 100644 policy/testdata/nested_rule.yaml create mode 100644 policy/testdata/nested_rule_parser.baseline create mode 100644 policy/yaml_policy_parser.cc create mode 100644 policy/yaml_policy_parser.h create mode 100644 policy/yaml_policy_parser_test.cc diff --git a/MODULE.bazel b/MODULE.bazel index 43d0485d2..187d68164 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -31,6 +31,7 @@ bazel_dep( name = "rules_python", version = "1.6.3", ) +bazel_dep(name = "rules_license", version = "1.0.0") bazel_dep( name = "protobuf", version = "34.1", @@ -96,3 +97,16 @@ bazel_dep( name = "yaml-cpp", version = "0.9.0", ) + +_CEL_POLICY_TAG = "ebfb2361f47080af643c14cf4da4c2b551a68740" + +_CEL_POLICY_SHA = "ea69e9c6b7bd5bc37d358148aebd2fcca38bc7c45a23feb635de72338e0327c1" + +http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "cel_policy", + sha256 = _CEL_POLICY_SHA, + strip_prefix = "cel-policy-%s" % _CEL_POLICY_TAG, + url = "https://github.com/cel-expr/cel-policy/archive/%s.tar.gz" % _CEL_POLICY_TAG, +) diff --git a/conformance/policy/BUILD b/conformance/policy/BUILD new file mode 100644 index 000000000..29210e02d --- /dev/null +++ b/conformance/policy/BUILD @@ -0,0 +1,78 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load( + "//conformance/policy:policy_conformance_test.bzl", + "cel_policy_conformance_test", +) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "policy_conformance_test_lib", + testonly = True, + srcs = ["policy_conformance_test.cc"], + deps = [ + "//common:ast", + "//common:source", + "//common:value", + "//common/internal:value_conversion", + "//compiler", + "//env", + "//env:config", + "//env:env_runtime", + "//env:env_std_extensions", + "//env:env_yaml", + "//env:runtime_std_extensions", + "//extensions/protobuf:bind_proto_to_activation", + "//extensions/protobuf:enum_adapter", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing_descriptor_pool", + "//internal:testing_no_main", + "//policy:cel_policy", + "//policy:cel_policy_parser", + "//policy:cel_policy_validation_result", + "//policy:compiler", + "//policy:test_util", + "//policy:yaml_policy_parser", + "//runtime", + "//runtime:activation", + "//runtime:function_adapter", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:protobuf", + ], + alwayslink = True, +) + +cel_policy_conformance_test( + name = "policy_conformance_test", + example = "@cel_policy//conformance:testdata/nested_rule/policy.yaml", + skip_tests = [ + # TODO(b/506179116): Fix these. + # Need to add k8s custom yaml parser and mock runtime. + "k8s", + ], + test_files = [ + "@cel_policy//conformance:testdata", + ], +) diff --git a/conformance/policy/policy_conformance_test.bzl b/conformance/policy/policy_conformance_test.bzl new file mode 100644 index 000000000..0b4d1a4c6 --- /dev/null +++ b/conformance/policy/policy_conformance_test.bzl @@ -0,0 +1,46 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains build rules for generating policy conformance test targets. +""" + +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +def cel_policy_conformance_test(name, test_files, example, skip_tests = [], **kwargs): + """Generates a policy conformance test target. + + Args: + name: Name of the test target. + test_files: List of targets or files representing the test data. + example: A specific example file from test_files used for runfiles resolution. + skip_tests: List of test cases to skip. + testdata_dir: Path to testdata directory under runfiles. + **kwargs: Additional arguments passed to the underlying cc_test. + """ + args = ["--gunit_fail_if_no_test_linked"] + args.append("--testdata_example='$(rlocationpath {})'".format(example)) + + if skip_tests: + args.append("--skip_tests=" + ",".join(skip_tests)) + + cc_test( + name = name, + data = test_files + [example], + deps = [ + "//conformance/policy:policy_conformance_test_lib", + ], + args = args, + **kwargs + ) diff --git a/conformance/policy/policy_conformance_test.cc b/conformance/policy/policy_conformance_test.cc new file mode 100644 index 000000000..0d68f8abf --- /dev/null +++ b/conformance/policy/policy_conformance_test.cc @@ -0,0 +1,659 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +// NOLINTNEXTLINE(build/c++17) for OSS compatibility +#include + +#include "cel/expr/eval.pb.h" +#include "absl/flags/flag.h" +#include "absl/log/absl_check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/internal/value_conversion.h" +#include "common/source.h" +#include "common/value.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_runtime.h" +#include "env/env_std_extensions.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "extensions/protobuf/bind_proto_to_activation.h" +#include "extensions/protobuf/enum_adapter.h" +#include "internal/runfiles.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/compiler.h" +#include "policy/test_util.h" +#include "policy/yaml_policy_parser.h" +#include "runtime/activation.h" +#include "runtime/function_adapter.h" +#include "runtime/runtime.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/dynamic_message.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +// Use a specific file to handle bazel runfiles resolution correctly. We find +// parent directory named 'testdata' to use as the root of the test cases. +ABSL_FLAG(std::string, testdata_example, "", + "Path to a specific example file."); +ABSL_FLAG(std::vector, skip_tests, {}, + "Comma-separated list of tests to skip."); + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::cel::expr::conformance::test::TestSuite; +using ::cel::internal::GetSharedTestingDescriptorPool; +using ::testing::HasSubstr; + +// Implementations for extension functions referenced in conformance tests. +cel::Value LocationCode(const cel::StringValue& ip, + const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* factory, google::protobuf::Arena* arena) { + std::string ip_str = ip.ToString(); + if (ip_str == "10.0.0.1") return cel::StringValue(arena, "us"); + if (ip_str == "10.0.0.2") return cel::StringValue(arena, "de"); + return cel::StringValue(arena, "ir"); +} + +// TODO(uncreated-issue/92): This should be migrated to use the testrunner utility +// after adding support for reading the yaml specification for envs/tests. +class InputEvaluator { + public: + static absl::StatusOr> Create( + const std::shared_ptr& pool) { + cel::Env env; + env.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.mutable_runtime_options().enable_qualified_type_identifiers = + true; + + // Enable default extensions (optional, bindings) + cel::Config config; + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "optional", cel::Config::ExtensionConfig::kLatest)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "bindings", cel::Config::ExtensionConfig::kLatest)); + env.SetConfig(config); + env_runtime.SetConfig(config); + + auto compiler_builder_or = env.NewCompilerBuilder(); + CEL_ASSIGN_OR_RETURN(auto compiler_builder, std::move(compiler_builder_or)); + compiler_builder->GetParserBuilder().GetOptions().enable_optional_syntax = + true; + CEL_ASSIGN_OR_RETURN(auto compiler, compiler_builder->Build()); + + auto runtime_builder_or = env_runtime.CreateRuntimeBuilder(); + CEL_ASSIGN_OR_RETURN(auto runtime_builder, std::move(runtime_builder_or)); + + // Register conformance enums + for (const auto& enum_name : + {"cel.expr.conformance.proto2.GlobalEnum", + "cel.expr.conformance.proto3.GlobalEnum", + "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"}) { + auto* enum_desc = pool->FindEnumTypeByName(enum_name); + if (enum_desc != nullptr) { + CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtobufEnum( + runtime_builder.type_registry(), enum_desc)); + } + } + + CEL_ASSIGN_OR_RETURN(auto runtime, std::move(runtime_builder).Build()); + + return absl::WrapUnique( + new InputEvaluator(std::move(compiler), std::move(runtime))); + } + + absl::StatusOr Evaluate( + absl::string_view expr_str, google::protobuf::Arena* arena, + google::protobuf::MessageFactory* message_factory) const { + CEL_ASSIGN_OR_RETURN(auto validation_result, compiler_->Compile(expr_str)); + if (!validation_result.IsValid()) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to compile input expr: ", expr_str)); + } + CEL_ASSIGN_OR_RETURN(auto ast, validation_result.ReleaseAst()); + CEL_ASSIGN_OR_RETURN( + auto program, + runtime_->CreateProgram(std::make_unique(std::move(*ast)))); + cel::Activation activation; + EvaluateOptions options; + options.message_factory = message_factory; + return program->Evaluate(arena, activation, options); + } + + private: + InputEvaluator(std::unique_ptr compiler, + std::unique_ptr runtime) + : compiler_(std::move(compiler)), runtime_(std::move(runtime)) {} + + std::unique_ptr compiler_; + std::unique_ptr runtime_; +}; + +absl::StatusOr EvaluateInputValue( + const cel::expr::conformance::test::InputValue& input_val, + const InputEvaluator& evaluator, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, google::protobuf::Arena* arena) { + if (input_val.has_expr()) { + return evaluator.Evaluate(input_val.expr(), arena, message_factory); + } + if (input_val.has_value()) { + return cel::test::FromExprValue(input_val.value(), descriptor_pool, + message_factory, arena); + } + return absl::InvalidArgumentError("Empty InputValue"); +} + +class CelValueMatcherImpl + : public testing::MatcherInterface { + public: + CelValueMatcherImpl(cel::Value expected_val, + const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* message_factory, + google::protobuf::Arena* arena) + : expected_val_(std::move(expected_val)), + pool_(pool), + message_factory_(message_factory), + arena_(arena) {} + + bool MatchAndExplain(const cel::Value& actual_val, + testing::MatchResultListener* listener) const override { + cel::Value actual = actual_val; + if (actual.IsOptional() && !expected_val_.IsOptional()) { + auto opt_val = actual.AsOptional(); + if (opt_val->HasValue()) { + actual = opt_val->Value(); + } + } + cel::Value eq_result; + auto eq_status = actual.Equal(expected_val_, pool_, message_factory_, + arena_, &eq_result); + if (!eq_status.ok()) { + *listener << "equality check failed with status: " << eq_status; + return false; + } + if (!eq_result.IsTrue()) { + *listener << "expected: " << expected_val_.DebugString() + << "\nactual: " << actual.DebugString(); + return false; + } + return true; + } + + void DescribeTo(std::ostream* os) const override { + *os << "is equal to " << expected_val_.DebugString(); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "is not equal to " << expected_val_.DebugString(); + } + + private: + cel::Value expected_val_; + const google::protobuf::DescriptorPool* pool_; + google::protobuf::MessageFactory* message_factory_; + google::protobuf::Arena* arena_; +}; + +absl::StatusOr> MakeExpectedValueMatcher( + const cel::expr::conformance::test::TestOutput& output, + const InputEvaluator& input_evaluator, const google::protobuf::DescriptorPool* pool, + google::protobuf::MessageFactory* message_factory, google::protobuf::Arena* arena) { + cel::Value expected_val; + if (output.has_result_expr()) { + CEL_ASSIGN_OR_RETURN( + expected_val, + input_evaluator.Evaluate(output.result_expr(), arena, message_factory)); + } else if (output.has_result_value()) { + CEL_ASSIGN_OR_RETURN(expected_val, + cel::test::FromExprValue(output.result_value(), pool, + message_factory, arena)); + } else { + return absl::InvalidArgumentError("Unsupported output kind"); + } + return testing::Matcher( + new CelValueMatcherImpl(expected_val, pool, message_factory, arena)); +} + +bool ShouldRunTest(absl::string_view test_name, + const std::vector& skip_tests) { + for (const std::string& skip : skip_tests) { + if (absl::StartsWith(test_name, skip)) { + return false; + } + } + return true; +} + +absl::Status PopulateActivation( + const cel::expr::conformance::test::TestCase& test, + const InputEvaluator& input_evaluator, + const google::protobuf::DescriptorPool* descriptor_pool, + google::protobuf::MessageFactory* message_factory, + absl::string_view context_msg_type_name, google::protobuf::Arena* arena, + Activation& activation) { + if (!test.has_input_context()) { + for (const auto& [var_name, input_val] : test.input()) { + CEL_ASSIGN_OR_RETURN( + auto val, + EvaluateInputValue(input_val, input_evaluator, descriptor_pool, + message_factory, arena)); + activation.InsertOrAssignValue(var_name, std::move(val)); + } + return absl::OkStatus(); + } + + const auto& input_context = test.input_context(); + const google::protobuf::Message* context_message = nullptr; + + if (input_context.has_context_message()) { + const google::protobuf::Any& any_msg = input_context.context_message(); + const google::protobuf::Descriptor* msg_descriptor = + descriptor_pool->FindMessageTypeByName(context_msg_type_name); + if (msg_descriptor == nullptr) { + return absl::NotFoundError(absl::StrCat( + "Failed to find message descriptor for: ", context_msg_type_name)); + } + const google::protobuf::Message* prototype = + message_factory->GetPrototype(msg_descriptor); + if (prototype == nullptr) { + return absl::NotFoundError( + absl::StrCat("Failed to get prototype for: ", context_msg_type_name)); + } + auto* buf = prototype->New(arena); + if (!any_msg.UnpackTo(buf)) { + return absl::InvalidArgumentError(absl::StrCat( + "Failed to unpack context message to ", context_msg_type_name)); + } + context_message = buf; + } else if (input_context.has_context_expr() && + !context_msg_type_name.empty()) { + CEL_ASSIGN_OR_RETURN(cel::Value evaluated_val, + input_evaluator.Evaluate(input_context.context_expr(), + arena, message_factory)); + + if (!evaluated_val.IsParsedMessage()) { + return absl::InvalidArgumentError( + absl::StrCat("Context expression did not evaluate to a message: ", + input_context.context_expr())); + } + if (evaluated_val.GetParsedMessage().GetDescriptor()->full_name() != + context_msg_type_name) { + return absl::InvalidArgumentError(absl::StrCat( + "Context expression evaluated to a message of type ", + evaluated_val.GetParsedMessage().GetDescriptor()->full_name(), + " which does not match the expected type ", context_msg_type_name)); + } + context_message = static_cast( + evaluated_val.GetParsedMessage().operator->()); + } + if (context_message == nullptr) { + return absl::InvalidArgumentError( + "Failed to resolve context message for test case"); + } + + return cel::extensions::BindProtoToActivation( + *context_message, + cel::extensions::BindProtoUnsetFieldBehavior::kBindDefaultValue, + descriptor_pool, message_factory, arena, &activation); +} + +class PolicyTestSuiteRunner { + public: + PolicyTestSuiteRunner(std::string suite_name, + std::unique_ptr compiler, + std::unique_ptr runtime, + std::shared_ptr policy_source, + CelPolicyValidationResult compile_result, + std::shared_ptr pool, + std::shared_ptr message_factory, + std::shared_ptr input_evaluator, + std::string context_msg_type_name, + bool expect_compile_fail = false) + : suite_name_(std::move(suite_name)), + compiler_(std::move(compiler)), + runtime_(std::move(runtime)), + policy_source_(std::move(policy_source)), + compile_result_(std::move(compile_result)), + pool_(std::move(pool)), + message_factory_(std::move(message_factory)), + input_evaluator_(std::move(input_evaluator)), + context_msg_type_name_(std::move(context_msg_type_name)), + expect_compile_fail_(expect_compile_fail) {} + + void RunTest(const cel::expr::conformance::test::TestCase& test, + absl::string_view full_test_name) { + const auto& output = test.output(); + + if (expect_compile_fail_) { + ASSERT_FALSE(compile_result_.IsValid()) + << "Expected compilation to fail in " << full_test_name; + ASSERT_TRUE(output.has_eval_error()) + << "Expected eval_error to be present in compile error test " + << full_test_name; + std::string err_msg = compile_result_.FormatIssues(); + for (const auto& expected_err : output.eval_error().errors()) { + EXPECT_THAT(err_msg, HasSubstr(expected_err.message())) + << "Did not find expected compile time error"; + } + return; + } + + // Compilation should have succeeded for evaluation tests + ASSERT_TRUE(compile_result_.IsValid()) + << "Compilation has validation errors in " << full_test_name << ": " + << compile_result_.FormatIssues(); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr program, + runtime_->CreateProgram(std::make_unique( + *compile_result_.GetAst()))); + + // Parse Inputs and evaluate them + google::protobuf::Arena arena; + Activation activation; + ASSERT_THAT(PopulateActivation(test, *input_evaluator_, pool_.get(), + message_factory_.get(), + context_msg_type_name_, &arena, activation), + IsOk()); + + // Evaluate Policy + auto eval_result_or = program->Evaluate(&arena, activation); + ASSERT_THAT(eval_result_or.status(), IsOk()) + << "Evaluation failed in " << full_test_name; + cel::Value actual_val = *eval_result_or; + + ASSERT_OK_AND_ASSIGN( + auto matcher, + MakeExpectedValueMatcher(output, *input_evaluator_, pool_.get(), + message_factory_.get(), &arena)); + + // Apply matcher to the output of evaluation + EXPECT_THAT(actual_val, matcher) << "Test failed: " << full_test_name; + } + + private: + std::string suite_name_; + std::unique_ptr compiler_; + std::unique_ptr runtime_; + std::shared_ptr policy_source_; + CelPolicyValidationResult compile_result_; + std::shared_ptr pool_; + std::shared_ptr message_factory_; + std::shared_ptr input_evaluator_; + std::string context_msg_type_name_; + bool expect_compile_fail_; +}; + +class CelPolicyTest : public testing::Test { + public: + explicit CelPolicyTest(std::shared_ptr runner, + cel::expr::conformance::test::TestCase test_case, + std::string full_test_name, bool skip) + : runner_(std::move(runner)), + test_case_(std::move(test_case)), + full_test_name_(std::move(full_test_name)), + skip_(skip) {} + + void TestBody() override { + if (skip_) { + GTEST_SKIP() << "Skipping test: " << full_test_name_; + } + EXPECT_NO_FATAL_FAILURE(runner_->RunTest(test_case_, full_test_name_)); + } + + private: + std::shared_ptr runner_; + cel::expr::conformance::test::TestCase test_case_; + std::string full_test_name_; + bool skip_; +}; + + +absl::Status RegisterTestSuite( + const std::filesystem::path& dir_path, const std::string& suite_name, + const std::shared_ptr& input_evaluator, + const std::shared_ptr& pool, + const std::shared_ptr& message_factory, + const std::vector& skip_tests) { + // Check if the entire suite should be skipped (prefix match) + for (const auto& skip : skip_tests) { + if (suite_name == skip || + absl::StartsWith(suite_name, absl::StrCat(skip, "/"))) { + std::cout << "[ SKIPPED SUITE ] " << suite_name << std::endl; + return absl::OkStatus(); + } + } + + std::filesystem::path policy_path = dir_path / "policy.yaml"; + std::filesystem::path tests_path = dir_path / "tests.yaml"; + bool is_yaml = true; + if (!std::filesystem::exists(tests_path)) { + tests_path = dir_path / "tests.textproto"; + is_yaml = false; + } + std::filesystem::path config_path = dir_path / "config.yaml"; + + if (!std::filesystem::exists(policy_path) || + !std::filesystem::exists(tests_path)) { + // Not a valid test suite, assume it's a directory we don't care about. + return absl::OkStatus(); + } + + // Parse Environment Config + cel::Config config; + if (std::filesystem::exists(config_path)) { + std::string config_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(config_path.string(), &config_content)); + CEL_ASSIGN_OR_RETURN(config, cel::EnvConfigFromYaml(config_content)); + } + + // Enable default extensions (optional, bindings) in the config + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "optional", cel::Config::ExtensionConfig::kLatest)); + CEL_RETURN_IF_ERROR(config.AddExtensionConfig( + "bindings", cel::Config::ExtensionConfig::kLatest)); + + // Set up compiler & runtime environments + cel::Env env; + env.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env); + env.SetConfig(config); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + env_runtime.mutable_runtime_options().enable_qualified_type_identifiers = + true; + + CEL_ASSIGN_OR_RETURN(auto compiler_builder, env.NewCompilerBuilder()); + compiler_builder->GetParserBuilder().GetOptions().enable_optional_syntax = + true; + + CEL_ASSIGN_OR_RETURN(auto compiler, compiler_builder->Build()); + + CEL_ASSIGN_OR_RETURN(auto runtime_builder, + env_runtime.CreateRuntimeBuilder()); + + // Register conformance enums + for (const auto& enum_name : + {"cel.expr.conformance.proto2.GlobalEnum", + "cel.expr.conformance.proto3.GlobalEnum", + "cel.expr.conformance.proto2.TestAllTypes.NestedEnum", + "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"}) { + auto* enum_desc = pool->FindEnumTypeByName(enum_name); + if (enum_desc != nullptr) { + CEL_RETURN_IF_ERROR(cel::extensions::RegisterProtobufEnum( + runtime_builder.type_registry(), enum_desc)); + } + } + + // Register locationCode in runtime + CEL_RETURN_IF_ERROR( + (cel::UnaryFunctionAdapter:: + RegisterGlobalOverload("locationCode", LocationCode, + runtime_builder.function_registry()))); + + CEL_ASSIGN_OR_RETURN(auto runtime, std::move(runtime_builder).Build()); + + // Parse Policy + std::string policy_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(policy_path.string(), &policy_content)); + CEL_ASSIGN_OR_RETURN(auto source, + cel::NewSource(policy_content, "policy.yaml")); + auto policy_source = std::make_shared(std::move(source)); + CEL_ASSIGN_OR_RETURN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + if (!parse_result.IsValid()) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse policy.yaml in ", suite_name, + "\nIssues:\n", parse_result.FormattedIssues())); + } + const CelPolicy* policy = parse_result.GetPolicy(); + + // Compile Policy (unexpected non-ok status represents a bug) + CEL_ASSIGN_OR_RETURN(CelPolicyValidationResult compile_result, + CompilePolicy(*compiler, *policy)); + + std::string tests_content; + CEL_RETURN_IF_ERROR( + cel::internal::GetFileContents(tests_path.string(), &tests_content)); + TestSuite test_suite; + if (is_yaml) { + CEL_ASSIGN_OR_RETURN(test_suite, + cel::test::ParsePolicyTestSuiteYaml(tests_content)); + } else { + if (!google::protobuf::TextFormat::ParseFromString(tests_content, &test_suite)) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse text proto in ", tests_path.string())); + } + } + + auto runner = std::make_shared( + suite_name, std::move(compiler), std::move(runtime), + std::move(policy_source), std::move(compile_result), pool, + message_factory, input_evaluator, config.GetContextType(), + /*expect_compile_fail=*/absl::StrContains(suite_name, "compile_errors")); + + for (const auto& section : test_suite.sections()) { + std::string section_name = section.name(); + for (const auto& test : section.tests()) { + std::string test_name = test.name(); + std::string full_test_name = + absl::StrCat(suite_name, "/", section_name, "/", test_name); + + bool skip = !ShouldRunTest(full_test_name, skip_tests); + + testing::RegisterTest( + suite_name.c_str(), + absl::StrCat(section_name, "/", test_name).c_str(), nullptr, + test_name.c_str(), __FILE__, __LINE__, + [runner, test, full_test_name, skip]() -> CelPolicyTest* { + return new CelPolicyTest(runner, test, full_test_name, skip); + }); + } + } + return absl::OkStatus(); +} + +void RegisterAllTests() { + // cel::google3-end + std::string testdata_example_flag = absl::GetFlag(FLAGS_testdata_example); + std::vector skip_tests = absl::GetFlag(FLAGS_skip_tests); + + std::string abs_testdata_example = + cel::internal::ResolveRunfilesPath(testdata_example_flag); + ABSL_CHECK(!abs_testdata_example.empty()) + << "Could not find testdata directory: " << testdata_example_flag; + + std::shared_ptr pool = + GetSharedTestingDescriptorPool(); + auto message_factory = + std::make_shared(pool.get()); + message_factory->SetDelegateToGeneratedFactory(true); + auto evaluator_or = InputEvaluator::Create(pool); + ABSL_CHECK_OK(evaluator_or.status()) << "Failed to create input evaluator"; + std::shared_ptr evaluator = std::move(evaluator_or.value()); + + std::filesystem::path testdata_path(abs_testdata_example); + ABSL_CHECK(std::filesystem::exists(testdata_path)) + << "Testdata path does not exist: " << testdata_path; + // walk up to find 'testdata' parent. A work around to portably + // get the expected directory from bazel. + while (!absl::EndsWith(testdata_path.string(), "testdata")) { + testdata_path = testdata_path.parent_path(); + ABSL_CHECK(testdata_path.string().size() > sizeof("testdata")) + << "could not resolve testdata directory"; + } + + for (const auto& entry : + std::filesystem::recursive_directory_iterator(testdata_path)) { + if (!entry.is_directory()) { + continue; + } + std::filesystem::path dir_path = entry.path(); + // Check if this directory has policy.yaml and tests.yaml (or + // tests.textproto) + if (std::filesystem::exists(dir_path / "policy.yaml") && + (std::filesystem::exists(dir_path / "tests.yaml") || + std::filesystem::exists(dir_path / "tests.textproto"))) { + std::string suite_name = absl::StrReplaceAll( + std::filesystem::relative(dir_path, testdata_path).string(), + {{"\\", "/"}}); + + ABSL_CHECK_OK(RegisterTestSuite(dir_path, suite_name, evaluator, pool, + message_factory, skip_tests)); + } + } +} + +} // namespace +} // namespace cel + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + cel::RegisterAllTests(); + return RUN_ALL_TESTS(); +} diff --git a/internal/BUILD b/internal/BUILD index 0ac5c4e46..6d0efab72 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -92,6 +92,7 @@ cc_library( hdrs = ["runfiles.h"], deps = [ "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@rules_cc//cc/runfiles", ], diff --git a/internal/runfiles.cc b/internal/runfiles.cc index 259e2e7ca..bffbfa9d1 100644 --- a/internal/runfiles.cc +++ b/internal/runfiles.cc @@ -14,11 +14,14 @@ #include "internal/runfiles.h" +#include +#include #include #include "rules_cc/cc/runfiles/runfiles.h" - #include "absl/log/absl_check.h" + +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -37,4 +40,14 @@ std::string ResolveRunfilesPath(absl::string_view path) { return runfiles->Rlocation(std::string(path)); } +absl::Status GetFileContents(absl::string_view path, std::string* out) { + std::ifstream file{std::string(path)}; + if (!file.is_open()) { + return absl::NotFoundError(absl::StrCat("Failed to open file: ", path)); + } + out->append((std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + return absl::OkStatus(); +} + } // namespace cel::internal diff --git a/internal/runfiles.h b/internal/runfiles.h index 643c677b4..11fdcf337 100644 --- a/internal/runfiles.h +++ b/internal/runfiles.h @@ -11,12 +11,15 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// Utilities for working with bazel runfiles. #ifndef THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ #define THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" namespace cel::internal { @@ -25,6 +28,9 @@ namespace cel::internal { // Intended for resolving test cases from cel-spec and cel-policy. std::string ResolveRunfilesPath(absl::string_view path); +// Read contents of a file at a resolved path to a string. +absl::Status GetFileContents(absl::string_view path, std::string* out); + } // namespace cel::internal #endif // THIRD_PARTY_CEL_CPP_INTERNAL_RUNFILES_H_ diff --git a/policy/BUILD b/policy/BUILD new file mode 100644 index 000000000..19195be2b --- /dev/null +++ b/policy/BUILD @@ -0,0 +1,239 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "cel_policy", + srcs = [ + "cel_policy.cc", + ], + hdrs = [ + "cel_policy.h", + ], + deps = [ + "//common:source", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "cel_policy_test", + srcs = ["cel_policy_test.cc"], + deps = [ + ":cel_policy", + "//common:source", + "//internal:testing", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "cel_policy_parser", + srcs = [ + "cel_policy_parse_context.cc", + "cel_policy_parse_result.cc", + ], + hdrs = [ + "cel_policy_parse_context.h", + "cel_policy_parse_result.h", + "cel_policy_parser.h", + ], + deps = [ + ":cel_policy", + "//common:source", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "yaml_policy_parser", + srcs = [ + "yaml_policy_parser.cc", + ], + hdrs = ["yaml_policy_parser.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + ":cel_policy", + ":cel_policy_parser", + "//common:source", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@yaml-cpp", + ], +) + +cc_library( + name = "cel_policy_validation_result", + srcs = [ + "cel_policy_validation_result.cc", + ], + hdrs = [ + "cel_policy_validation_result.h", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + "//common:ast", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "compiler", + srcs = ["compiler.cc"], + hdrs = ["compiler.h"], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":cel_policy_validation_result", + "//checker:type_check_issue", + "//checker:validation_result", + "//common:ast", + "//common:ast_rewrite", + "//common:constant", + "//common:container", + "//common:decl", + "//common:expr", + "//common:format_type_name", + "//common:navigable_ast", + "//common:source", + "//common:type", + "//common:type_kind", + "//compiler", + "//internal:status_macros", + "//policy/internal:issue_reporter", + "//policy/internal:optimizer_expr_factory", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "yaml_policy_parser_test", + srcs = [ + "test_custom_yaml_policy_parser.cc", + "yaml_policy_parser_test.cc", + ], + data = [ + "//policy/testdata:policy_testdata", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":yaml_policy_parser", + "//common:source", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@yaml-cpp", + ], +) + +cc_test( + name = "compiler_test", + srcs = ["compiler_test.cc"], + data = [ + "//policy/testdata:policy_testdata", + ], + deps = [ + ":cel_policy", + ":cel_policy_parser", + ":cel_policy_validation_result", + ":compiler", + ":yaml_policy_parser", + "//common:ast", + "//common:decl", + "//common:navigable_ast", + "//common:source", + "//common:type", + "//common:value", + "//common:value_testing", + "//compiler", + "//compiler:compiler_factory", + "//compiler:optional", + "//compiler:standard_library", + "//extensions:bindings_ext", + "//internal:runfiles", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "//runtime:optional_types", + "//runtime:runtime_builder", + "//runtime:runtime_options", + "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "test_util", + testonly = True, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + deps = [ + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:struct_cc_proto", + "@yaml-cpp", + ], +) diff --git a/policy/cel_policy.cc b/policy/cel_policy.cc new file mode 100644 index 000000000..c2d97edeb --- /dev/null +++ b/policy/cel_policy.cc @@ -0,0 +1,273 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "common/source.h" + +namespace cel { + +namespace { + +std::string IdDebugString(CelPolicyElementId id) { + if (id == -1) { + return ""; + } + return absl::StrCat("#", id, "> "); +} + +std::string IndentBlock(absl::string_view text) { + if (text.empty()) { + return ""; + } + std::vector lines; + for (absl::string_view line : absl::StrSplit(text, '\n')) { + if (line.empty()) { + lines.push_back(""); + } else { + lines.push_back(absl::StrCat(" ", line)); + } + } + return absl::StrJoin(lines, "\n"); +} + +} // namespace + +void CelPolicySource::NoteSourcePosition(CelPolicyElementId id, + SourcePosition position) { + source_positions_[id] = position; +} + +std::optional CelPolicySource::GetSourcePosition( + CelPolicyElementId id) const { + auto it = source_positions_.find(id); + if (it == source_positions_.end()) { + return std::nullopt; + } + return it->second; +} + +std::optional CelPolicySource::GetSourceLocation( + CelPolicyElementId id) const { + auto it = source_positions_.find(id); + if (it == source_positions_.end()) { + return std::nullopt; + } + return policy_source_->GetLocation(it->second); +} + +std::string CelPolicySource::DebugString() const { + std::string result; + + // Sort the source elements in descending order of position + std::vector> sorted_positions; + for (const auto& pair : source_positions_) { + sorted_positions.push_back(pair); + } + std::sort(sorted_positions.begin(), sorted_positions.end(), + [](const auto& a, const auto& b) { + if (a.second == b.second) { + return a.first < b.first; + } + return a.second > b.second; + }); + + result = policy_source_->content().ToString(); + for (const auto& [id, position] : sorted_positions) { + result.insert(position, IdDebugString(id)); + } + return result; +} + +std::string ValueString::DebugString() const { + return absl::StrCat(IdDebugString(id_), "\"", value_, "\""); +} + +std::string Import::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "name: ", name_.DebugString()); + return result; +} + +std::string OutputBlock::DebugString() const { + std::string result; + absl::StrAppend(&result, "output: ", output_.DebugString()); + if (explanation_.has_value()) { + absl::StrAppend(&result, "\nexplanation: ", explanation_->DebugString()); + } + return result; +} + +Match::Match(const Match& other) + : id_(other.id_), condition_(other.condition_) { + if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else { + result_ = + std::make_unique(*std::get>(other.result_)); + } +} + +Match& Match::operator=(const Match& other) { + if (this != &other) { + id_ = other.id_; + condition_ = other.condition_; + if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else if (std::holds_alternative(other.result_)) { + result_ = std::get(other.result_); + } else { + result_ = std::make_unique( + *std::get>(other.result_)); + } + } + return *this; +} + +std::string Match::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "match: {\n"); + if (condition_.has_value()) { + absl::StrAppend(&result, " condition: ", condition_->DebugString(), "\n"); + } + if (has_rule()) { + absl::StrAppend(&result, " result:\n", + IndentBlock(IndentBlock(rule().DebugString())), "\n"); + } else { + absl::StrAppend(&result, " result: {\n", + IndentBlock(IndentBlock(output_block().DebugString())), + "\n }\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string Variable::DebugString() const { + std::string result; + absl::StrAppend(&result, "variable: {\n"); + absl::StrAppend(&result, " name: ", name_.DebugString(), "\n"); + absl::StrAppend(&result, " expression: ", expression_.DebugString(), "\n"); + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + if (display_name_.has_value()) { + absl::StrAppend(&result, " display_name: ", display_name_->DebugString(), + "\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string Rule::DebugString() const { + std::string result; + absl::StrAppend(&result, IdDebugString(id_), "rule: {\n"); + if (rule_id_.has_value()) { + absl::StrAppend(&result, " rule_id: ", rule_id_->DebugString(), "\n"); + } + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + for (const Variable& variable : variables_) { + absl::StrAppend(&result, IndentBlock(variable.DebugString()), "\n"); + } + for (const Match& match : matches_) { + absl::StrAppend(&result, IndentBlock(match.DebugString()), "\n"); + } + absl::StrAppend(&result, "}"); + return result; +} + +std::string MetadataValueDebugString(std::any value) { + if (value.type() == typeid(std::monostate)) { + return "null"; + } + if (value.type() == typeid(ValueString)) { + return std::any_cast(value).DebugString(); + } + if (value.type() == typeid(bool)) { + return std::any_cast(value) ? "true" : "false"; + } + if (value.type() == typeid(int)) { + return absl::StrCat(std::any_cast(value)); + } + if (value.type() == typeid(std::string)) { + return std::any_cast(value); + } + return absl::StrCat("typeid: ", value.type().name()); +} + +std::string CelPolicy::DebugString() const { + std::string result; + absl::StrAppend(&result, "CelPolicy{\n"); + absl::StrAppend( + &result, + " ===========================================================\n"); + absl::StrAppend(&result, IndentBlock(IndentBlock(source_->DebugString())), + "\n"); + absl::StrAppend( + &result, + " ===========================================================\n"); + absl::StrAppend(&result, " name: ", name_.DebugString(), "\n"); + if (description_.has_value()) { + absl::StrAppend(&result, " description: ", description_->DebugString(), + "\n"); + } + if (display_name_.has_value()) { + absl::StrAppend(&result, " display_name: ", display_name_->DebugString(), + "\n"); + } + if (!metadata_.empty()) { + std::vector sorted_keys; + for (const auto& [key, _] : metadata_) { + sorted_keys.push_back(key); + } + std::sort(sorted_keys.begin(), sorted_keys.end()); + + absl::StrAppend(&result, " metadata: {\n"); + for (const auto& key : sorted_keys) { + const auto& value = metadata_.at(key); + absl::StrAppend(&result, " ", key, ": ", + MetadataValueDebugString(value), "\n"); + } + absl::StrAppend(&result, " }\n"); + } + if (!imports_.empty()) { + absl::StrAppend(&result, " imports:\n"); + for (const Import& import : imports_) { + absl::StrAppend(&result, " ", import.DebugString(), "\n"); + } + } + absl::StrAppend(&result, IndentBlock(rule_.DebugString()), "\n"); + absl::StrAppend(&result, "}"); + return result; +} + +} // namespace cel diff --git a/policy/cel_policy.h b/policy/cel_policy.h new file mode 100644 index 000000000..af8f7c977 --- /dev/null +++ b/policy/cel_policy.h @@ -0,0 +1,320 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "common/source.h" + +namespace cel { + +using CelPolicyElementId = int32_t; + +class CelPolicySource { + public: + explicit CelPolicySource(cel::SourcePtr policy_source) + : policy_source_(std::move(policy_source)) {} + + const Source* absl_nonnull content() const { return policy_source_.get(); } + + void NoteSourcePosition(CelPolicyElementId id, SourcePosition position); + + std::optional GetSourcePosition(CelPolicyElementId id) const; + + std::optional GetSourceLocation(CelPolicyElementId id) const; + + std::string DebugString() const; + + private: + cel::SourcePtr policy_source_; + absl::flat_hash_map source_positions_; +}; + +class ValueString { + public: + ValueString() : id_(-1) {} + + explicit ValueString(CelPolicyElementId id, absl::string_view value) + : id_(id), value_(value) {} + + CelPolicyElementId id() const { return id_; } + absl::string_view value() const { return value_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_; + std::string value_; +}; + +class Import { + public: + Import(CelPolicyElementId id, ValueString name) + : id_(id), name_(std::move(name)) {} + CelPolicyElementId id() const { return id_; } + const ValueString& name() const { return name_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_; + ValueString name_; +}; + +// Defines a variable that can be used in CEL expressions within the policy. +// Variables are evaluated once and stored in the activation context. +class Variable { + public: + const ValueString& name() const { return name_; } + void set_name(ValueString name) { name_ = std::move(name); } + + const ValueString& expression() const { return expression_; } + void set_expression(ValueString expression) { + expression_ = std::move(expression); + } + + std::optional description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + + std::optional display_name() const { return display_name_; } + void set_display_name(ValueString display_name) { + display_name_ = std::move(display_name); + } + + std::string DebugString() const; + + private: + ValueString name_; + ValueString expression_; + std::optional description_; + std::optional display_name_; +}; + +class Rule; + +class OutputBlock { + public: + OutputBlock() = default; + OutputBlock(ValueString output, std::optional explanation) + : output_(std::move(output)), explanation_(std::move(explanation)) {} + + const ValueString& output() const { return output_; } + void set_output(ValueString output) { output_ = std::move(output); } + + const std::optional& explanation() const { return explanation_; } + void set_explanation(ValueString explanation) { + explanation_ = std::move(explanation); + } + + std::string DebugString() const; + + private: + ValueString output_; + std::optional explanation_; +}; + +// Defines a match condition and result. +// If the result is a Rule, it is considered a sub-rule and will be evaluated +// only if the match condition evaluates to true. +class Match { + public: + Match() = default; + Match(const Match& other); + Match& operator=(const Match& other); + + CelPolicyElementId id() const; + void set_id(CelPolicyElementId id); + + bool has_condition() const; + std::optional condition() const; + void set_condition(ValueString condition); + + bool has_output_block() const; + const OutputBlock& output_block() const; + OutputBlock& mutable_output_block(); + + bool has_rule() const; + const Rule& rule() const; + Rule& mutable_rule(); + + void set_result(OutputBlock result); + void set_result(std::unique_ptr result); + + std::string DebugString() const; + + private: + CelPolicyElementId id_ = -1; + std::optional condition_; + std::variant> result_; +}; + +// Rule is the body of the policy and contains a list of variables and matches. +// Variables are evaluated once and stored in the activation context. +// Matches are evaluated in order and the first match is returned. If the +// match contains a sub-rule, the sub-rule is evaluated only if the match +// condition evaluates to true. +class Rule { + public: + Rule() = default; + Rule(const Rule& other) = default; + + CelPolicyElementId id() const { return id_; } + void set_id(CelPolicyElementId id) { id_ = id; } + + const std::optional& rule_id() const { return rule_id_; } + void set_rule_id(ValueString rule_id) { rule_id_ = std::move(rule_id); } + + const std::optional& description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + + const std::vector& variables() const { return variables_; } + std::vector& mutable_variables() { return variables_; } + + const std::vector& matches() const { return matches_; } + std::vector& mutable_matches() { return matches_; } + + std::string DebugString() const; + + private: + CelPolicyElementId id_ = -1; + std::optional rule_id_; + std::optional description_; + std::vector variables_; + std::vector matches_; +}; + +// CelPolicy is the top-level policy object. +// It contains a source, name, description, display name, imports, and a rule. +// The source is the CEL policy source code. +// The name, description, and display name are metadata about the policy. +// The rule is the main body of the policy. +class CelPolicy { + public: + explicit CelPolicy(std::shared_ptr source) + : source_(std::move(source)) {} + + CelPolicy(const CelPolicy& other) = default; + CelPolicy& operator=(const CelPolicy& other) = default; + + const CelPolicySource* absl_nullable source() const { return source_.get(); } + const std::shared_ptr& source_ptr() const { return source_; } + + const ValueString& name() const { return name_; } + void set_name(ValueString name) { name_ = std::move(name); } + + std::optional description() const { return description_; } + void set_description(ValueString description) { + description_ = std::move(description); + } + std::optional display_name() const { return display_name_; } + void set_display_name(ValueString display_name) { + display_name_ = std::move(display_name); + } + const absl::flat_hash_map& metadata() const { + return metadata_; + } + absl::flat_hash_map& mutable_metadata() { + return metadata_; + } + const std::vector& imports() const { return imports_; } + std::vector& mutable_imports() { return imports_; } + + const Rule& rule() const { return rule_; } + Rule& mutable_rule() { return rule_; } + + std::string DebugString() const; + + private: + std::shared_ptr source_; + ValueString name_; + std::optional description_; + std::optional display_name_; + absl::flat_hash_map metadata_; + std::vector imports_; + Rule rule_; +}; + +// Implementation details. + +inline CelPolicyElementId Match::id() const { return id_; } +inline void Match::set_id(CelPolicyElementId id) { id_ = id; } + +inline bool Match::has_condition() const { return condition_.has_value(); } + +inline std::optional Match::condition() const { + return condition_; +} + +inline void Match::set_condition(ValueString condition) { + condition_ = std::move(condition); +} + +inline bool Match::has_output_block() const { + return std::holds_alternative(result_); +} + +inline const OutputBlock& Match::output_block() const { + ABSL_DCHECK(std::holds_alternative(result_)); + return std::get(result_); +} + +inline OutputBlock& Match::mutable_output_block() { + if (!std::holds_alternative(result_)) { + result_ = OutputBlock(); + } + return std::get(result_); +} + +inline bool Match::has_rule() const { + return std::holds_alternative>(result_); +} + +inline const Rule& Match::rule() const { + ABSL_DCHECK(std::holds_alternative>(result_)); + return *std::get>(result_); +} + +inline Rule& Match::mutable_rule() { + ABSL_DCHECK(std::holds_alternative>(result_)); + return *std::get>(result_); +} + +inline void Match::set_result(OutputBlock result) { + result_ = std::move(result); +} + +inline void Match::set_result(std::unique_ptr result) { + result_ = std::move(result); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_H_ diff --git a/policy/cel_policy_parse_context.cc b/policy/cel_policy_parse_context.cc new file mode 100644 index 000000000..66861d085 --- /dev/null +++ b/policy/cel_policy_parse_context.cc @@ -0,0 +1,49 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_parse_context.h" + +#include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +CelPolicy& CelPolicyParseContext::policy() const { + ABSL_CHECK(policy_ != nullptr) + << "CelPolicyParseContext::policy() called after GetResult()"; + return *policy_; +} + +CelPolicyParseResult CelPolicyParseContext::GetResult() { + if (policy_ != nullptr && issues_.empty()) { + return CelPolicyParseResult(std::move(policy_source_), std::move(policy_), + std::move(issues_)); + } + policy_.reset(); + return CelPolicyParseResult(std::move(policy_source_), nullptr, + std::move(issues_)); +} + +void CelPolicyParseContext::ReportError(CelPolicyElementId element_id, + std::string_view message) { + issues_.push_back(CelPolicyIssue(element_id, std::string(message))); +} + +} // namespace cel diff --git a/policy/cel_policy_parse_context.h b/policy/cel_policy_parse_context.h new file mode 100644 index 000000000..6482fa1ae --- /dev/null +++ b/policy/cel_policy_parse_context.h @@ -0,0 +1,65 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +// A mutable context for parsing a CelPolicy. An instance of this class is +// created for each policy parse and is passed to the parser, which is meant to +// be stateless. +// +// Parsers call methods on this class to report issues and populate the policy +// being parsed. Call GetResult() to obtain the resulting CelPolicyParseResult, +// which takes ownership of the parsed policy. Do not use the context after +// calling GetResult(). +class CelPolicyParseContext { + public: + explicit CelPolicyParseContext(std::shared_ptr policy_source) + : policy_source_(std::move(policy_source)), + policy_(std::make_unique(policy_source_)) {} + + CelPolicySource& policy_source() const { return *policy_source_; } + + // Returns the policy being parsed. It should not be used after + // calling GetResult(). + CelPolicy& policy() const; + + // The context should not be used after calling GetResult(). + CelPolicyParseResult GetResult(); + + // Reports an error for the given element with the given error message. + void ReportError(CelPolicyElementId id, std::string_view message); + + CelPolicyElementId next_element_id() { return next_element_id_++; } + + private: + std::shared_ptr policy_source_; + CelPolicyElementId next_element_id_ = 0; + std::vector issues_; + std::unique_ptr policy_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_CONTEXT_H_ diff --git a/policy/cel_policy_parse_result.cc b/policy/cel_policy_parse_result.cc new file mode 100644 index 000000000..32d6431bb --- /dev/null +++ b/policy/cel_policy_parse_result.cc @@ -0,0 +1,91 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_parse_result.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel { +namespace { + +absl::string_view SeverityString(CelPolicyIssue::Severity severity) { + switch (severity) { + case CelPolicyIssue::Severity::kInformation: + return "INFORMATION"; + case CelPolicyIssue::Severity::kWarning: + return "WARNING"; + case CelPolicyIssue::Severity::kError: + return "ERROR"; + case CelPolicyIssue::Severity::kDeprecated: + return "DEPRECATED"; + default: + return "SEVERITY_UNSPECIFIED"; + } +} + +} // namespace + +std::string CelPolicyIssue::ToDisplayString( + const CelPolicySource* absl_nullable source) const { + SourceLocation location; + std::string description; + std::string snippet; + if (source != nullptr) { + if (relative_position_) { + std::optional base = + source->GetSourcePosition(element_id_); + if (element_id_ == -1) { + base.emplace(0); + } + if (base) { + location = source->content() + ->GetLocation(*base + *relative_position_) + .value_or(SourceLocation{}); + } + } else { + location = + source->GetSourceLocation(element_id_).value_or(SourceLocation{}); + } + description = std::string(source->content()->description()); + snippet = source->content()->DisplayErrorLocation(location); + } + + const int display_column = location.column >= 0 ? location.column + 1 : -1; + + return absl::StrFormat("%s: %s:%d:%d: %s%s", SeverityString(severity_), + description, location.line, display_column, message_, + snippet); +} + +std::string CelPolicyParseResult::FormattedIssues() const { + std::string formatted_issues; + for (const CelPolicyIssue& issue : issues_) { + if (!formatted_issues.empty()) { + absl::StrAppend(&formatted_issues, "\n"); + } + absl::StrAppend(&formatted_issues, issue.ToDisplayString(*policy_source_)); + } + return formatted_issues; +} + +} // namespace cel diff --git a/policy/cel_policy_parse_result.h b/policy/cel_policy_parse_result.h new file mode 100644 index 000000000..2bf80b1ce --- /dev/null +++ b/policy/cel_policy_parse_result.h @@ -0,0 +1,105 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel { + +class CelPolicyIssue { + public: + enum class Severity { kInformation, kDeprecated, kWarning, kError }; + + CelPolicyIssue(CelPolicyElementId element_id, absl::string_view message) + : element_id_(element_id), message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, Severity severity, + absl::string_view message) + : element_id_(element_id), severity_(severity), message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, + SourcePosition relative_position, absl::string_view message) + : element_id_(element_id), + relative_position_(relative_position), + message_(message) {} + CelPolicyIssue(CelPolicyElementId element_id, + SourcePosition relative_position, Severity severity, + absl::string_view message) + : element_id_(element_id), + relative_position_(relative_position), + severity_(severity), + message_(message) {} + + std::string ToDisplayString( + const CelPolicySource* absl_nullable source) const; + std::string ToDisplayString(const CelPolicySource& source) const { + return ToDisplayString(&source); + } + + Severity severity() const { return severity_; } + absl::string_view message() const { return message_; } + + private: + CelPolicyElementId element_id_; + std::optional relative_position_; + Severity severity_ = Severity::kError; + std::string message_; +}; + +class CelPolicyParseResult { + public: + explicit CelPolicyParseResult(std::shared_ptr policy_source, + std::unique_ptr policy, + std::vector issues) + : policy_source_(std::move(policy_source)), + policy_(std::move(policy)), + issues_(std::move(issues)) {} + + bool IsValid() const { return policy_ != nullptr; } + + const CelPolicy* absl_nullable GetPolicy() const { return policy_.get(); } + + absl::StatusOr> ReleasePolicy() { + if (policy_ == nullptr) { + return absl::FailedPreconditionError( + "CelPolicyParseResult is empty. Check for Issues."); + } + return std::move(policy_); + } + + absl::Span GetIssues() const { return issues_; } + + std::string FormattedIssues() const; + + private: + std::shared_ptr policy_source_; + absl_nullable std::unique_ptr policy_; + std::vector issues_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSE_RESULT_H_ diff --git a/policy/cel_policy_parser.h b/policy/cel_policy_parser.h new file mode 100644 index 000000000..0a11c9e68 --- /dev/null +++ b/policy/cel_policy_parser.h @@ -0,0 +1,40 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ + +#include "absl/status/status.h" +#include "policy/cel_policy_parse_context.h" + +namespace cel { + +// A policy parser for a given policy format. The type `T` parameter is the +// representation of the input file format, such as `` for YAML. +// +// Parsers are intended to be stateless: all state, including the resulting +// policy and any issues encountered, should be kept in the context passed to +// the `ParsePolicy` method. +template +class CelPolicyParser { + public: + virtual ~CelPolicyParser() = default; + + // Parses the input and populates a CelPolicy in the context. + virtual absl::Status ParsePolicy(CelPolicyParseContext& ctx) const = 0; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_PARSER_H_ diff --git a/policy/cel_policy_test.cc b/policy/cel_policy_test.cc new file mode 100644 index 000000000..640247e7f --- /dev/null +++ b/policy/cel_policy_test.cc @@ -0,0 +1,220 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy.h" + +#include +#include +#include +#include + +#include "absl/strings/str_replace.h" +#include "common/source.h" +#include "internal/testing.h" + +namespace cel { +namespace { + +using testing::Field; +using testing::Optional; +using testing::SizeIs; + +TEST(CelPolicyBuilderTest, Build) { + CelPolicyElementId next_id = 1; + ASSERT_OK_AND_ASSIGN(SourcePtr source, NewSource("CEL\n policy\n source")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + CelPolicy policy(policy_source); + policy.set_name(ValueString(next_id++, "test_policy")); + policy.set_description(ValueString(next_id++, "test_description")); + policy.set_display_name(ValueString(next_id++, "test_display_name")); + ValueString import1_name = ValueString(next_id++, "test_import1"); + policy.mutable_imports().push_back(Import(next_id++, import1_name)); + ValueString import2_name = ValueString(next_id++, "test_import2"); + policy.mutable_imports().push_back(Import(next_id++, import2_name)); + + Rule& rule = policy.mutable_rule(); + rule.set_id(next_id++); + rule.set_rule_id(ValueString(next_id++, "test_rule_id")); + rule.set_description(ValueString(next_id++, "test_rule_description")); + + Variable variable; + variable.set_name(ValueString(next_id++, "test_variable")); + variable.set_expression(ValueString(next_id++, "test_expression")); + variable.set_description(ValueString(next_id++, "test_variable_description")); + variable.set_display_name( + ValueString(next_id++, "test_variable_display_name")); + + Match match1; + match1.set_id(next_id++); + match1.set_condition(ValueString(next_id++, "test_condition")); + CelPolicyElementId output_id = next_id++; + CelPolicyElementId explanation_id = next_id++; + match1.set_result( + OutputBlock(ValueString(output_id, "test_result"), + ValueString(explanation_id, "test_explanation"))); + + Match match2; + match2.set_id(next_id++); + match2.set_condition(ValueString(next_id++, "test_condition2")); + + auto sub_rule = std::make_unique(); + sub_rule->set_id(next_id++); + sub_rule->set_rule_id(ValueString(next_id++, "sub_rule_id")); + sub_rule->set_description(ValueString(next_id++, "sub_rule_description")); + Match sub_rule_match; + sub_rule_match.set_id(next_id++); + sub_rule_match.set_condition(ValueString(next_id++, "sub_rule_condition")); + sub_rule_match.set_result( + OutputBlock(ValueString(next_id++, "sub_rule_result"), std::nullopt)); + sub_rule->mutable_matches().push_back(sub_rule_match); + + match2.set_result(std::move(sub_rule)); + + rule.mutable_variables().push_back(variable); + rule.mutable_matches().push_back(match1); + rule.mutable_matches().push_back(match2); + + EXPECT_EQ(policy.name().value(), "test_policy"); + ASSERT_TRUE(policy.description().has_value()); + EXPECT_EQ(policy.description()->value(), "test_description"); + ASSERT_TRUE(policy.display_name().has_value()); + EXPECT_EQ(policy.display_name()->value(), "test_display_name"); + + ASSERT_THAT(policy.imports(), SizeIs(2)); + + EXPECT_EQ(policy.imports()[0].name().value(), "test_import1"); + EXPECT_EQ(policy.imports()[1].name().value(), "test_import2"); + ASSERT_TRUE(policy.rule().rule_id().has_value()); + EXPECT_EQ(policy.rule().rule_id()->value(), "test_rule_id"); + ASSERT_TRUE(policy.rule().description().has_value()); + EXPECT_EQ(policy.rule().description()->value(), "test_rule_description"); + + ASSERT_THAT(policy.rule().variables(), SizeIs(1)); + + EXPECT_EQ(policy.rule().variables()[0].name().value(), "test_variable"); + EXPECT_EQ(policy.rule().variables()[0].expression().value(), + "test_expression"); + ASSERT_TRUE(policy.rule().variables()[0].description().has_value()); + EXPECT_EQ(policy.rule().variables()[0].description()->value(), + "test_variable_description"); + ASSERT_TRUE(policy.rule().variables()[0].display_name().has_value()); + EXPECT_EQ(policy.rule().variables()[0].display_name()->value(), + "test_variable_display_name"); + + ASSERT_THAT(policy.rule().matches(), SizeIs(2)); + + EXPECT_EQ(policy.rule().matches()[0].condition().value().value(), + "test_condition"); + ASSERT_TRUE(policy.rule().matches()[0].has_output_block()); + EXPECT_EQ(policy.rule().matches()[0].output_block().output().value(), + "test_result"); + ASSERT_TRUE( + policy.rule().matches()[0].output_block().explanation().has_value()); + EXPECT_EQ(policy.rule().matches()[0].output_block().explanation()->value(), + "test_explanation"); + + EXPECT_EQ(policy.rule().matches()[1].condition().value().value(), + "test_condition2"); + ASSERT_TRUE(policy.rule().matches()[1].has_rule()); + ASSERT_TRUE(policy.rule().matches()[1].rule().rule_id().has_value()); + EXPECT_EQ(policy.rule().matches()[1].rule().rule_id()->value(), + "sub_rule_id"); + ASSERT_TRUE(policy.rule().matches()[1].rule().description().has_value()); + EXPECT_EQ(policy.rule().matches()[1].rule().description()->value(), + "sub_rule_description"); + ASSERT_THAT(policy.rule().matches()[1].rule().matches(), SizeIs(1)); + EXPECT_EQ(policy.rule() + .matches()[1] + .rule() + .matches()[0] + .condition() + .value() + .value(), + "sub_rule_condition"); + + std::string actual = policy.DebugString(); + EXPECT_EQ(actual, absl::StrReplaceAll(R"(CelPolicy{ + =========================================================== + CEL + policy + source + =========================================================== + name: #1> "test_policy" + description: #2> "test_description" + display_name: #3> "test_display_name" + imports: + #5> name: #4> "test_import1" + #7> name: #6> "test_import2" + #8> rule: { + rule_id: #9> "test_rule_id" + description: #10> "test_rule_description" + variable: { + name: #11> "test_variable" + expression: #12> "test_expression" + description: #13> "test_variable_description" + display_name: #14> "test_variable_display_name" + } + #15> match: { + condition: #16> "test_condition" + result: { + output: #17> "test_result" + explanation: #18> "test_explanation" + } + } + #19> match: { + condition: #20> "test_condition2" + result: + #21> rule: { + rule_id: #22> "sub_rule_id" + description: #23> "sub_rule_description" + #24> match: { + condition: #25> "sub_rule_condition" + result: { + output: #26> "sub_rule_result" + } + } + } + } + } + })", + {{"\n ", "\n"}})); +} + +TEST(CelPolicySourceTest, Build) { + std::string source = + "name: test_policy\n imports:\n - name: test_import\n"; + + ASSERT_OK_AND_ASSIGN(SourcePtr source_ptr, NewSource(source)); + CelPolicySource policy_source(std::move(source_ptr)); + policy_source.NoteSourcePosition(1, source.find("test_policy")); + policy_source.NoteSourcePosition(2, source.find("test_import")); + + EXPECT_THAT(policy_source.GetSourcePosition(1), Optional(6)); + EXPECT_THAT(policy_source.GetSourceLocation(1), + Optional(AllOf(Field(&SourceLocation::line, 1), + Field(&SourceLocation::column, 6)))); + EXPECT_THAT(policy_source.GetSourcePosition(2), Optional(44)); + EXPECT_THAT(policy_source.GetSourceLocation(2), + Optional(AllOf(Field(&SourceLocation::line, 3), + Field(&SourceLocation::column, 13)))); + EXPECT_EQ(policy_source.GetSourcePosition(3), std::nullopt); + EXPECT_EQ(policy_source.GetSourceLocation(3), std::nullopt); + EXPECT_EQ( + policy_source.DebugString(), + "name: #1> test_policy\n imports:\n - name: #2> test_import\n"); +} + +} // namespace +} // namespace cel diff --git a/policy/cel_policy_validation_result.cc b/policy/cel_policy_validation_result.cc new file mode 100644 index 000000000..e257f064c --- /dev/null +++ b/policy/cel_policy_validation_result.cc @@ -0,0 +1,32 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/cel_policy_validation_result.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +std::string CelPolicyValidationResult::FormatIssues() const { + return absl::StrJoin( + issues_, "\n", [this](std::string* out, const CelPolicyIssue& issue) { + absl::StrAppend(out, issue.ToDisplayString(source_.get())); + }); +} + +} // namespace cel diff --git a/policy/cel_policy_validation_result.h b/policy/cel_policy_validation_result.h new file mode 100644 index 000000000..bddb9a3ca --- /dev/null +++ b/policy/cel_policy_validation_result.h @@ -0,0 +1,84 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel { + +// CelPolicyValidationResult holds the result of policy compilation. +// +// Policy compilation/validation errors are captured in issues. +class CelPolicyValidationResult { + public: + CelPolicyValidationResult( + std::unique_ptr ast, std::vector issues, + std::shared_ptr source = nullptr) + : ast_(std::move(ast)), + issues_(std::move(issues)), + source_(std::move(source)) {} + + explicit CelPolicyValidationResult( + std::vector issues, + std::shared_ptr source = nullptr) + : ast_(nullptr), issues_(std::move(issues)), source_(std::move(source)) {} + + // Returns true if validation succeeded and an AST is present. + bool IsValid() const { return ast_ != nullptr; } + + // Returns the AST if validation was successful. + const Ast* absl_nullable GetAst() const { return ast_.get(); } + + // Moves out and returns the AST. + absl::StatusOr> ReleaseAst() { + if (ast_ == nullptr) { + return absl::FailedPreconditionError( + "CelPolicyValidationResult is empty. Check for CelPolicyIssues."); + } + return std::move(ast_); + } + + // Returns the list of issues encountered during compilation. + absl::Span GetIssues() const { return issues_; } + + // Returns the contained policy source, if any. + const CelPolicySource* absl_nullable GetSource() const { + return source_.get(); + } + + // Returns a formatted error string of the compiled issues. + std::string FormatIssues() const; + + private: + absl_nullable std::unique_ptr ast_; + std::vector issues_; + std::shared_ptr source_; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_CEL_POLICY_VALIDATION_RESULT_H_ diff --git a/policy/compiler.cc b/policy/compiler.cc new file mode 100644 index 000000000..7a892447c --- /dev/null +++ b/policy/compiler.cc @@ -0,0 +1,1058 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/compiler.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "checker/type_check_issue.h" +#include "checker/validation_result.h" +#include "common/ast.h" +#include "common/ast_rewrite.h" +#include "common/constant.h" +#include "common/container.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/format_type_name.h" +#include "common/navigable_ast.h" +#include "common/source.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "compiler/compiler.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/internal/issue_reporter.h" +#include "policy/internal/optimizer_expr_factory.h" +#include "google/protobuf/arena.h" + +namespace cel { +namespace { + +constexpr absl::string_view kCelBlock = "cel.@block"; + +enum class RuleSemantics { + // TODO(b/506179116): will also need "aggregate" or similar concept. + kFirstMatch, + + kNotForUseWithExhaustiveSwitchStatements, +}; + +template +void AbslStringify(Sink& s, RuleSemantics semantics) { + switch (semantics) { + case RuleSemantics::kFirstMatch: + s.Append("first_match"); + return; + default: + s.Append(""); + return; + } +} + +struct EmbeddedAst { + CelPolicyElementId id; + std::unique_ptr ast; +}; + +struct CompiledVariable { + std::string ident; + EmbeddedAst ast; +}; + +struct CompiledOutputBlock { + EmbeddedAst output_ast; + cel::Type result_type; + std::optional explanation_ast; +}; + +struct CompiledRule; + +struct CompiledMatch { + using Production = + std::variant absl_nonnull, + CompiledOutputBlock>; + + CelPolicyElementId id; + std::optional condition; + Production production; +}; + +struct CompiledRule { + CelPolicyElementId id; + std::vector variables; + std::vector matches; + // Not set if cannot be determined. + std::optional result_type; +}; + +std::optional GetOutputType( + const CompiledMatch::Production& production) { + return std::visit( + [](const auto& production) -> std::optional { + if constexpr (std::is_same_v, + CompiledOutputBlock>) { + return production.result_type; + } else if constexpr (std::is_same_v, + std::unique_ptr>) { + return production->result_type; + } + return std::nullopt; + }, + production); +} + +// Internal representation of the compiled policy elements. +// +// This is used for checking the component expression before composing into the +// final AST based on the provided rule semantics. +class IntermediateCompiledPolicy { + public: + CompiledRule& mutable_root_rule() { return root_rule_; } + + const CompiledRule& root_rule() const { return root_rule_; } + + void set_name(absl::string_view name) { name_ = name; } + absl::string_view name() const { return name_; } + void set_display_name(absl::string_view display_name) { + display_name_ = display_name; + } + absl::string_view display_name() const { return display_name_; } + void set_description(absl::string_view description) { + description_ = description; + } + absl::string_view description() const { return description_; } + + void set_semantics(RuleSemantics semantics) { semantics_ = semantics; } + RuleSemantics semantics() const { return semantics_; } + + private: + std::string name_; + std::string display_name_; + std::string description_; + RuleSemantics semantics_ = RuleSemantics::kFirstMatch; + + CompiledRule root_rule_; +}; + +CelPolicyIssue::Severity MapSeverity(cel::TypeCheckIssue::Severity severity) { + switch (severity) { + case cel::TypeCheckIssue::Severity::kError: + return CelPolicyIssue::Severity::kError; + case cel::TypeCheckIssue::Severity::kWarning: + return CelPolicyIssue::Severity::kWarning; + case cel::TypeCheckIssue::Severity::kDeprecated: + return CelPolicyIssue::Severity::kDeprecated; + default: + return CelPolicyIssue::Severity::kError; + } +} + +bool IsWrapperOf(cel::TypeKind wrapper_kind, cel::TypeKind primitive_kind) { + switch (wrapper_kind) { + case cel::TypeKind::kBoolWrapper: + return primitive_kind == cel::TypeKind::kBool; + case cel::TypeKind::kIntWrapper: + return primitive_kind == cel::TypeKind::kInt; + case cel::TypeKind::kUintWrapper: + return primitive_kind == cel::TypeKind::kUint; + case cel::TypeKind::kDoubleWrapper: + return primitive_kind == cel::TypeKind::kDouble; + case cel::TypeKind::kStringWrapper: + return primitive_kind == cel::TypeKind::kString; + case cel::TypeKind::kBytesWrapper: + return primitive_kind == cel::TypeKind::kBytes; + default: + return false; + } +} + +cel::Type FilterSpecialTypes(cel::Type type) { + if (type.IsTypeParam()) { + // Free type param should not appear in the output type, but if it does, + // force it to dyn. + return DynType(); + } + if (type.IsEnum()) { + return IntType{}; + } + if (type.IsError()) { + return DynType(); + } + if (type.IsType()) { + // drop parameters so all type types are compatible. + return TypeType{}; + } + return type; +} + +// Returns true if `from` is assignable to `to`. +// +// Slightly adjusted from the standard routine to cover some edge cases around +// null and wrappers. +// +// TODO(b/522391716): try to standardize assignability checks. +bool OutputTypeIsAssignable(cel::Type from, cel::Type to) { + from = FilterSpecialTypes(from); + to = FilterSpecialTypes(to); + + // Any and dyn are assignable to/from everything. + if (from.kind() == cel::TypeKind::kAny || + from.kind() == cel::TypeKind::kDyn || to.kind() == cel::TypeKind::kAny || + to.kind() == cel::TypeKind::kDyn) { + return true; + } + + // Wrappers auto-unwrap. + if (IsWrapperOf(from.kind(), to.kind()) || + IsWrapperOf(to.kind(), from.kind())) { + return true; + } + + // Null is assignable to anything that is message-like. + if (from.kind() == cel::TypeKind::kNull) { + switch (to.kind()) { + case cel::TypeKind::kNull: + case cel::TypeKind::kStruct: + case cel::TypeKind::kOpaque: + case cel::TypeKind::kTimestamp: + case cel::TypeKind::kDuration: + case cel::TypeKind::kBytesWrapper: + case cel::TypeKind::kBoolWrapper: + case cel::TypeKind::kIntWrapper: + case cel::TypeKind::kUintWrapper: + case cel::TypeKind::kDoubleWrapper: + case cel::TypeKind::kStringWrapper: + return true; + default: + return false; + } + } + + if (from.kind() != to.kind()) { + return false; + } + + if (from.name() != to.name()) { + return false; + } + + if (from.GetParameters().size() != to.GetParameters().size()) { + return false; + } + + for (int i = 0; i < from.GetParameters().size(); ++i) { + if (!OutputTypeIsAssignable(from.GetParameters()[i], + to.GetParameters()[i])) { + return false; + } + } + + return true; +} + +bool OutputTypeIsCompatible(cel::Type from, cel::Type to) { + // We don't handle widening like in a self-contained CEL expression, but + // permit some cases where one type is more specific than the other. + return OutputTypeIsAssignable(from, to) || OutputTypeIsAssignable(to, from); +} + +bool HasErrors(const policy_internal::IssueReporter& issues) { + for (const auto& issue : issues.issues()) { + if (issue.severity() == CelPolicyIssue::Severity::kError) { + return true; + } + } + return false; +} + +// Note on lifetime safety: +// +// The output policy will contain references to types that are owned by the +// arena member of this class. This is safe as long as the policy compiler lives +// as long as the output policies. +class PolicyCompiler { + public: + explicit PolicyCompiler(policy_internal::IssueReporter* issues, + std::unique_ptr base_compiler) + : issues_(*issues), base_compiler_(std::move(base_compiler)) {} + + absl::string_view GetSourceDescription() const { + if (src_ == nullptr) { + return ""; + } + return src_->content()->description(); + } + + void AdaptTypeCheckIssues(CelPolicyElementId id, const ValidationResult& r) { + const Source* source = r.GetSource(); + + for (const auto& iss : r.GetIssues()) { + std::optional offset; + if (source != nullptr) { + offset = source->GetPosition(iss.location()); + } + if (offset.has_value()) { + issues_.ReportOffsetIssue(id, offset.value(), + MapSeverity(iss.severity()), iss.message()); + continue; + } + issues_.ReportIssue(id, MapSeverity(iss.severity()), iss.message()); + } + } + + absl::StatusOr CompileOutputBlock( + const cel::OutputBlock& output_block, const Compiler* env) { + CompiledOutputBlock output; + CEL_ASSIGN_OR_RETURN(auto output_validation, + env->Compile(output_block.output().value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(output_block.output().id(), output_validation); + + cel::Type result_type = DynType(); + if (output_validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, output_validation.ReleaseAst()); + auto root_expr_id = ast->root_expr().id(); + output.output_ast = + EmbeddedAst{output_block.output().id(), std::move(ast)}; + if (auto it = output_validation.GetResolvedTypeMap().find(root_expr_id); + it != output_validation.GetResolvedTypeMap().end()) { + result_type = it->second; + } + } + if (output_block.explanation().has_value()) { + CEL_ASSIGN_OR_RETURN(auto explanation_validation, + env->Compile(output_block.explanation()->value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(output_block.explanation()->id(), + explanation_validation); + if (explanation_validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, explanation_validation.ReleaseAst()); + if (ast->GetReturnType().primitive() != PrimitiveType::kString) { + issues_.ReportError(output_block.explanation()->id(), + "explanation must evaluate to string"); + } else { + output.explanation_ast = + EmbeddedAst{output_block.explanation()->id(), std::move(ast)}; + } + } + } + output.result_type = result_type; + return output; + } + + absl::Status CompileMatch(const Match& match, const Compiler* env, + CompiledRule* out) { + CompiledMatch c_match; + c_match.id = match.id(); + if (match.condition().has_value()) { + CEL_ASSIGN_OR_RETURN(auto validation, + env->Compile(match.condition()->value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(match.condition()->id(), validation); + if (validation.IsValid()) { + CEL_ASSIGN_OR_RETURN(auto ast, validation.ReleaseAst()); + if (ast->GetReturnType().primitive() != PrimitiveType::kBool) { + issues_.ReportError(match.condition()->id(), + "condition must evaluate to bool"); + } + c_match.condition = + EmbeddedAst{match.condition()->id(), std::move(ast)}; + } + } + + if (match.has_output_block()) { + CEL_ASSIGN_OR_RETURN(c_match.production, + CompileOutputBlock(match.output_block(), env)); + } else if (match.has_rule()) { + auto rule = std::make_unique(); + CEL_RETURN_IF_ERROR(CompileRule(match.rule(), env, rule.get())); + c_match.production = std::move(rule); + } else { + issues_.ReportError(match.id(), "match must specify an output or rule"); + } + out->matches.push_back(std::move(c_match)); + return absl::OkStatus(); + } + + absl::Status CompileRule(const Rule& rule, const cel::Compiler* env, + CompiledRule* out) { + out->id = rule.id(); + std::unique_ptr buf; + + absl::flat_hash_set seen_variables; + for (const auto& variable : rule.variables()) { + std::string name(variable.name().value()); + if (!seen_variables.insert(name).second) { + issues_.ReportError( + variable.expression().id(), + absl::StrCat("overlapping identifier for name 'variables.", name, + "'")); + continue; + } + std::string ident = absl::StrCat("variables.", name); + CEL_ASSIGN_OR_RETURN(auto validation, + env->Compile(variable.expression().value(), + GetSourceDescription(), &arena_)); + AdaptTypeCheckIssues(variable.expression().id(), validation); + if (!validation.IsValid()) { + continue; + } + CEL_ASSIGN_OR_RETURN(auto ast, validation.ReleaseAst()); + cel::Type result_type = DynType(); + + if (auto it = validation.GetResolvedTypeMap().find(ast->root_expr().id()); + it != validation.GetResolvedTypeMap().end()) { + result_type = it->second; + } + out->variables.push_back(CompiledVariable{ + ident, + EmbeddedAst{variable.expression().id(), std::move(ast)}, + }); + auto next = env->ToBuilder(); + auto status = next->GetCheckerBuilder().AddOrReplaceVariable( + MakeVariableDecl(ident, result_type)); + if (!status.ok()) { + issues_.ReportError(variable.expression().id(), status.message()); + continue; + } + CEL_ASSIGN_OR_RETURN(buf, next->Build()); + env = buf.get(); + } + + std::optional overall_type; + for (const auto& match : rule.matches()) { + CEL_RETURN_IF_ERROR(CompileMatch(match, env, out)); + if (!overall_type.has_value()) { + overall_type = GetOutputType(out->matches.back().production); + continue; + } + + if (std::optional match_type = + GetOutputType(out->matches.back().production); + match_type.has_value()) { + if (!OutputTypeIsCompatible(*match_type, *overall_type)) { + issues_.ReportError( + match.id(), + absl::StrCat("incompatible output types: block has output type ", + FormatTypeName(*match_type), + ", but previous outputs have type ", + FormatTypeName(*overall_type))); + } + } + } + + out->result_type = overall_type; + return absl::OkStatus(); + } + + absl::Status CompilePolicy(const CelPolicy& policy, + IntermediateCompiledPolicy* out) { + src_ = policy.source(); + out->set_semantics(RuleSemantics::kFirstMatch); + out->set_name(policy.name().value()); + out->set_display_name( + policy.display_name().value_or(ValueString{}).value()); + out->set_description(policy.description().value_or(ValueString{}).value()); + + return CompileRule(policy.rule(), base_compiler_.get(), + &out->mutable_root_rule()); + } + + private: + google::protobuf::Arena arena_; + const CelPolicySource* absl_nullable src_; + policy_internal::IssueReporter& issues_; + std::unique_ptr base_compiler_; +}; + +bool IsExhaustive(const CompiledRule& rule); + +class FirstMatchComposer { + public: + FirstMatchComposer(const IntermediateCompiledPolicy& icp, + const Compiler& compiler, + policy_internal::IssueReporter& issues) + : issues_(issues), icp_(icp), compiler_(compiler) {} + + absl::Status Compose(); + + bool success() const { return ast_ != nullptr; } + + std::unique_ptr ReleaseAst() { return std::move(ast_); } + + private: + using VariableScope = absl::flat_hash_map; + + std::optional ResolvePolicyVariable(absl::string_view reference); + + absl::flat_hash_map ResolveBlockIndexes(const Ast& ast); + + bool CheckMatchStructure(const CompiledRule& rule); + + // Returns true if already optional wrapped. + absl::StatusOr ComposeRule(const CompiledRule& rule, Expr& init, + Expr& insertion_expr); + + // returns true if already optional wrapped. + absl::StatusOr ComposeProduction( + const CompiledRule& rule, const CompiledMatch::Production& production, + Expr& init, Expr& insertion_expr); + + void MapVariables(Ast& ast); + + void ComposeRuleVariables(const CompiledRule& rule, Expr& init, + Expr& insertion_expr); + + policy_internal::IssueReporter& issues_; + OptimizerExprFactory factory_; + const IntermediateCompiledPolicy& icp_; + const Compiler& compiler_; + std::vector scopes_; + bool optionalize_ = false; + std::unique_ptr ast_; +}; + +absl::Status FirstMatchComposer::Compose() { + ABSL_DCHECK(icp_.semantics() == RuleSemantics::kFirstMatch); + + factory_.mutable_ast().mutable_root_expr() = factory_.NewCall( + "cel.@block", factory_.NewList(), factory_.NewUnspecified()); + auto& block_init_list = factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[0]; + auto& insertion_expr = factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[1]; + optionalize_ = !IsExhaustive(icp_.root_rule()); + if (!CheckMatchStructure(icp_.root_rule())) { + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN( + bool optional_wrapped, + ComposeRule(icp_.root_rule(), block_init_list, insertion_expr)); + + if (optional_wrapped != optionalize_) { + return absl::InternalError( + "composition failed to handle non-exhaustive rules"); + } + + CEL_ASSIGN_OR_RETURN(cel::ValidationResult result, + compiler_.GetTypeChecker().Check(factory_.ast())); + if (!result.IsValid()) { + for (const auto& iss : result.GetIssues()) { + issues_.ReportError(icp_.root_rule().id, iss.message()); + } + return absl::OkStatus(); + } + + CEL_ASSIGN_OR_RETURN(ast_, result.ReleaseAst()); + + return absl::OkStatus(); +} + +bool IsTriviallyTrueCondition(const CompiledMatch& match) { + if (!match.condition.has_value() || match.condition->ast == nullptr) { + return true; + } + const cel::Expr& expr = match.condition->ast->root_expr(); + if (expr.has_const_expr()) { + const cel::Constant& const_expr = expr.const_expr(); + if (const_expr.has_bool_value() && const_expr.bool_value()) { + return true; + } + } + return false; +} + +bool IsExhaustive(const CompiledRule& rule); + +bool IsExhaustive(const CompiledMatch& match) { + if (std::holds_alternative(match.production)) { + return true; + } + + const auto* nested_rule_ptr = + std::get_if>(&match.production); + ABSL_DCHECK(nested_rule_ptr != nullptr); + const CompiledRule& nested_rule = **nested_rule_ptr; + return IsExhaustive(nested_rule); +} + +bool IsExhaustive(const CompiledRule& rule) { + if (rule.matches.empty()) { + // Validation should fail, but generalization would be false. + return false; + } + bool has_default = false; + for (const auto& match : rule.matches) { + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + // If this isn't the last match in the rule, it should get flagged + // during validation since it means there are trivially unreachable + // matches. + has_default = true; + } + if (!IsTriviallyTrueCondition(match) && !IsExhaustive(match)) { + // There is a nested rule that might return an optional.none(). + return false; + } + } + // Otherwise, everything in this branch is exhaustive so we can defer + // wrapping. + return has_default; +} + +bool FirstMatchComposer::CheckMatchStructure(const CompiledRule& rule) { + if (rule.matches.empty()) { + issues_.ReportError(rule.id, "rule does not specify match conditions"); + return false; + } + + bool valid = true; + bool seen_trivially_true = false; + + for (const auto& match : rule.matches) { + if (seen_trivially_true) { + if (std::holds_alternative(match.production)) { + issues_.ReportError(match.id, "match creates unreachable outputs"); + } else if (std::holds_alternative>( + match.production)) { + issues_.ReportError(match.id, "rule creates unreachable outputs"); + } + valid = false; + } + + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + seen_trivially_true = true; + } + + if (auto* nested_rule = + std::get_if>(&match.production); + nested_rule != nullptr) { + ABSL_DCHECK(*nested_rule != nullptr); + if (!CheckMatchStructure(**nested_rule)) { + valid = false; + } + } + } + + return valid; +} + +std::optional FirstMatchComposer::ResolvePolicyVariable( + absl::string_view reference) { + for (auto scope_iter = scopes_.rbegin(); scope_iter != scopes_.rend(); + ++scope_iter) { + if (auto it = scope_iter->find(reference); it != scope_iter->end()) { + return it->second; + } + } + return std::nullopt; +} + +class IndexRewrite : public AstRewriterBase { + public: + explicit IndexRewrite(absl::flat_hash_map expr_id_to_index, + OptimizerExprFactory& factory) + : expr_id_to_index_(std::move(expr_id_to_index)), factory_(factory) {} + + bool PreVisitRewrite(Expr& e) override { + if (auto it = expr_id_to_index_.find(e.id()); + it != expr_id_to_index_.end()) { + e.mutable_ident_expr().set_name(absl::StrCat("@index", it->second)); + factory_.RecordReplacement(e.id(), e); + return true; + } + return false; + } + + private: + absl::flat_hash_map expr_id_to_index_; + OptimizerExprFactory& factory_; +}; + +absl::StatusOr FirstMatchComposer::ComposeRule(const CompiledRule& rule, + Expr& init, + Expr& insertion_expr) { + scopes_.emplace_back(); + auto pop_scope = absl::MakeCleanup([this]() { scopes_.pop_back(); }); + ComposeRuleVariables(rule, init, insertion_expr); + Expr* insertion_point = &insertion_expr; + const bool has_default = IsTriviallyTrueCondition(rule.matches.back()); + const bool needs_wrap = !IsExhaustive(rule); + size_t end = rule.matches.size() - (has_default ? 1 : 0); + for (size_t i = 0; i < end; i++) { + const auto& match = rule.matches[i]; + if (IsTriviallyTrueCondition(match) && IsExhaustive(match)) { + return absl::InternalError("detected unreachable match after validation"); + } + + Expr production; + CEL_ASSIGN_OR_RETURN( + bool is_wrapped, + ComposeProduction(rule, match.production, init, production)); + if (needs_wrap && !is_wrapped) { + production = factory_.NewCall("optional.of", std::move(production)); + } + + if (!IsTriviallyTrueCondition(match)) { + Ast condition = *match.condition->ast; + MapVariables(condition); + factory_.StartCopyContext(); + auto copy = factory_.Copy(condition.root_expr()); + auto source_info = factory_.RemapSourceInfo(condition.source_info()); + factory_.MergeSourceInfo(source_info); + *insertion_point = factory_.NewCall("_?_:_", std::move(copy)); + insertion_point->mutable_call_expr().mutable_args().push_back( + std::move(production)); + ABSL_DCHECK(!(!needs_wrap && is_wrapped)) + << "unexpected wrapping in exhaustive policy."; + insertion_point = &insertion_point->mutable_call_expr().add_args(); + continue; + } + + if (!is_wrapped) { + return absl::InternalError( + "composition failed. expected optional wrapped rule but got a plain " + "value"); + } + auto fn = needs_wrap ? "or" : "orValue"; + *insertion_point = factory_.NewMemberCall(fn, std::move(production)); + insertion_point = &insertion_point->mutable_call_expr().add_args(); + } + + if (has_default) { + const auto& match = rule.matches.back(); + Expr production; + CEL_ASSIGN_OR_RETURN( + bool is_wrapped, + ComposeProduction(rule, match.production, init, production)); + if (needs_wrap && !is_wrapped) { + production = factory_.NewCall("optional.of", std::move(production)); + } + *insertion_point = std::move(production); + ABSL_DCHECK(!(!needs_wrap && is_wrapped)) + << "unexpected wrapping in exhaustive policy."; + + return needs_wrap; + } + + // Otherwise, we fell through a non-exhaustive rule. + *insertion_point = factory_.NewCall("optional.none"); + return true; +} + +absl::StatusOr FirstMatchComposer::ComposeProduction( + const CompiledRule& rule, const CompiledMatch::Production& production, + Expr& init, Expr& insertion_expr) { + if (auto* nested_rule = + std::get_if>(&production); + nested_rule != nullptr) { + return ComposeRule(**nested_rule, init, insertion_expr); + } + auto* output = std::get_if(&production); + if (output == nullptr) { + return absl::InternalError("unexpected rule production type"); + } + const EmbeddedAst& output_ast = output->output_ast; + Ast ast = *output_ast.ast; + MapVariables(ast); + factory_.StartCopyContext(); + Expr to_insert = factory_.Copy(ast.root_expr()); + auto source_info = factory_.RemapSourceInfo(ast.source_info()); + factory_.MergeSourceInfo(source_info); + insertion_expr = std::move(to_insert); + + return false; +} + +absl::flat_hash_map FirstMatchComposer::ResolveBlockIndexes( + const Ast& ast) { + absl::flat_hash_map out; + for (auto it = ast.reference_map().begin(); it != ast.reference_map().end(); + it++) { + const Reference& ref = it->second; + if (!it->second.overload_id().empty()) { + continue; + } + if (!absl::StartsWith(ref.name(), "variable")) { + continue; + } + if (auto index = ResolvePolicyVariable(ref.name()); index.has_value()) { + out[it->first] = *index; + } + } + return out; +} + +void FirstMatchComposer::MapVariables(Ast& ast) { + absl::flat_hash_map edit_map = ResolveBlockIndexes(ast); + IndexRewrite rewriter(std::move(edit_map), factory_); + AstRewrite(ast.mutable_root_expr(), rewriter); +} + +void FirstMatchComposer::ComposeRuleVariables(const CompiledRule& rule, + Expr& init, + Expr& insertion_expr) { + for (const auto& variable : rule.variables) { + Ast ast = *variable.ast.ast; + MapVariables(ast); + factory_.StartCopyContext(); + auto insertion = factory_.Copy(ast.root_expr()); + // TODO(b/506179116): apply the position offsets here. + auto info = factory_.RemapSourceInfo(ast.source_info()); + ABSL_DCHECK(init.has_list_expr()); + int index = init.mutable_list_expr().elements().size(); + init.mutable_list_expr().mutable_elements().push_back( + factory_.NewListElement(std::move(insertion))); + scopes_.back()[variable.ident] = index; + } +} + +bool HasComprehensionParent(const NavigableAstNode& node) { + const NavigableAstNode* curr = &node; + while (curr != nullptr) { + if (curr->node_kind() == NodeKind::kComprehension) { + return true; + } + curr = curr->parent(); + } + return false; +} + +// Unnester implementation. +class Unnester { + public: + Unnester(Ast ast, int height, policy_internal::IssueReporter& issues) + : factory_(std::move(ast)), height_(height), issues_(issues) {} + + // Run the unnesting. + // The class cannot be reused after this is called. + absl::StatusOr Unnest() { + if (height_ > 0) { + CEL_RETURN_IF_ERROR(Slice()); + } + CEL_RETURN_IF_ERROR(Cleanup()); + return std::move(factory_.mutable_ast()); + } + + private: + // The core unnest routine. + absl::Status Slice(); + // Fixup the AST post-unnesting. + absl::Status Cleanup(); + + void ReportErrorAtId(int64_t id, absl::string_view message); + + OptimizerExprFactory factory_; + int height_; + policy_internal::IssueReporter& issues_; +}; + +class UnnestRewriter : public AstRewriterBase { + public: + explicit UnnestRewriter(OptimizerExprFactory& f, Expr& block_list_expr, + absl::Span cuts) + : factory_(f), cuts_(cuts), block_list_expr_(block_list_expr) {} + + bool PostVisitRewrite(Expr& expr) override { + using std::swap; + // Post order so we always see children before parents. + // No need to copy metadata since we're only moving exprs or minting + // new ones. + if (absl::c_contains(cuts_, expr.id())) { + size_t idx = block_list_expr_.list_expr().elements().size(); + Expr value = factory_.NewIdent(absl::StrCat("@index", idx)); + factory_.RecordReplacement(expr.id(), value, /*keep_metadata=*/true); + swap(value, expr); + block_list_expr_.mutable_list_expr().mutable_elements().push_back( + factory_.NewListElement(std::move(value))); + return true; + } + return false; + } + + private: + OptimizerExprFactory& factory_; + absl::Span cuts_; + Expr& block_list_expr_; +}; + +absl::Status Unnester::Slice() { + Expr& root = factory_.mutable_ast().mutable_root_expr(); + if (root.call_expr().function() != kCelBlock || + root.call_expr().args().size() != 2 || + !root.call_expr().args()[0].has_list_expr()) { + return absl::InternalError("malformed AST detected during unnesting"); + } + // Two passes, we identify the slice points (bottom up), then cut + // and paste the leaves into the block list. + NavigableAst nav_ast = NavigableAst::Build(factory_.ast().root_expr()); + + ABSL_DCHECK(nav_ast.IdsAreUnique()); + bool can_cut = true; + std::vector cuts; + for (const NavigableAstNode& node : nav_ast.Root().DescendantsPostorder()) { + // Subsequent cuts will be height_ + 1 in the block, indices. Within the + // error margin we specified. + if (node.height() % height_ == 0) { + if (HasComprehensionParent(node)) { + ReportErrorAtId( + node.expr()->id(), + absl::StrCat( + "cannot unnest AST due to comprehension. cannot accommodate " + "height limit of ", + height_)); + can_cut = false; + continue; + } + if (&node == &nav_ast.Root()) { + // If evenly divisible by height, don't cut since it will net a taller + // AST. + continue; + } + cuts.push_back(node.expr()->id()); + } + } + + if (!can_cut || cuts.empty()) { + return absl::OkStatus(); + } + + Expr& block_list_expr = root.mutable_call_expr().mutable_args()[0]; + Expr& insertion_expr = root.mutable_call_expr().mutable_args()[1]; + + UnnestRewriter rewriter(factory_, block_list_expr, cuts); + AstRewrite(insertion_expr, rewriter); + + return absl::OkStatus(); +} + +absl::Status Unnester::Cleanup() { + using std::swap; + + const auto& ast = factory_.ast(); + if (ast.root_expr().call_expr().function() != kCelBlock || + ast.root_expr().call_expr().args().size() != 2 || + !ast.root_expr().call_expr().args()[0].has_list_expr()) { + return absl::InternalError("malformed AST detected during unnesting"); + } + if (ast.root_expr().call_expr().args()[0].list_expr().elements().empty()) { + Expr value = std::move(factory_.mutable_ast() + .mutable_root_expr() + .mutable_call_expr() + .mutable_args()[1]); + factory_.mutable_ast().mutable_root_expr() = std::move(value); + } + + return absl::OkStatus(); +} + +void Unnester::ReportErrorAtId(int64_t id, absl::string_view message) { + int32_t position = 0; + auto it = factory_.ast().source_info().positions().find(id); + if (it != factory_.ast().source_info().positions().end()) { + position = it->second; + } + issues_.ReportError(-1, position, message); +} +} // namespace + +// Compiles a CEL policy using the provided CEL compiler as a base environment. +absl::StatusOr CompilePolicy( + const Compiler& compiler, const CelPolicy& policy, + const CompilePolicyOptions& options) { + policy_internal::IssueReporter issues; + if (options.unnesting_height_limit != 0 && + options.unnesting_height_limit < 2) { + return absl::InvalidArgumentError( + "unnesting_height_limit must be at least 2"); + } + auto builder = compiler.ToBuilder(); + ExpressionContainer cont; + for (const auto& import : policy.imports()) { + auto status = cont.AddAbbreviation(import.name().value()); + if (!status.ok()) { + issues.ReportError( + import.name().id(), + absl::StrCat("'", import.name().value(), "': ", status.message())); + } + } + + builder->GetCheckerBuilder().SetExpressionContainer(cont); + CEL_ASSIGN_OR_RETURN(auto base_compiler, builder->Build()); + + PolicyCompiler policy_compiler(&issues, std::move(base_compiler)); + + IntermediateCompiledPolicy icp; + CEL_RETURN_IF_ERROR(policy_compiler.CompilePolicy(policy, &icp)); + + if (HasErrors(issues)) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + CEL_ASSIGN_OR_RETURN(base_compiler, builder->Build()); + switch (icp.semantics()) { + case RuleSemantics::kFirstMatch: { + FirstMatchComposer composer(icp, *base_compiler, issues); + CEL_RETURN_IF_ERROR(composer.Compose()); + if (!composer.success()) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + auto ast = composer.ReleaseAst(); + Unnester unnester(std::move(*ast), options.unnesting_height_limit, + issues); + CEL_ASSIGN_OR_RETURN(Ast unnested_ast, unnester.Unnest()); + + if (HasErrors(issues)) { + return CelPolicyValidationResult(issues.ReleaseIssues(), + policy.source_ptr()); + } + + return CelPolicyValidationResult( + std::make_unique(std::move(unnested_ast)), {}, + policy.source_ptr()); + } + default: + return absl::UnimplementedError( + absl::StrCat("Unsupported RuleSemantics: ", icp.semantics())); + } +} + +} // namespace cel diff --git a/policy/compiler.h b/policy/compiler.h new file mode 100644 index 000000000..0187bd1a2 --- /dev/null +++ b/policy/compiler.h @@ -0,0 +1,50 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ + +#include "absl/status/statusor.h" +#include "compiler/compiler.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_validation_result.h" + +namespace cel { + +struct CompilePolicyOptions { + // If greater than 0, the compiler will attempt to unnest rule branches + // at the specified height. The overall height of the final AST may exceed + // this by a small, fixed margin. + // + // To avoid slicing comprehensions, subexpressions within comprehensions + // are not eligible for unnesting. If the height limit cannot be accommodated, + // an error with code InvalidArgument is returned. + // + // If the AST is converted to proto, even relatively low levels of nesting + // can cause problems in serialization/deserialization. This does not apply + // if the AST is used directly by the runtime. + int unnesting_height_limit = 0; +}; + +// Compiles a CEL policy using the provided CEL compiler as a base environment. +// +// TODO(b/506179116): Implementation in progress. Functionally complete, +// but errors are not consistent with other implementations. +absl::StatusOr CompilePolicy( + const Compiler& compiler, const CelPolicy& policy, + const CompilePolicyOptions& options = {}); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_COMPILER_H_ diff --git a/policy/compiler_test.cc b/policy/compiler_test.cc new file mode 100644 index 000000000..8db494b45 --- /dev/null +++ b/policy/compiler_test.cc @@ -0,0 +1,946 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/compiler.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/decl.h" +#include "common/navigable_ast.h" +#include "common/source.h" +#include "common/type.h" +#include "common/types/message_type.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/optional.h" +#include "compiler/standard_library.h" +#include "extensions/bindings_ext.h" +#include "internal/runfiles.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_validation_result.h" +#include "policy/yaml_policy_parser.h" +#include "runtime/activation.h" +#include "runtime/optional_types.h" +#include "runtime/runtime.h" +#include "runtime/runtime_builder.h" +#include "runtime/runtime_options.h" +#include "runtime/standard_runtime_builder_factory.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using ::absl_testing::IsOk; +using ::absl_testing::StatusIs; +using ::cel::test::IntValueIs; +using ::cel::test::OptionalValueIs; +using ::cel::test::OptionalValueIsEmpty; +using ::cel::test::StringValueIs; +using ::cel::test::ValueMatcher; + +constexpr absl::string_view kTestPolicyFilePath = +"_main/policy/testdata/cel_policy.yaml"; + +absl::StatusOr> BuildTestCompiler() { + CompilerOptions opts; + opts.adapt_parser_errors = true; + opts.parser_options.enable_optional_syntax = true; + CEL_ASSIGN_OR_RETURN( + auto builder, + NewCompilerBuilder(internal::GetSharedTestingDescriptorPool(), opts)); + + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::OptionalCompilerLibrary())); + CEL_RETURN_IF_ERROR( + builder->AddLibrary(cel::extensions::BindingsCompilerLibrary())); + + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("x", IntType()))); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("y", IntType()))); + + const google::protobuf::Descriptor* descriptor = + cel::internal::GetSharedTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes"); + if (descriptor == nullptr) { + return absl::InternalError("Failed to find TestAllTypes descriptor"); + } + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("spec", cel::MessageType(descriptor)))); + + return builder->Build(); +} + +absl::StatusOr> ParsePolicyFromYaml( + absl::string_view yaml_content) { + CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(yaml_content, "test.yaml")); + + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + CEL_ASSIGN_OR_RETURN(auto parse_result, + cel::ParseYamlCelPolicy(policy_source)); + + if (!parse_result.IsValid()) { + return absl::InvalidArgumentError("Invalid policy YAML structure"); + } + return parse_result.ReleasePolicy(); +} + +TEST(CompilerTest, SmokeTest) { + std::string contents; + std::string test_file = + cel::internal::ResolveRunfilesPath(kTestPolicyFilePath); + auto read_status = cel::internal::GetFileContents(test_file, &contents); + ASSERT_THAT(read_status, IsOk()); + + auto source_or = cel::NewSource(contents, "cel_policy.yaml"); + ASSERT_THAT(source_or.status(), IsOk()); + auto source = *std::move(source_or); + + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + auto parse_result_or = cel::ParseYamlCelPolicy(policy_source); + ASSERT_THAT(parse_result_or.status(), IsOk()); + auto parse_result = *std::move(parse_result_or); + + ASSERT_TRUE(parse_result.IsValid()); + const CelPolicy* policy = parse_result.GetPolicy(); + + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); +} + +TEST(CompilerTest, VariableOutOfScopeReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: variables.non_existent == 10 + output: '"error"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("undeclared reference")); +} + +TEST(CompilerTest, ConditionNotBoolReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: 10 + output: '"error"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("condition must evaluate to bool")); +} + +TEST(CompilerTest, InvalidOutputExpressionReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: undeclared_var +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("undeclared reference")); +} + +TEST(CompilerTest, UnreachableMatchAfterTriviallyTrueCondition) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"first"' + - condition: true + output: '"second"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match creates unreachable outputs")); +} + +TEST(CompilerTest, UnreachableMatchAfterUnconditionalExhaustiveSubRule) { + absl::string_view yaml = R"yaml( +name: dead_branch +rule: + match: + - rule: + match: + - output: 1 + - output: 2 +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match creates unreachable outputs")); +} + +TEST(CompilerTest, RuleWithoutMatchesReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("rule does not specify match conditions")); +} + +TEST(CompilerTest, ExhaustivePolicyCompiles) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + variables: + - name: test_var + expression: 10 + match: + - condition: variables.test_var > 15 + output: '"greater than 15"' + - condition: variables.test_var > 5 + output: '"greater than 5"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); + EXPECT_TRUE(result.GetAst()->is_checked()); +} + +TEST(CompilerTest, NonExhaustivePolicyCompiles) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + variables: + - name: test_var + expression: 10 + match: + - condition: variables.test_var > 5 + output: '"greater than 5"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); +} + +TEST(CompilerTest, PolicyReferencesEnvInput) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: spec.single_int32 > 10 + output: '"greater than 10"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(result.IsValid()); + EXPECT_TRUE(result.GetAst()->is_checked()); +} + +struct EvaluationTestCase { + std::string name; + std::string yaml_policy; + struct Input { + int64_t x; + int64_t y; + } input; + ValueMatcher expected_result_matcher; +}; + +class PolicyEvaluationTest : public testing::TestWithParam { +}; + +TEST_P(PolicyEvaluationTest, Evaluate) { + const auto& test_case = GetParam(); + + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(test_case.yaml_policy)); + ASSERT_OK_AND_ASSIGN(auto validation_result, + CompilePolicy(*compiler, *policy)); + ASSERT_TRUE(validation_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, validation_result.ReleaseAst()); + + // Set up runtime + cel::RuntimeOptions opts; + opts.enable_qualified_type_identifiers = true; + ASSERT_OK_AND_ASSIGN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + ASSERT_THAT(cel::extensions::EnableOptionalTypes(rt_builder), IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + // Set up activation + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::IntValue(test_case.input.x)); + activation.InsertOrAssignValue("y", cel::IntValue(test_case.input.y)); + + google::protobuf::Arena arena; + ASSERT_OK_AND_ASSIGN(cel::Value result, + program->Evaluate(&arena, activation)); + + EXPECT_THAT(result, test_case.expected_result_matcher); +} + +constexpr absl::string_view kEvalPolicyYaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: x > 10 && y > 10 + output: '"both greater than 10"' + - condition: x > 10 + output: '"x greater than 10"' + - condition: y > 10 + output: '"y greater than 10"' + - output: '"default"' +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + PolicyEvaluationTest, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "BothGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 15, .y = 15}, + .expected_result_matcher = StringValueIs("both greater than 10"), + }, + EvaluationTestCase{ + .name = "XGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 15, .y = 5}, + .expected_result_matcher = StringValueIs("x greater than 10"), + }, + EvaluationTestCase{ + .name = "YGreaterThan10", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 5, .y = 15}, + .expected_result_matcher = StringValueIs("y greater than 10"), + }, + EvaluationTestCase{ + .name = "Default", + .yaml_policy = std::string(kEvalPolicyYaml), + .input = {.x = 5, .y = 5}, + .expected_result_matcher = StringValueIs("default"), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kNonExhaustivePolicyYaml = R"yaml( +name: nested_rule4 +rule: + match: + - condition: x > 0 + rule: + match: + - condition: x < 3 + output: 1 + - condition: x < 5 + output: 2 + - condition: x < 0 + rule: + match: + - condition: x > -2 + output: 3 + - condition: x > -4 + output: 4 + - output: 5 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + NonExhaustivePolicyEvaluation, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals0_FallthroughTopLevel", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }, + EvaluationTestCase{ + .name = "XEquals2_MatchesFirstNested", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 2, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEquals6_FallthroughNested", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = 6, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1_MatchesMinus2", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(3)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus3_MatchesMinus4", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -3, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(4)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus5_MatchesDefault", + .yaml_policy = std::string(kNonExhaustivePolicyYaml), + .input = {.x = -5, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(5)), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kNestedVariablePolicyYaml = R"yaml( +name: nested_rule4 +rule: + variables: + - name: i + expression: "1" + - name: j + expression: "2" + match: + - condition: x > 0 + rule: + variables: + - name: k + expression: "3" + match: + - output: "variables.i + variables.j + variables.k" + - condition: x < 0 + rule: + variables: + - name: j + expression: "5" + - name: k + expression: "4" + match: + - output: "variables.i + variables.j + variables.k" + - output: "variables.i + variables.j" +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + NestedVariablePolicyEvaluation, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XGreaterThan0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(6), + }, + EvaluationTestCase{ + .name = "XLessThan0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = IntValueIs(10), + }, + EvaluationTestCase{ + .name = "XEquals0", + .yaml_policy = std::string(kNestedVariablePolicyYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = IntValueIs(3), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view + kOptionalChainingUnconditionalSubRuleOptionalParentYaml = R"yaml( +name: optional_chaining +rule: + match: + - rule: + id: r2 + match: + - condition: x > 0 + output: 1 + - output: 2 + condition: x < 0 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalSubRuleOptionalParent, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = std::string( + kOptionalChainingUnconditionalSubRuleOptionalParentYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = std::string( + kOptionalChainingUnconditionalSubRuleOptionalParentYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(2)), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kOptionalChainingUnconditionalSubRuleYaml = R"yaml( +name: optional_chaining +rule: + id: r1 + match: + - rule: + id: r2 + match: + - condition: x > 0 + output: 1 + - output: 2 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalSubRule, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalSubRuleYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(1), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalSubRuleYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = IntValueIs(2), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kOptionalChainingUnconditionalComplexYaml = R"yaml( +name: optional_chaining +rule: + match: + - condition: x > 0 + rule: + match: + - rule: + match: + - condition: x == 1 + output: 1 + - output: 2 + - rule: + match: + - condition: x == -1 + output: 3 + - condition: x == -2 + output: 4 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + OptionalChainingUnconditionalComplex, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus1", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(3)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus2", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -2, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(4)), + }, + EvaluationTestCase{ + .name = "XEqualsMinus3", + .yaml_policy = + std::string(kOptionalChainingUnconditionalComplexYaml), + .input = {.x = -3, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kUnconditionalExhaustiveSubRuleAsLastMatchYaml = + R"yaml( +name: exhaustive_unconditional_subrule +rule: + match: + - condition: x > 0 + output: 1 + - rule: + match: + - condition: y > 0 + output: 2 + - output: 3 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + UnconditionalExhaustiveSubRuleAsLastMatch, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = IntValueIs(1), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals1", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 1}, + .expected_result_matcher = IntValueIs(2), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals0", + .yaml_policy = + std::string(kUnconditionalExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = IntValueIs(3), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +constexpr absl::string_view kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml = + R"yaml( +name: non_exhaustive_unconditional_subrule +rule: + match: + - condition: x > 0 + output: 1 + - rule: + match: + - condition: y > 0 + output: 2 +)yaml"; + +INSTANTIATE_TEST_SUITE_P( + UnconditionalNonExhaustiveSubRuleAsLastMatch, PolicyEvaluationTest, + testing::Values( + EvaluationTestCase{ + .name = "XEquals1", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 1, .y = 0}, + .expected_result_matcher = OptionalValueIs(IntValueIs(1)), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals1", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 1}, + .expected_result_matcher = OptionalValueIs(IntValueIs(2)), + }, + EvaluationTestCase{ + .name = "XEquals0_YEquals0", + .yaml_policy = + std::string(kUnconditionalNonExhaustiveSubRuleAsLastMatchYaml), + .input = {.x = 0, .y = 0}, + .expected_result_matcher = OptionalValueIsEmpty(), + }), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +TEST(CompilerTest, ImportsAndAbbreviations) { + absl::string_view yaml = R"yaml( +name: imports_test +imports: + - name: cel.expr.conformance.proto3.TestAllTypes +rule: + match: + - condition: 'spec == TestAllTypes{single_int32: 10}' + output: '"matched"' + - output: '"default"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + auto ast_or = CompilePolicy(*compiler, *policy); + ASSERT_THAT(ast_or, IsOk()); +} + +TEST(CompilerTest, MatchWithoutProductionReportsError) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("match must specify an output or rule")); +} + +int GetAstHeight(const cel::Ast& ast) { + auto nav_ast = cel::NavigableAst::Build(ast.root_expr()); + return nav_ast.Root().height(); +} + +TEST(CompilerTest, UnnestHeightValidation) { + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"ok"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 1; + auto status_or = CompilePolicy(*compiler, *policy, options); + EXPECT_THAT(status_or.status(), + StatusIs(absl::StatusCode::kInvalidArgument, + testing::HasSubstr( + "unnesting_height_limit must be at least 2"))); + + options.unnesting_height_limit = 2; + EXPECT_THAT(CompilePolicy(*compiler, *policy, options), IsOk()); +} + +constexpr absl::string_view kDeepPolicyYaml = R"yaml( +name: deep_policy +rule: + match: + - condition: x > 0 + rule: + match: + - condition: x > 1 + rule: + match: + - condition: x > 2 + rule: + match: + - condition: x > 3 + rule: + match: + - condition: x > 4 + rule: + match: + - condition: x > 5 + output: 6 + - output: 5 + - output: 4 + - output: 3 + - output: 2 + - output: 1 + - output: 0 +)yaml"; + +TEST(CompilerTest, UnnestHeightReduction) { + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(kDeepPolicyYaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + // Compile without unnesting + CompilePolicyOptions options_no_unnest; + options_no_unnest.unnesting_height_limit = 0; + ASSERT_OK_AND_ASSIGN(auto result_no_unnest, + CompilePolicy(*compiler, *policy, options_no_unnest)); + ASSERT_TRUE(result_no_unnest.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast_no_unnest, result_no_unnest.ReleaseAst()); + int height_no_unnest = GetAstHeight(*ast_no_unnest); + + CompilePolicyOptions options_unnest; + options_unnest.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result_unnest, + CompilePolicy(*compiler, *policy, options_unnest)); + ASSERT_TRUE(result_unnest.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast_unnest, result_unnest.ReleaseAst()); + int height_unnest = GetAstHeight(*ast_unnest); + + EXPECT_EQ(height_no_unnest, 8); + EXPECT_EQ(height_unnest, 5); + EXPECT_LT(height_unnest, height_no_unnest); +} + +TEST(CompilerTest, UnnestComprehensionFailure) { + absl::string_view yaml = R"yaml( +name: comprehension_policy +rule: + match: + - condition: x > 0 + rule: + match: + - condition: "[1, 2].all(i, i > x)" + output: 1 + - output: 2 + - output: 0 +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + EXPECT_FALSE(result.IsValid()); + EXPECT_THAT(result.FormatIssues(), + testing::HasSubstr("cannot unnest AST due to comprehension")); +} + +struct UnnestEvaluationTestCase { + std::string name; + int64_t x; + ValueMatcher expected; +}; + +class UnnestedDeepPolicyEvaluationTest + : public testing::TestWithParam {}; + +TEST_P(UnnestedDeepPolicyEvaluationTest, Evaluate) { + const auto& tc = GetParam(); + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(kDeepPolicyYaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 2; + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + + // Set up runtime + cel::RuntimeOptions opts; + opts.enable_qualified_type_identifiers = true; + ASSERT_OK_AND_ASSIGN( + cel::RuntimeBuilder rt_builder, + cel::CreateStandardRuntimeBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + ASSERT_THAT(cel::extensions::EnableOptionalTypes(rt_builder), IsOk()); + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + std::move(rt_builder).Build()); + + ASSERT_OK_AND_ASSIGN(auto program, runtime->CreateProgram(std::move(ast))); + + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::IntValue(tc.x)); + google::protobuf::Arena arena; + + ASSERT_OK_AND_ASSIGN(cel::Value res, program->Evaluate(&arena, activation)); + + EXPECT_THAT(res, tc.expected); +} + +INSTANTIATE_TEST_SUITE_P( + UnnestedDeepPolicyEvaluation, UnnestedDeepPolicyEvaluationTest, + testing::Values(UnnestEvaluationTestCase{"XEquals6", 6, IntValueIs(6)}, + UnnestEvaluationTestCase{"XEquals5", 5, IntValueIs(5)}, + UnnestEvaluationTestCase{"XEquals4", 4, IntValueIs(4)}, + UnnestEvaluationTestCase{"XEquals3", 3, IntValueIs(3)}, + UnnestEvaluationTestCase{"XEquals2", 2, IntValueIs(2)}, + UnnestEvaluationTestCase{"XEquals1", 1, IntValueIs(1)}, + UnnestEvaluationTestCase{"XEquals0", 0, IntValueIs(0)}, + UnnestEvaluationTestCase{"XEqualsMinus1", -1, + IntValueIs(0)}), + [](const testing::TestParamInfo< + UnnestedDeepPolicyEvaluationTest::ParamType>& info) { + return info.param.name; + }); + +TEST(CompilerTest, UnnestCleanupRunsWhenDisabled) { + // A policy without variables and without nesting. + absl::string_view yaml = R"yaml( +name: cel_policy +rule: + id: test_rule + match: + - condition: true + output: '"ok"' +)yaml"; + ASSERT_OK_AND_ASSIGN(auto policy, ParsePolicyFromYaml(yaml)); + ASSERT_OK_AND_ASSIGN(auto compiler, BuildTestCompiler()); + + CompilePolicyOptions options; + options.unnesting_height_limit = 0; // Disabled + ASSERT_OK_AND_ASSIGN(auto result, CompilePolicy(*compiler, *policy, options)); + ASSERT_TRUE(result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast, result.ReleaseAst()); + + // If cleanup ran, it should have optimized away the trivial `cel.@block`. + // So the root expression should NOT be a call to `cel.@block`. + // It should be just the constant `"ok"`. + auto nav_ast = cel::NavigableAst::Build(ast->root_expr()); + EXPECT_FALSE(nav_ast.Root().expr()->has_call_expr() && + nav_ast.Root().expr()->call_expr().function() == "cel.@block"); + EXPECT_TRUE(nav_ast.Root().expr()->has_const_expr()); +} +} // namespace +} // namespace cel diff --git a/policy/internal/BUILD b/policy/internal/BUILD new file mode 100644 index 000000000..30f43d431 --- /dev/null +++ b/policy/internal/BUILD @@ -0,0 +1,68 @@ +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "issue_reporter", + srcs = ["issue_reporter.cc"], + hdrs = ["issue_reporter.h"], + deps = [ + "//common:source", + "//policy:cel_policy", + "//policy:cel_policy_parser", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "optimizer_expr_factory", + srcs = ["optimizer_expr_factory.cc"], + hdrs = ["optimizer_expr_factory.h"], + deps = [ + "//common:ast", + "//common:ast_rewrite", + "//common:ast_traverse", + "//common:ast_visitor_base", + "//common:constant", + "//common:expr", + "//common:expr_factory", + "//common:source", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "optimizer_expr_factory_test", + srcs = ["optimizer_expr_factory_test.cc"], + deps = [ + ":optimizer_expr_factory", + "//common:ast", + "//common:ast_proto", + "//common:ast_rewrite", + "//common:decl", + "//common:expr", + "//common:expr_factory", + "//common:source", + "//common:type", + "//compiler", + "//compiler:compiler_factory", + "//compiler:standard_library", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//testutil:expr_printer", + "//tools:cel_unparser", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) diff --git a/policy/internal/issue_reporter.cc b/policy/internal/issue_reporter.cc new file mode 100644 index 000000000..944e687d6 --- /dev/null +++ b/policy/internal/issue_reporter.cc @@ -0,0 +1,45 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/issue_reporter.h" + +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "policy/cel_policy.h" + +namespace cel::policy_internal { + +void IssueReporter::ReportIssue(CelPolicyElementId element, Severity severity, + absl::string_view message) { + issues_.push_back({element, severity, message}); +} + +void IssueReporter::ReportOffsetIssue(CelPolicyElementId element, + cel::SourcePosition relative_position, + Severity severity, + absl::string_view message) { + issues_.push_back({element, relative_position, severity, message}); +} + +void IssueReporter::ReportError(CelPolicyElementId element, + absl::string_view message) { + ReportIssue(element, Severity::kError, message); +} + +void IssueReporter::ReportError(CelPolicyElementId element, SourcePosition pos, + absl::string_view message) { + ReportOffsetIssue(element, pos, Severity::kError, message); +} + +} // namespace cel::policy_internal diff --git a/policy/internal/issue_reporter.h b/policy/internal/issue_reporter.h new file mode 100644 index 000000000..3f88806ef --- /dev/null +++ b/policy/internal/issue_reporter.h @@ -0,0 +1,57 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" + +namespace cel::policy_internal { + +class IssueReporter { + private: + using Severity = CelPolicyIssue::Severity; + + public: + void ReportIssue(CelPolicyElementId element, Severity severity, + absl::string_view message); + + void ReportOffsetIssue(CelPolicyElementId element, + cel::SourcePosition relative_position, + Severity severity, absl::string_view message); + + void ReportError(CelPolicyElementId element, absl::string_view message); + void ReportError(CelPolicyElementId element, SourcePosition relative_pos, + absl::string_view message); + + std::vector ReleaseIssues() { + using std::swap; + std::vector out; + swap(out, issues_); + return out; + } + const std::vector& issues() const { return issues_; } + + private: + std::vector issues_; +}; + +} // namespace cel::policy_internal + +#endif // THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_ISSUE_REPORTER_H_ diff --git a/policy/internal/optimizer_expr_factory.cc b/policy/internal/optimizer_expr_factory.cc new file mode 100644 index 000000000..6c89ae958 --- /dev/null +++ b/policy/internal/optimizer_expr_factory.cc @@ -0,0 +1,373 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/optimizer_expr_factory.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/ast.h" +#include "common/ast_rewrite.h" +#include "common/ast_traverse.h" +#include "common/ast_visitor_base.h" +#include "common/constant.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" + +namespace cel { + +namespace { + +class MaxIdVisitor final : public AstVisitorBase { + public: + ExprId max_id() const { return max_id_; } + + void PreVisitExpr(const Expr& expr) override { + max_id_ = std::max(max_id_, expr.id()); + } + + void PostVisitExpr(const Expr&) override {} + + void PostVisitStruct(const Expr&, const StructExpr& struct_expr) override { + for (const auto& field : struct_expr.fields()) { + max_id_ = std::max(max_id_, field.id()); + } + } + + void PostVisitMap(const Expr&, const MapExpr& map_expr) override { + for (const auto& entry : map_expr.entries()) { + max_id_ = std::max(max_id_, entry.id()); + } + } + + private: + ExprId max_id_ = 0; +}; + +ExprId GetMaxId(const Expr& expr) { + MaxIdVisitor visitor; + AstTraverse(expr, visitor); + return visitor.max_id(); +} + +ExprId GetMaxId(const Ast& ast) { + ExprId max_id = GetMaxId(ast.root_expr()); + for (const auto& [id, _] : ast.source_info().positions()) { + max_id = std::max(max_id, id); + } + for (const auto& [id, expr] : ast.source_info().macro_calls()) { + max_id = std::max(max_id, id); + max_id = std::max(max_id, GetMaxId(expr)); + } + return max_id; +} + +// Replaces nested macros in a macro_calls expr with reference nodes. +// +// The macro_calls map is used for retaining the original structure of the +// parsed expression before macro expansion. When a macro appears inside another +// macro, the parser will replace the inner macro expr node with an unspecified +// expr with the inner macro's ID in the macro_calls map to save space. +class MakeMacroCallRewrite final : public AstRewriterBase { + public: + explicit MakeMacroCallRewrite(const SourceInfo& source_info) + : source_info_(source_info) {} + + bool PreVisitRewrite(Expr& expr) override { + if (source_info_.macro_calls().find(expr.id()) != + source_info_.macro_calls().end()) { + ExprId id = expr.id(); + expr.mutable_kind() = UnspecifiedExpr(); + expr.set_id(id); + return true; + } + return false; + } + + private: + const SourceInfo& source_info_; +}; + +// Updates macro_calls map entries to reflect a replaced expression in the +// main AST. +class ReplaceMacroCallRewrite final : public AstRewriterBase { + public: + ReplaceMacroCallRewrite(ExprId old_id, const Expr& replacement, + const SourceInfo& source_info) + : old_id_(old_id), replacement_(replacement), source_info_(source_info) {} + + bool PreVisitRewrite(Expr& expr) override { + if (expr.id() == old_id_) { + expr = macro_replacement(); + return true; + } + return false; + } + + Expr macro_replacement() { + if (!macro_replacement_) { + macro_replacement_.emplace(replacement_); + MakeMacroCallRewrite hole_creator(source_info_); + AstRewrite(*macro_replacement_, hole_creator); + } + return *macro_replacement_; + } + + private: + ExprId old_id_; + const Expr& replacement_; + absl::optional macro_replacement_; + const SourceInfo& source_info_; +}; + +void ReplaceSubExpr(Expr& expr, ExprId old_id, const Expr& replacement, + const SourceInfo& source_info) { + ReplaceMacroCallRewrite rewriter(old_id, replacement, source_info); + AstRewrite(expr, rewriter); +} + +class IdRewriter : public AstRewriterBase { + using CopyIdFn = absl::AnyInvocable; + + public: + explicit IdRewriter(CopyIdFn copy_id) : copy_id_(std::move(copy_id)) {} + + // No structure changes just ids. + bool PreVisitRewrite(Expr& expr) override { + expr.set_id(copy_id_(expr.id())); + if (expr.has_struct_expr()) { + for (auto& field : expr.mutable_struct_expr().mutable_fields()) { + field.set_id(copy_id_(field.id())); + } + } else if (expr.has_map_expr()) { + for (auto& entry : expr.mutable_map_expr().mutable_entries()) { + entry.set_id(copy_id_(entry.id())); + } + } + return false; + } + + private: + CopyIdFn copy_id_; +}; + +} // namespace + +OptimizerExprFactory::OptimizerExprFactory(Ast basis) + : ast_(std::move(basis)), next_id_(GetMaxId(ast_) + 1) {} + +OptimizerExprFactory::OptimizerExprFactory() : next_id_(1) {} + +Expr OptimizerExprFactory::Copy(const Expr& expr) { + Expr copied = expr; + IdRewriter rewriter([this](ExprId id) { return CopyId(id); }); + AstRewrite(copied, rewriter); + return copied; +} + +ListExprElement OptimizerExprFactory::Copy(const ListExprElement& element) { + return NewListElement(Copy(element.expr()), element.optional()); +} + +StructExprField OptimizerExprFactory::Copy(const StructExprField& field) { + auto field_id = CopyId(field.id()); + auto field_value = Copy(field.value()); + return NewStructField(field_id, field.name(), std::move(field_value), + field.optional()); +} + +MapExprEntry OptimizerExprFactory::Copy(const MapExprEntry& entry) { + auto entry_id = CopyId(entry.id()); + auto entry_key = Copy(entry.key()); + auto entry_value = Copy(entry.value()); + return NewMapEntry(entry_id, std::move(entry_key), std::move(entry_value), + entry.optional()); +} + +ExprId OptimizerExprFactory::NextId() { return next_id_++; } + +ExprId OptimizerExprFactory::CopyId(ExprId id) { + if (id == 0) { + return 0; + } + auto it = renumbers_.find(id); + if (it != renumbers_.end()) { + return it->second; + } + ExprId new_id = NextId(); + renumbers_[id] = new_id; + return new_id; +} + +SourceInfo OptimizerExprFactory::RemapSourceInfo(const SourceInfo& info, + SourcePosition offset) { + SourceInfo out; + + for (const auto& [old_id, macro_expr] : info.macro_calls()) { + if (auto it = renumbers_.find(old_id); it != renumbers_.end()) { + ExprId new_id = it->second; + out.mutable_macro_calls()[new_id] = Copy(macro_expr); + } + } + + for (const auto& [old_id, new_id] : renumbers_) { + if (auto it = info.positions().find(old_id); it != info.positions().end()) { + out.mutable_positions()[new_id] = it->second + offset; + } + } + + return out; +} + +void OptimizerExprFactory::MergeSourceInfo(const SourceInfo& info) { + auto& target_info = ast_.mutable_source_info(); + + for (const auto& [id, pos] : info.positions()) { + auto [it, inserted] = target_info.mutable_positions().insert({id, pos}); + if (!inserted) { + issues_.push_back(Issue{id, "conflicting ID in positions merge"}); + } + } + + for (const auto& [id, expr] : info.macro_calls()) { + auto [it, inserted] = target_info.mutable_macro_calls().insert({id, expr}); + if (!inserted) { + issues_.push_back(Issue{id, "conflicting ID in macro calls merge"}); + } + } + + // TODO(b/506179116): need to add some check that we aren't + // introducing incompatible tags. Not possible in the policy compiler right + // now. + for (const auto& ext : info.extensions()) { + auto& target_exts = target_info.mutable_extensions(); + if (!absl::c_linear_search(target_exts, ext)) { + target_exts.push_back(ext); + } + } +} + +void OptimizerExprFactory::RecordReplacement(ExprId id, const Expr& replacement, + bool keep_metadata) { + auto& source_info = ast_.mutable_source_info(); + if (!keep_metadata) { + source_info.mutable_positions().erase(id); + source_info.mutable_macro_calls().erase(id); + } + + for (auto& [macro_id, macro_expr] : source_info.mutable_macro_calls()) { + ReplaceSubExpr(macro_expr, id, replacement, source_info); + } +} + +Expr OptimizerExprFactory::ReportError(absl::string_view message) { + ExprId id = NextId(); + issues_.push_back(Issue{id, std::string(message)}); + return NewUnspecified(id); +} + +Expr OptimizerExprFactory::ReportErrorAt(const Expr& expr, + absl::string_view message) { + issues_.push_back(Issue{expr.id(), std::string(message)}); + return NewUnspecified(NextId()); +} + +Expr OptimizerExprFactory::ReportErrorAtCopy(const Expr& expr, + absl::string_view message) { + issues_.push_back(Issue{CopyId(expr.id()), std::string(message)}); + return NewUnspecified(NextId()); +} + +Expr OptimizerExprFactory::NewUnspecified() { return NewUnspecified(NextId()); } + +Expr OptimizerExprFactory::NewNullConst() { return NewNullConst(NextId()); } + +Expr OptimizerExprFactory::NewBoolConst(bool value) { + return NewBoolConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewIntConst(int64_t value) { + return NewIntConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewUintConst(uint64_t value) { + return NewUintConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewDoubleConst(double value) { + return NewDoubleConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewBytesConst(std::string value) { + return NewBytesConst(NextId(), std::move(value)); +} + +Expr OptimizerExprFactory::NewBytesConst(absl::string_view value) { + return NewBytesConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewBytesConst(const char* value) { + return NewBytesConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewStringConst(std::string value) { + return NewStringConst(NextId(), std::move(value)); +} + +Expr OptimizerExprFactory::NewStringConst(absl::string_view value) { + return NewStringConst(NextId(), value); +} + +Expr OptimizerExprFactory::NewStringConst(const char* value) { + return NewStringConst(NextId(), value); +} + +absl::flat_hash_map OptimizerExprFactory::ConsumeRenumbers() { + using std::swap; + absl::flat_hash_map out; + swap(out, renumbers_); + return out; +} + +void OptimizerExprFactory::StartCopyContext() { renumbers_.clear(); } + +const std::vector& OptimizerExprFactory::issues() + const { + return issues_; +} + +const Ast& OptimizerExprFactory::ast() const { return ast_; } + +Ast& OptimizerExprFactory::mutable_ast() { return ast_; } + +absl::string_view OptimizerExprFactory::AccuVarName() { + return ExprFactory::AccuVarName(); +} + +Expr OptimizerExprFactory::NewAccuIdent() { return NewAccuIdent(NextId()); } + +ExprId OptimizerExprFactory::CopyId(const Expr& expr) { + return CopyId(expr.id()); +} + +} // namespace cel diff --git a/policy/internal/optimizer_expr_factory.h b/policy/internal/optimizer_expr_factory.h new file mode 100644 index 000000000..6f63f1485 --- /dev/null +++ b/policy/internal/optimizer_expr_factory.h @@ -0,0 +1,419 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" + +namespace cel { + +class ParserMacroExprFactory; +class TestOptimizerExprFactory; + +// `OptimizerExprFactory` is a specialization of `ExprFactory` used for AST +// optimization. It provides utilities for correcting metadata for modified +// ASTs. +class OptimizerExprFactory : protected ExprFactory { + public: + struct Issue { + ExprId location = 0; + std::string message; + }; + + explicit OptimizerExprFactory(Ast basis); + OptimizerExprFactory(); + + protected: + using ExprFactory::IsArrayLike; + using ExprFactory::IsExprLike; + using ExprFactory::IsStringLike; + + template + struct IsRValue + : std::bool_constant< + std::disjunction_v, std::is_same>> {}; + + public: + // Consume the current set of renumberings. + absl::flat_hash_map ConsumeRenumbers(); + + // Starts a new copy context. The current set of renumberings are cleared. + void StartCopyContext(); + + const std::vector& issues() const; + + // Record that a node in the working AST was replaced. This is used to correct + // metadata referencing the old ID. + void RecordReplacement(ExprId id, const Expr& replacement, + bool keep_metadata = false); + + // Makes a copy of source metadata that is remapped to new expr Ids using + // current renumberings. This is suitable for merging into the main source + // info. + SourceInfo RemapSourceInfo(const SourceInfo& info, SourcePosition offset = 0); + + // Merge a remapped SourceInfo into the current one. + void MergeSourceInfo(const SourceInfo& info); + + const Ast& ast() const; + Ast& mutable_ast(); + + absl::string_view AccuVarName(); + + ABSL_MUST_USE_RESULT Expr Copy(const Expr& expr); + + ABSL_MUST_USE_RESULT ListExprElement Copy(const ListExprElement& element); + + ABSL_MUST_USE_RESULT StructExprField Copy(const StructExprField& field); + + ABSL_MUST_USE_RESULT MapExprEntry Copy(const MapExprEntry& entry); + + ABSL_MUST_USE_RESULT Expr NewUnspecified(); + + ABSL_MUST_USE_RESULT Expr NewNullConst(); + + ABSL_MUST_USE_RESULT Expr NewBoolConst(bool value); + + ABSL_MUST_USE_RESULT Expr NewIntConst(int64_t value); + + ABSL_MUST_USE_RESULT Expr NewUintConst(uint64_t value); + + ABSL_MUST_USE_RESULT Expr NewDoubleConst(double value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(std::string value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(absl::string_view value); + + ABSL_MUST_USE_RESULT Expr NewBytesConst(const char* absl_nullable value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(std::string value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(absl::string_view value); + + ABSL_MUST_USE_RESULT Expr NewStringConst(const char* absl_nullable value); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewIdent(Name name); + + ABSL_MUST_USE_RESULT Expr NewAccuIdent(); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewSelect(Operand operand, Field field); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewPresenceTest(Operand operand, Field field); + + template < + typename Function, typename... Args, + typename = std::enable_if_t::value>, + typename = std::enable_if_t...>>> + ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args&&... args); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewCall(Function function, Args args); + + template < + typename Function, typename Target, typename... Args, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t...>>> + ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, + Args&&... args); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewMemberCall(Function function, Target target, + Args args); + + using ExprFactory::NewListElement; + + template ...>>> + ABSL_MUST_USE_RESULT Expr NewList(Elements&&... elements); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewList(Elements elements); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT StructExprField NewStructField(Name name, Value value, + bool optional = false); + + template ::value>, + typename = std::enable_if_t< + std::conjunction_v...>>> + ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields&&... fields); + + template < + typename Name, typename Fields, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewStruct(Name name, Fields fields); + + template ::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT MapExprEntry NewMapEntry(Key key, Value value, + bool optional = false); + + template ...>>> + ABSL_MUST_USE_RESULT Expr NewMap(Entries&&... entries); + + template ::value>> + ABSL_MUST_USE_RESULT Expr NewMap(Entries entries); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewComprehension(IterVar iter_var, + IterRange iter_range, + AccuVar accu_var, + AccuInit accu_init, + LoopCondition loop_condition, + LoopStep loop_step, Result result); + + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + ABSL_MUST_USE_RESULT Expr NewComprehension( + IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result); + + ABSL_MUST_USE_RESULT Expr ReportError(absl::string_view message); + + // Reports an error at the id in the optimized AST. + ABSL_MUST_USE_RESULT Expr ReportErrorAt(const Expr& expr, + absl::string_view message); + // Reports an error at the mapped id of the copy of expr in the optimized AST. + ABSL_MUST_USE_RESULT Expr ReportErrorAtCopy(const Expr& expr, + absl::string_view message); + + protected: + ABSL_MUST_USE_RESULT ExprId NextId(); + + ABSL_MUST_USE_RESULT ExprId CopyId(ExprId id); + + ABSL_MUST_USE_RESULT ExprId CopyId(const Expr& expr); + + using ExprFactory::AccuVarName; + using ExprFactory::NewAccuIdent; + using ExprFactory::NewBoolConst; + using ExprFactory::NewBytesConst; + using ExprFactory::NewCall; + using ExprFactory::NewComprehension; + using ExprFactory::NewConst; + using ExprFactory::NewDoubleConst; + using ExprFactory::NewIdent; + using ExprFactory::NewIntConst; + using ExprFactory::NewList; + using ExprFactory::NewMap; + using ExprFactory::NewMapEntry; + using ExprFactory::NewMemberCall; + using ExprFactory::NewNullConst; + using ExprFactory::NewPresenceTest; + using ExprFactory::NewSelect; + using ExprFactory::NewStringConst; + using ExprFactory::NewStruct; + using ExprFactory::NewStructField; + using ExprFactory::NewUintConst; + using ExprFactory::NewUnspecified; + + private: + Ast ast_; + absl::flat_hash_map renumbers_; + std::vector issues_; + + ExprId next_id_ = 1; +}; + +// Implementation details. + +template +Expr OptimizerExprFactory::NewIdent(Name name) { + return NewIdent(NextId(), std::move(name)); +} + +template +Expr OptimizerExprFactory::NewSelect(Operand operand, Field field) { + return NewSelect(NextId(), std::move(operand), std::move(field)); +} + +template +Expr OptimizerExprFactory::NewPresenceTest(Operand operand, Field field) { + return NewPresenceTest(NextId(), std::move(operand), std::move(field)); +} + +template +Expr OptimizerExprFactory::NewCall(Function function, Args&&... args) { + std::vector array; + array.reserve(sizeof...(Args)); + (array.push_back(std::forward(args)), ...); + return NewCall(NextId(), std::move(function), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewCall(Function function, Args args) { + return NewCall(NextId(), std::move(function), std::move(args)); +} + +template +Expr OptimizerExprFactory::NewMemberCall(Function function, Target target, + Args&&... args) { + std::vector array; + array.reserve(sizeof...(Args)); + (array.push_back(std::forward(args)), ...); + return NewMemberCall(NextId(), std::move(function), std::move(target), + std::move(array)); +} + +template +Expr OptimizerExprFactory::NewMemberCall(Function function, Target target, + Args args) { + return NewMemberCall(NextId(), std::move(function), std::move(target), + std::move(args)); +} + +template +Expr OptimizerExprFactory::NewList(Elements&&... elements) { + std::vector array; + array.reserve(sizeof...(Elements)); + (array.push_back(std::forward(elements)), ...); + return NewList(NextId(), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewList(Elements elements) { + return NewList(NextId(), std::move(elements)); +} + +template +StructExprField OptimizerExprFactory::NewStructField(Name name, Value value, + bool optional) { + return NewStructField(NextId(), std::move(name), std::move(value), optional); +} + +template +Expr OptimizerExprFactory::NewStruct(Name name, Fields&&... fields) { + std::vector array; + array.reserve(sizeof...(Fields)); + (array.push_back(std::forward(fields)), ...); + return NewStruct(NextId(), std::move(name), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewStruct(Name name, Fields fields) { + return NewStruct(NextId(), std::move(name), std::move(fields)); +} + +template +MapExprEntry OptimizerExprFactory::NewMapEntry(Key key, Value value, + bool optional) { + return NewMapEntry(NextId(), std::move(key), std::move(value), optional); +} + +template +Expr OptimizerExprFactory::NewMap(Entries&&... entries) { + std::vector array; + array.reserve(sizeof...(Entries)); + (array.push_back(std::forward(entries)), ...); + return NewMap(NextId(), std::move(array)); +} + +template +Expr OptimizerExprFactory::NewMap(Entries entries) { + return NewMap(NextId(), std::move(entries)); +} + +template +Expr OptimizerExprFactory::NewComprehension(IterVar iter_var, + IterRange iter_range, + AccuVar accu_var, + AccuInit accu_init, + LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), std::move(iter_range), + std::move(accu_var), std::move(accu_init), + std::move(loop_condition), std::move(loop_step), + std::move(result)); +} + +template +Expr OptimizerExprFactory::NewComprehension( + IterVar iter_var, IterVar2 iter_var2, IterRange iter_range, + AccuVar accu_var, AccuInit accu_init, LoopCondition loop_condition, + LoopStep loop_step, Result result) { + return NewComprehension(NextId(), std::move(iter_var), std::move(iter_var2), + std::move(iter_range), std::move(accu_var), + std::move(accu_init), std::move(loop_condition), + std::move(loop_step), std::move(result)); +} + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_INTERNAL_OPTIMIZER_EXPR_FACTORY_H_ diff --git a/policy/internal/optimizer_expr_factory_test.cc b/policy/internal/optimizer_expr_factory_test.cc new file mode 100644 index 000000000..1b14b5628 --- /dev/null +++ b/policy/internal/optimizer_expr_factory_test.cc @@ -0,0 +1,570 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/internal/optimizer_expr_factory.h" + +#include +#include +#include + +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/ast_rewrite.h" +#include "common/decl.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/source.h" +#include "common/type.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" +#include "compiler/standard_library.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "testutil/expr_printer.h" +#include "tools/cel_unparser.h" + +namespace cel { + +using ::testing::SizeIs; + +// Expose protected members of OptimizerExprFactory for use in tests +// +// These allow setting explicit IDs which is not safe for the optimizing +// factory. +class TestOptimizerExprFactory final : public OptimizerExprFactory { + public: + using OptimizerExprFactory::OptimizerExprFactory; + + using OptimizerExprFactory::NewBoolConst; + using OptimizerExprFactory::NewCall; + using OptimizerExprFactory::NewComprehension; + using OptimizerExprFactory::NewIdent; + using OptimizerExprFactory::NewList; + using OptimizerExprFactory::NewListElement; + using OptimizerExprFactory::NewMap; + using OptimizerExprFactory::NewMapEntry; + using OptimizerExprFactory::NewMemberCall; + using OptimizerExprFactory::NewSelect; + using OptimizerExprFactory::NewStruct; + using OptimizerExprFactory::NewStructField; + using OptimizerExprFactory::NewUnspecified; + using OptimizerExprFactory::NextId; +}; + +namespace { + +class ReplaceExprRewriter final : public AstRewriterBase { + public: + ReplaceExprRewriter(ExprId old_id, const Expr& replacement) + : old_id_(old_id), replacement_(replacement) {} + + bool PreVisitRewrite(Expr& expr) override { + if (expr.id() == old_id_) { + expr = replacement_; + return true; + } + return false; + } + + private: + ExprId old_id_; + const Expr& replacement_; +}; + +void ReplaceExprInTree(Expr& expr, ExprId old_id, const Expr& replacement) { + ReplaceExprRewriter rewriter(old_id, replacement); + AstRewrite(expr, rewriter); +} + +absl::StatusOr> CreateTestCompiler() { + CompilerOptions opts; + opts.parser_options.add_macro_calls = true; + CEL_ASSIGN_OR_RETURN( + auto builder, cel::NewCompilerBuilder( + cel::internal::GetSharedTestingDescriptorPool(), opts)); + CEL_RETURN_IF_ERROR(builder->AddLibrary(cel::StandardCompilerLibrary())); + CEL_RETURN_IF_ERROR(builder->GetCheckerBuilder().AddVariable( + cel::MakeVariableDecl("to_replace", cel::DynType()))); + return builder->Build(); +} + +TEST(OptimizerExprFactory, CopyUnspecified) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewUnspecified()), factory.NewUnspecified(2)); +} + +TEST(OptimizerExprFactory, CopyIdent) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewIdent("foo")), factory.NewIdent(2, "foo")); +} + +TEST(OptimizerExprFactory, CopyConst) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewBoolConst(true)), + factory.NewBoolConst(2, true)); +} + +TEST(OptimizerExprFactory, CopySelect) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ(factory.Copy(factory.NewSelect(factory.NewIdent("foo"), "bar")), + factory.NewSelect(3, factory.NewIdent(4, "foo"), "bar")); +} + +TEST(OptimizerExprFactory, CopyCall) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_args; + copied_args.reserve(1); + copied_args.push_back(factory.NewIdent(6, "baz")); + EXPECT_EQ(factory.Copy(factory.NewMemberCall("bar", factory.NewIdent("foo"), + factory.NewIdent("baz"))), + factory.NewMemberCall(4, "bar", factory.NewIdent(5, "foo"), + absl::MakeSpan(copied_args))); +} + +TEST(OptimizerExprFactory, CopyList) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_elements; + copied_elements.reserve(1); + copied_elements.push_back(factory.NewListElement(factory.NewIdent(4, "foo"))); + EXPECT_EQ(factory.Copy(factory.NewList( + factory.NewListElement(factory.NewIdent("foo")))), + factory.NewList(3, absl::MakeSpan(copied_elements))); +} + +TEST(OptimizerExprFactory, CopyStruct) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_fields; + copied_fields.reserve(1); + copied_fields.push_back( + factory.NewStructField(5, "bar", factory.NewIdent(6, "baz"))); + EXPECT_EQ(factory.Copy(factory.NewStruct( + "foo", factory.NewStructField("bar", factory.NewIdent("baz")))), + factory.NewStruct(4, "foo", absl::MakeSpan(copied_fields))); +} + +TEST(OptimizerExprFactory, CopyMap) { + TestOptimizerExprFactory factory{Ast()}; + std::vector copied_entries; + copied_entries.reserve(1); + copied_entries.push_back(factory.NewMapEntry(6, factory.NewIdent(7, "bar"), + factory.NewIdent(8, "baz"))); + EXPECT_EQ(factory.Copy(factory.NewMap(factory.NewMapEntry( + factory.NewIdent("bar"), factory.NewIdent("baz")))), + factory.NewMap(5, absl::MakeSpan(copied_entries))); +} + +TEST(OptimizerExprFactory, CopyComprehension) { + TestOptimizerExprFactory factory{Ast()}; + EXPECT_EQ( + factory.Copy(factory.NewComprehension( + "foo", factory.NewList(), "bar", factory.NewBoolConst(true), + factory.NewIdent("baz"), factory.NewIdent("foo"), + factory.NewIdent("bar"))), + factory.NewComprehension( + 7, "foo", factory.NewList(8, std::vector()), "bar", + factory.NewBoolConst(9, true), factory.NewIdent(10, "baz"), + factory.NewIdent(11, "foo"), factory.NewIdent(12, "bar"))); +} + +TEST(OptimizerExprFactory, RemapSourceInfo) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + SourceInfo info; + info.mutable_positions()[1] = 42; // old ID 1 has position 42 + + SourceInfo remapped = factory.RemapSourceInfo(info, 10); + + // remapped should have ID 2 mapped to position 42 + 10 = 52 + auto it = remapped.positions().find(2); + ASSERT_NE(it, remapped.positions().end()); + EXPECT_EQ(it->second, 52); +} + +TEST(OptimizerExprFactory, RemapSourceInfoWithMacroCalls) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + SourceInfo info; + // old ID 1 has macro call with ID 3 + info.mutable_macro_calls()[1] = factory.NewIdent("bar"); + + SourceInfo remapped = factory.RemapSourceInfo(info, 10); + + // remapped should have ID 2 mapped to the copied macro call + // since "bar" has ID 3, Copy(bar) should map ID 3 to ID 4 + + auto it = remapped.macro_calls().find(2); + ASSERT_NE(it, remapped.macro_calls().end()); + + // The macro call should be an Ident with new ID 4 + EXPECT_EQ(it->second.id(), 4); + EXPECT_TRUE(it->second.has_ident_expr()); + EXPECT_EQ(it->second.ident_expr().name(), "bar"); +} + +TEST(OptimizerExprFactory, ReportError) { + TestOptimizerExprFactory factory{Ast()}; + Expr err_expr = factory.ReportError("something went wrong"); + + // err_expr should be unspecified with ID 1 + EXPECT_EQ(err_expr.id(), 1); + EXPECT_EQ(err_expr.kind_case(), ExprKindCase::kUnspecifiedExpr); + + // issues_ should have 1 entry with ID 1 and correct message + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "something went wrong"); +} + +TEST(OptimizerExprFactory, ReportErrorAt) { + TestOptimizerExprFactory factory{Ast()}; + Expr orig = factory.NewIdent("foo"); // allocates ID 1 + Expr copied = factory.Copy(orig); // copies ID 1 to mapped ID 2 + + Expr err_expr = factory.ReportErrorAtCopy(orig, "error on foo"); + + // err_expr should be unspecified with ID 3 (NextId) + EXPECT_EQ(err_expr.id(), 3); + EXPECT_EQ(err_expr.kind_case(), ExprKindCase::kUnspecifiedExpr); + + // issues_ should have 1 entry with mapped ID 2 and correct message + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 2); + EXPECT_EQ(factory.issues()[0].message, "error on foo"); +} + +TEST(OptimizerExprFactory, MergeSourceInfo) { + // Create a base AST with some source info + SourceInfo base_info; + base_info.set_syntax_version("cel1"); + base_info.set_location("test.cel"); + base_info.mutable_positions()[1] = 10; + + Ast base_ast(Expr(), std::move(base_info)); + + TestOptimizerExprFactory factory{std::move(base_ast)}; + + // Create a new source info to merge + SourceInfo new_info; + new_info.mutable_positions()[2] = 20; + + factory.MergeSourceInfo(new_info); + + // The merged source info should have both positions + const auto& merged_info = factory.ast().source_info(); + EXPECT_EQ(merged_info.syntax_version(), "cel1"); + EXPECT_EQ(merged_info.location(), "test.cel"); + + auto it1 = merged_info.positions().find(1); + ASSERT_NE(it1, merged_info.positions().end()); + EXPECT_EQ(it1->second, 10); + + auto it2 = merged_info.positions().find(2); + ASSERT_NE(it2, merged_info.positions().end()); + EXPECT_EQ(it2->second, 20); +} + +TEST(OptimizerExprFactory, MergeSourceInfoConflict) { + SourceInfo base_info; + base_info.mutable_positions()[1] = 10; + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory factory{std::move(base_ast)}; + + SourceInfo new_info; + new_info.mutable_positions()[1] = 20; // conflicting ID 1 + + factory.MergeSourceInfo(new_info); + + // Should report an error for the conflict + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "conflicting ID in positions merge"); +} + +TEST(OptimizerExprFactory, RecordReplacement) { + SourceInfo base_info; + base_info.mutable_positions()[1] = 10; + base_info.mutable_positions()[2] = 20; + + TestOptimizerExprFactory factory{Ast()}; + + // macro_calls[1] maps ID 1 to macro call "bar(foo)" (where "foo" has ID 1) + base_info.mutable_macro_calls()[1] = + factory.NewCall("bar", factory.NewIdent(1, "foo")); + + // macro_calls[2] maps ID 2 to macro call "baz(foo)" (where "foo" has ID 1) + base_info.mutable_macro_calls()[2] = + factory.NewCall("baz", factory.NewIdent(1, "foo")); + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory optimizer{std::move(base_ast)}; + + // Record the replacement of ID 1 by a new Ident "replacement" with ID 3 + optimizer.RecordReplacement(1, factory.NewIdent(3, "replacement")); + + const auto& result_info = optimizer.ast().source_info(); + + // 1. ID 1 should be erased from positions + EXPECT_EQ(result_info.positions().find(1), result_info.positions().end()); + EXPECT_NE(result_info.positions().find(2), result_info.positions().end()); + + // 2. ID 1 should be erased from macro_calls keys + EXPECT_EQ(result_info.macro_calls().find(1), result_info.macro_calls().end()); + + // 3. macro_calls[2] should still exist, but its argument referencing ID 1 + // should be replaced with the Ident "replacement" with ID 3 inline + auto it = result_info.macro_calls().find(2); + ASSERT_NE(it, result_info.macro_calls().end()); + + const Expr& macro_expr = it->second; + ASSERT_TRUE(macro_expr.has_call_expr()); + ASSERT_EQ(macro_expr.call_expr().args().size(), 1); + + const Expr& arg = macro_expr.call_expr().args()[0]; + EXPECT_EQ(arg.id(), 3); + EXPECT_TRUE(arg.has_ident_expr()); + EXPECT_EQ(arg.ident_expr().name(), "replacement"); +} + +class IdAdorner : public cel::test::ExpressionAdorner { + public: + std::string Adorn(const cel::Expr& e) const override { + return absl::StrCat("#", e.id()); + } + + std::string AdornStructField(const cel::StructExprField& e) const override { + return absl::StrCat("#", e.id()); + } + + std::string AdornMapEntry(const cel::MapExprEntry& e) const override { + return absl::StrCat("#", e.id()); + } +}; + +TEST(OptimizerExprFactory, UnparseCopiedMacroCall) { + // Arrange: create an template expression and one to inline. + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto basis_result, + compiler->Compile("[1].map(x, x + to_replace)")); + ASSERT_TRUE(basis_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto basis_ast, basis_result.ReleaseAst()); + + ASSERT_OK_AND_ASSIGN(auto copy_result, + compiler->Compile("[1].filter(x, x > 2).size()")); + ASSERT_TRUE(copy_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto copy_ast, copy_result.ReleaseAst()); + + // Locate the "to_replace" IdentExpr node in reference_map + ExprId to_replace_id = 0; + for (const auto& [id, ref] : basis_ast->reference_map()) { + if (ref.name() == "to_replace") { + to_replace_id = id; + break; + } + } + ASSERT_NE(to_replace_id, 0); + + // Act: implement the optimization. + TestOptimizerExprFactory factory{std::move(*basis_ast)}; + Expr copied_expr = factory.Copy(copy_ast->root_expr()); + SourceInfo remapped_info = factory.RemapSourceInfo(copy_ast->source_info()); + factory.MergeSourceInfo(remapped_info); + + ReplaceExprInTree(factory.mutable_ast().mutable_root_expr(), to_replace_id, + copied_expr); + factory.RecordReplacement(to_replace_id, copied_expr); + + // Test AST structure. + EXPECT_EQ( + cel::test::ExprPrinter(IdAdorner()).Print(factory.ast().root_expr()), + R"(__comprehension__( + // Variable + x, + // Target + [ + 1#2 + ]#1, + // Accumulator + @result, + // Init + []#8, + // LoopCondition + true#9, + // LoopStep + _+_( + @result#10, + [ + _+_( + x#5, + __comprehension__( + // Variable + x, + // Target + [ + 1#18 + ]#17, + // Accumulator + @result, + // Init + []#19, + // LoopCondition + true#20, + // LoopStep + _?_:_( + _>_( + x#23, + 2#24 + )#22, + _+_( + @result#26, + [ + x#28 + ]#27 + )#25, + @result#29 + )#21, + // Result + @result#30)#16.size()#15 + )#6 + ]#11 + )#12, + // Result + @result#13)#14)"); + + // Check that the structure is compatible with unparser. + cel::expr::ParsedExpr optimized_parsed; + auto status = AstToParsedExpr(factory.ast(), &optimized_parsed); + ASSERT_THAT(status, absl_testing::IsOk()); + ASSERT_OK_AND_ASSIGN(std::string unparsed, + google::api::expr::Unparse(optimized_parsed)); + + EXPECT_EQ(unparsed, "[1].map(x, x + [1].filter(x, x > 2).size())"); + + const CallExpr& call_expr = factory.mutable_ast() + .mutable_source_info() + .mutable_macro_calls()[14] + .mutable_call_expr(); + ASSERT_THAT(call_expr.args(), SizeIs(2)); + ASSERT_THAT(call_expr.args()[1].call_expr().args(), SizeIs(2)); + EXPECT_EQ(call_expr.args()[1].call_expr().args()[1].id(), 15); + + EXPECT_EQ(call_expr.args()[1].call_expr().args()[1].call_expr().target().id(), + 16); + EXPECT_EQ(call_expr.args()[1] + .call_expr() + .args()[1] + .call_expr() + .target() + .kind_case(), + ExprKindCase::kUnspecifiedExpr); +} + +TEST(OptimizerExprFactory, CopyMultipleAstsWithConsumeRenumbers) { + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler, + CreateTestCompiler()); + + ASSERT_OK_AND_ASSIGN(auto ast1_result, compiler->Compile("[1]")); + ASSERT_TRUE(ast1_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast1, ast1_result.ReleaseAst()); + + ASSERT_OK_AND_ASSIGN(auto ast2_result, compiler->Compile("2")); + ASSERT_TRUE(ast2_result.IsValid()); + ASSERT_OK_AND_ASSIGN(auto ast2, ast2_result.ReleaseAst()); + + TestOptimizerExprFactory factory{Ast()}; + + Expr copied1 = factory.Copy(ast1->root_expr()); + auto renumbers1 = factory.ConsumeRenumbers(); + + Expr copied2 = factory.Copy(ast2->root_expr()); + auto renumbers2 = factory.ConsumeRenumbers(); + + EXPECT_EQ(renumbers1.size(), 2); + EXPECT_EQ(renumbers2.size(), 1); + + EXPECT_NE(copied1.id(), copied2.id()); + EXPECT_GT(copied2.id(), copied1.id()); +} + +TEST(OptimizerExprFactory, MaxIdVisitorExprKinds) { + ASSERT_OK_AND_ASSIGN(auto compiler, CreateTestCompiler()); + + // Expression that covers all the kinds. + ASSERT_OK_AND_ASSIGN(auto source, NewSource(R"cel( + Struct{field : 1} || + {'key' : 'value'} || [1].exists(x, x) || foo(bar))cel")); + ASSERT_OK_AND_ASSIGN(auto ast, compiler->GetParser().Parse(*source)); + + TestOptimizerExprFactory factory{std::move(*ast)}; + + EXPECT_EQ(factory.NextId(), 26); +} + +TEST(OptimizerExprFactory, CopyListElement) { + TestOptimizerExprFactory factory{Ast()}; + ListExprElement orig = factory.NewListElement(factory.NewIdent("foo")); + ListExprElement copied = factory.Copy(orig); + EXPECT_EQ(copied.expr(), factory.NewIdent(2, "foo")); +} + +TEST(OptimizerExprFactory, CopyStructField) { + TestOptimizerExprFactory factory{Ast()}; + StructExprField orig = factory.NewStructField("bar", factory.NewIdent("baz")); + StructExprField copied = factory.Copy(orig); + EXPECT_EQ(copied.id(), 3); + EXPECT_EQ(copied.name(), "bar"); + EXPECT_EQ(copied.value(), factory.NewIdent(4, "baz")); +} + +TEST(OptimizerExprFactory, CopyMapEntry) { + TestOptimizerExprFactory factory{Ast()}; + MapExprEntry orig = + factory.NewMapEntry(factory.NewIdent("bar"), factory.NewIdent("baz")); + MapExprEntry copied = factory.Copy(orig); + EXPECT_EQ(copied.id(), 4); + EXPECT_EQ(copied.key(), factory.NewIdent(5, "bar")); + EXPECT_EQ(copied.value(), factory.NewIdent(6, "baz")); +} + +TEST(OptimizerExprFactory, MergeSourceInfoMacroConflict) { + SourceInfo base_info; + base_info.mutable_macro_calls()[1] = Expr(); + + Ast base_ast(Expr(), std::move(base_info)); + TestOptimizerExprFactory factory{std::move(base_ast)}; + + SourceInfo new_info; + new_info.mutable_macro_calls()[1] = Expr(); + + factory.MergeSourceInfo(new_info); + + ASSERT_EQ(factory.issues().size(), 1); + EXPECT_EQ(factory.issues()[0].location, 1); + EXPECT_EQ(factory.issues()[0].message, "conflicting ID in macro calls merge"); +} + +} // namespace +} // namespace cel diff --git a/policy/test_custom_yaml_policy_parser.cc b/policy/test_custom_yaml_policy_parser.cc new file mode 100644 index 000000000..faced6952 --- /dev/null +++ b/policy/test_custom_yaml_policy_parser.cc @@ -0,0 +1,188 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parser.h" +#include "policy/yaml_policy_parser.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel::internal { + +// TestCustomYamlPolicyParser is used to support unit tests for custom tags +// and custom policy structures. It demonstrates the versatility of the +// cel::YamlPolicyParser framework API by implementing custom tag and block +// parsing without needing to modify the core parser. +class TestCustomYamlPolicyParser : public cel::YamlPolicyParser { + absl::StatusOr ParsePolicyTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node) const override { + if (tag_name.value() == "name" || tag_name.value() == "description" || + tag_name.value() == "imports") { + return cel::YamlPolicyParser::ParsePolicyTag(ctx, tag_name, node); + } + if (tag_name.value() == "purpose") { + std::optional purpose = + GetValueString(ctx, node, "Policy purpose is not a string"); + if (purpose.has_value()) { + ctx.policy().mutable_metadata()["purpose"] = *purpose; + } + return true; + } + if (tag_name.value() == "version") { + std::optional version = + GetValueString(ctx, node, "Policy version is not a string"); + if (!version.has_value()) { + return true; + } + int version_int; + if (!absl::SimpleAtoi(version->value(), &version_int)) { + ctx.ReportError(version->id(), + absl::StrCat("Policy version is not an integer: ", + version->value())); + return true; + } + ctx.policy().mutable_metadata()["version"] = version_int; + return true; + } + + if (tag_name.value() == "conditions") { + if (!node.IsSequence()) { + ctx.ReportError(tag_name.id(), "Policy 'conditions' is not a sequence"); + return true; + } + for (const YAML::Node& condition : node) { + // Track the number of existing matches before parsing. When ParseMatch + // evaluates an 'else' block, it recursively triggers parsing and adds + // internal inner matches directly to the rule's match vector. + // Inserting the outer match at begin() + size_before ensures that the + // primary outer 'if' condition is always evaluated before its nested + // 'else' fallbacks. + // + // Example: + // if: x > 0 + // then: "positive" + // else: "negative" + // + // The inner "negative" match is parsed and appended to rule.matches() + // by the inner recursive call, before the outer "x > 0" match finishes. + // Inserting at size_before places the "x > 0" match ahead of the inner + // one. + size_t size_before = ctx.policy().rule().matches().size(); + CEL_ASSIGN_OR_RETURN(Match match, + cel::YamlPolicyParser::ParseMatch( + ctx, condition, ctx.policy().mutable_rule())); + ctx.policy().mutable_rule().mutable_matches().insert( + ctx.policy().mutable_rule().mutable_matches().begin() + size_before, + std::move(match)); + } + + return true; + } + return false; + } + + absl::Status ParseThenBlock(CelPolicyParseContext& ctx, + const YAML::Node& value_node, + Match& match) const { + if (value_node.IsScalar()) { + std::optional val = GetValueString( + ctx, value_node, "Policy condition 'then' is not a string"); + if (val.has_value()) { + OutputBlock output; + output.set_output(*val); + match.set_result(output); + } + } else if (value_node.IsMap()) { + auto nested_rule = std::make_unique(); + CEL_ASSIGN_OR_RETURN( + Match nested_match, + cel::YamlPolicyParser::ParseMatch(ctx, value_node, *nested_rule)); + nested_rule->mutable_matches().insert( + nested_rule->mutable_matches().begin(), std::move(nested_match)); + match.set_result(std::move(nested_rule)); + } else { + ctx.ReportError(CollectMetadata(ctx, value_node), + "Bad syntax in 'if/then' block"); + } + return absl::OkStatus(); + } + + absl::Status ParseElseBlock(CelPolicyParseContext& ctx, + const YAML::Node& value_node, Rule& rule) const { + if (value_node.IsScalar()) { + std::optional val = GetValueString( + ctx, value_node, "Policy condition 'else' is not a string"); + if (val.has_value()) { + Match else_match; + else_match.set_id(CollectMetadata(ctx, value_node)); + OutputBlock output; + output.set_output(*val); + else_match.set_result(output); + rule.mutable_matches().push_back(std::move(else_match)); + } + } else if (value_node.IsMap()) { + size_t size_before = rule.matches().size(); + CEL_ASSIGN_OR_RETURN(Match match, cel::YamlPolicyParser::ParseMatch( + ctx, value_node, rule)); + rule.mutable_matches().insert( + rule.mutable_matches().begin() + size_before, std::move(match)); + } else { + ctx.ReportError(CollectMetadata(ctx, value_node), + "Bad syntax in 'if/then' block"); + } + return absl::OkStatus(); + } + + absl::StatusOr ParseMatchTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, Match& match, + Rule& rule) const override { + if (tag_name.value() == "if") { + std::optional condition = + GetValueString(ctx, node, "Policy 'if' condition is not a string"); + if (condition.has_value()) { + match.set_condition(*condition); + } + return true; + } + if (tag_name.value() == "then") { + CEL_RETURN_IF_ERROR(ParseThenBlock(ctx, node, match)); + return true; + } + if (tag_name.value() == "else") { + CEL_RETURN_IF_ERROR(ParseElseBlock(ctx, node, rule)); + return true; + } + return false; + } +}; + +const CelPolicyParser& GetTestCustomYamlPolicyParser() { + static const auto* const parser = new TestCustomYamlPolicyParser(); + return *parser; +} + +} // namespace cel::internal diff --git a/policy/test_util.cc b/policy/test_util.cc new file mode 100644 index 000000000..9fe1e43d1 --- /dev/null +++ b/policy/test_util.cc @@ -0,0 +1,221 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +#include "policy/test_util.h" + +#include +#include +#include +#include + +#include "cel/expr/eval.pb.h" +#include "cel/expr/value.pb.h" +#include "google/protobuf/struct.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "internal/status_macros.h" +#include "yaml-cpp/yaml.h" + +namespace cel::test { + +namespace { + +absl::Status YamlToExprValue(const YAML::Node& node, + cel::expr::Value* proto) { + if (node.IsNull()) { + proto->set_null_value(google::protobuf::NULL_VALUE); + return absl::OkStatus(); + } + if (node.IsScalar()) { + // Try bool + try { + proto->set_bool_value(node.as()); + return absl::OkStatus(); + } catch (...) { + } + // Try int64 + try { + int64_t val; + if (YAML::convert::decode(node, val)) { + proto->set_int64_value(val); + return absl::OkStatus(); + } + } catch (...) { + } + // Try double + try { + double val; + if (YAML::convert::decode(node, val)) { + proto->set_double_value(val); + return absl::OkStatus(); + } + } catch (...) { + } + // Fallback to string + proto->set_string_value(node.as()); + return absl::OkStatus(); + } + if (node.IsSequence()) { + auto* list = proto->mutable_list_value(); + for (const auto& elem : node) { + CEL_RETURN_IF_ERROR(YamlToExprValue(elem, list->add_values())); + } + return absl::OkStatus(); + } + if (node.IsMap()) { + auto* map_val = proto->mutable_map_value(); + for (auto it = node.begin(); it != node.end(); ++it) { + auto* entry = map_val->add_entries(); + CEL_RETURN_IF_ERROR(YamlToExprValue(it->first, entry->mutable_key())); + CEL_RETURN_IF_ERROR(YamlToExprValue(it->second, entry->mutable_value())); + } + return absl::OkStatus(); + } + return absl::InvalidArgumentError("Unknown YAML node type"); +} + +absl::Status ParseInputValue( + const YAML::Node& node, + cel::expr::conformance::test::InputValue* input_val) { + if (node.IsMap() && node["expr"].IsDefined()) { + input_val->set_expr(node["expr"].as()); + return absl::OkStatus(); + } + if (node.IsMap() && node["value"].IsDefined()) { + return YamlToExprValue(node["value"], input_val->mutable_value()); + } + return YamlToExprValue(node, input_val->mutable_value()); +} + +absl::Status ParseTestOutput(const YAML::Node& node, + cel::expr::conformance::test::TestOutput* output) { + if (!node.IsDefined()) { + return absl::InvalidArgumentError("Missing output node"); + } + if (node.IsMap()) { + if (node["expr"].IsDefined()) { + output->set_result_expr(node["expr"].as()); + return absl::OkStatus(); + } + if (node["value"].IsDefined()) { + return YamlToExprValue(node["value"], output->mutable_result_value()); + } + if (node["error"].IsDefined()) { + auto* eval_error = output->mutable_eval_error(); + eval_error->add_errors()->set_message(node["error"].as()); + return absl::OkStatus(); + } + if (node["error_set"].IsDefined()) { + auto* eval_error = output->mutable_eval_error(); + for (const auto& err : node["error_set"]) { + eval_error->add_errors()->set_message(err.as()); + } + return absl::OkStatus(); + } + if (node["unknown"].IsDefined()) { + auto* unknown = output->mutable_unknown(); + for (const auto& expr_id_node : node["unknown"]) { + unknown->add_exprs(expr_id_node.as()); + } + return absl::OkStatus(); + } + } + return YamlToExprValue(node, output->mutable_result_value()); +} + +absl::StatusOr +ParsePolicyTestSuiteYamlImpl(absl::string_view yaml_content) { + YAML::Node tests_node; + try { + tests_node = YAML::Load(std::string(yaml_content)); + } catch (const std::exception& e) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to parse YAML: ", e.what())); + } + + cel::expr::conformance::test::TestSuite test_suite; + if (tests_node["description"].IsDefined()) { + test_suite.set_description(tests_node["description"].as()); + } + + YAML::Node sections = tests_node["sections"]; + if (!sections.IsDefined()) { + sections = tests_node["section"]; // support singular format + } + if (!sections.IsDefined()) { + return absl::InvalidArgumentError( + "Missing 'sections' or 'section' in tests YAML"); + } + + for (const auto& section_node : sections) { + auto* section = test_suite.add_sections(); + if (section_node["name"].IsDefined()) { + section->set_name(section_node["name"].as()); + } + if (section_node["description"].IsDefined()) { + section->set_description(section_node["description"].as()); + } + + YAML::Node tests = section_node["tests"]; + if (!tests.IsDefined()) { + tests = section_node["test"]; // support singular format + } + if (!tests.IsDefined()) { + continue; + } + + for (const auto& test_node : tests) { + auto* test_case = section->add_tests(); + if (test_node["name"].IsDefined()) { + test_case->set_name(test_node["name"].as()); + } + if (test_node["description"].IsDefined()) { + test_case->set_description(test_node["description"].as()); + } + if (test_node["context_expr"].IsDefined()) { + test_case->mutable_input_context()->set_context_expr( + test_node["context_expr"].as()); + } + + YAML::Node input_node = test_node["input"]; + if (input_node.IsDefined() && input_node.IsMap()) { + auto* input_map = test_case->mutable_input(); + for (auto it = input_node.begin(); it != input_node.end(); ++it) { + std::string var_name = it->first.as(); + cel::expr::conformance::test::InputValue input_val; + CEL_RETURN_IF_ERROR(ParseInputValue(it->second, &input_val)); + (*input_map)[var_name] = std::move(input_val); + } + } + + YAML::Node output_node = test_node["output"]; + if (output_node.IsDefined()) { + CEL_RETURN_IF_ERROR( + ParseTestOutput(output_node, test_case->mutable_output())); + } + } + } + + return test_suite; +} + +} // namespace + +absl::StatusOr +ParsePolicyTestSuiteYaml(absl::string_view yaml_content) { + try { + return ParsePolicyTestSuiteYamlImpl(yaml_content); + } catch (...) { + return absl::InvalidArgumentError("Failed to parse YAML"); + } +} + +} // namespace cel::test diff --git a/policy/test_util.h b/policy/test_util.h new file mode 100644 index 000000000..5fe306050 --- /dev/null +++ b/policy/test_util.h @@ -0,0 +1,33 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "cel/expr/conformance/test/suite.pb.h" + +namespace cel::test { + +// Parses a YAML content representing a policy test suite (tests.yaml) +// and adapts it to the cel.expr.conformance.test.TestSuite protobuf message. +// +// TODO(uncreated-issue/92): Move to the testrunner library. +absl::StatusOr +ParsePolicyTestSuiteYaml(absl::string_view yaml_content); + +} // namespace cel::test + +#endif // THIRD_PARTY_CEL_CPP_POLICY_TEST_UTIL_H_ diff --git a/policy/testdata/BUILD b/policy/testdata/BUILD new file mode 100644 index 000000000..10a26fa0b --- /dev/null +++ b/policy/testdata/BUILD @@ -0,0 +1,19 @@ +package( + default_testonly = True, + default_visibility = ["//visibility:public"], +) + +filegroup( + name = "policy_testdata", + srcs = glob([ + "*.yaml", + "*.baseline", + ]), +) + +exports_files( + srcs = glob([ + "*.yaml", + "*.baseline", + ]), +) diff --git a/policy/testdata/cel_policy.yaml b/policy/testdata/cel_policy.yaml new file mode 100644 index 000000000..010ad8855 --- /dev/null +++ b/policy/testdata/cel_policy.yaml @@ -0,0 +1,42 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Environment: +# spec: TestAllTypes +name: cel_policy +description: A test policy for CEL +display_name: Cel Policy +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +- name: cel.expr.conformance.proto3.TestAllTypes.NestedEnum +rule: + id: test_rule + description: test rule description + variables: + - name: test_var + expression: > + TestAllTypes{single_int64: 10}.single_int64 + match: + - condition: > + spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + output: | + "invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + explanation: | + "invalid spec, spec is greater than 10" + - condition: > + spec.standalone_enum == NestedEnum.BAR + output: | + "invalid spec, reference to BAR is not allowed" + - condition: spec.single_int64 == variables.test_var + output: '"invalid spec: exactly matches test_var"' + explanation: '"the spec cannot have single_int64 set to a known bad value"' \ No newline at end of file diff --git a/policy/testdata/cel_policy_parser.baseline b/policy/testdata/cel_policy_parser.baseline new file mode 100644 index 000000000..7a6678bfe --- /dev/null +++ b/policy/testdata/cel_policy_parser.baseline @@ -0,0 +1,89 @@ +POLICY SOURCE: cel_policy.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2026 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + # Environment: + # spec: TestAllTypes + #0> name: #1> cel_policy + #2> description: #3> A test policy for CEL + #4> display_name: #5> Cel Policy + #6> imports: + - #7> name: #8> cel.expr.conformance.proto3.TestAllTypes + - #9> name: #10> cel.expr.conformance.proto3.TestAllTypes.NestedEnum + #11> rule: + #13> #12> id: #14> test_rule + #15> description: #16> test rule description + #17> variables: + - #18> name: #19> test_var + #20> expression: #21> > + TestAllTypes{single_int64: 10}.single_int64 + #22> match: + - #24> #23> condition: #25> > + spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + #26> output: #27> | + "invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + #28> explanation: #29> | + "invalid spec, spec is greater than 10" + - #31> #30> condition: #32> > + spec.standalone_enum == NestedEnum.BAR + #33> output: #34> | + "invalid spec, reference to BAR is not allowed" + - #36> #35> condition: #37> spec.single_int64 == variables.test_var + #38> output: #39> '"invalid spec: exactly matches test_var"' + #40> explanation: #41> '"the spec cannot have single_int64 set to a known bad value"' + =========================================================== + name: #1> "cel_policy" + description: #3> "A test policy for CEL" + display_name: #5> "Cel Policy" + imports: + #7> name: #8> "cel.expr.conformance.proto3.TestAllTypes" + #9> name: #10> "cel.expr.conformance.proto3.TestAllTypes.NestedEnum" + #12> rule: { + rule_id: #14> "test_rule" + description: #16> "test rule description" + variable: { + name: #19> "test_var" + expression: #21> "TestAllTypes{single_int64: 10}.single_int64 + " + } + #23> match: { + condition: #25> "spec.single_int32 > TestAllTypes{single_int64: 10}.single_int64 + " + result: { + output: #27> ""invalid spec, got single_int32=" + string(spec.single_int32) + ", wanted <= 10" + " + explanation: #29> ""invalid spec, spec is greater than 10" + " + } + } + #30> match: { + condition: #32> "spec.standalone_enum == NestedEnum.BAR + " + result: { + output: #34> ""invalid spec, reference to BAR is not allowed" + " + } + } + #35> match: { + condition: #37> "spec.single_int64 == variables.test_var" + result: { + output: #39> ""invalid spec: exactly matches test_var"" + explanation: #41> ""the spec cannot have single_int64 set to a known bad value"" + } + } + } +} diff --git a/policy/testdata/custom_policy_format.yaml b/policy/testdata/custom_policy_format.yaml new file mode 100644 index 000000000..a67356906 --- /dev/null +++ b/policy/testdata/custom_policy_format.yaml @@ -0,0 +1,29 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: cel_policy_custom_tags +description: A custom policy format +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +purpose: test +version: 42 +conditions: +- if: spec.single_string == "none" + then: "'zero'" + else: + if: spec.single_string == "integer" + then: + if: spec.single_int32 > 0 + then: "'positive integer'" + else: "'negative integer'" + else: "'not an integer'" diff --git a/policy/testdata/custom_policy_format_parser.baseline b/policy/testdata/custom_policy_format_parser.baseline new file mode 100644 index 000000000..d5b1a2235 --- /dev/null +++ b/policy/testdata/custom_policy_format_parser.baseline @@ -0,0 +1,75 @@ +POLICY SOURCE: custom_policy_format.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2026 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + #0> name: #1> cel_policy_custom_tags + #2> description: #3> A custom policy format + #4> imports: + - #5> name: #6> cel.expr.conformance.proto3.TestAllTypes + #7> purpose: #8> test + #9> version: #10> 42 + #11> conditions: + - #13> #12> if: #14> spec.single_string == "none" + #15> then: #16> "'zero'" + #17> else: + #19> #18> if: #20> spec.single_string == "integer" + #21> then: + #23> #22> if: #24> spec.single_int32 > 0 + #25> then: #26> "'positive integer'" + #27> else: #29> #28> "'negative integer'" + #30> else: #32> #31> "'not an integer'" + + =========================================================== + name: #1> "cel_policy_custom_tags" + description: #3> "A custom policy format" + metadata: { + purpose: #8> "test" + version: 42 + } + imports: + #5> name: #6> "cel.expr.conformance.proto3.TestAllTypes" + rule: { + #12> match: { + condition: #14> "spec.single_string == "none"" + result: { + output: #16> "'zero'" + } + } + #18> match: { + condition: #20> "spec.single_string == "integer"" + result: + rule: { + #22> match: { + condition: #24> "spec.single_int32 > 0" + result: { + output: #26> "'positive integer'" + } + } + #29> match: { + result: { + output: #28> "'negative integer'" + } + } + } + } + #32> match: { + result: { + output: #31> "'not an integer'" + } + } + } +} diff --git a/policy/testdata/custom_policy_format_with_errors.yaml b/policy/testdata/custom_policy_format_with_errors.yaml new file mode 100644 index 000000000..594747c60 --- /dev/null +++ b/policy/testdata/custom_policy_format_with_errors.yaml @@ -0,0 +1,33 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: cel_policy_custom_tags +description: A custom policy format +imports: +- name: cel.expr.conformance.proto3.TestAllTypes +purpose: + - testing +version: new +conditions: +- if: + spec.single_string: "none" + then: "'zero'" + else: "'not zero'" +- if: spec.single_string == "number" + then: + if: spec.single_int32 > 0 + then: "'positive integer'" + else: + - ignore +- else: "'negative integer'" + diff --git a/policy/testdata/custom_policy_format_with_errors_parser.baseline b/policy/testdata/custom_policy_format_with_errors_parser.baseline new file mode 100644 index 000000000..978d27bda --- /dev/null +++ b/policy/testdata/custom_policy_format_with_errors_parser.baseline @@ -0,0 +1,16 @@ +POLICY SOURCE: custom_policy_format_with_errors.yaml +-------------------------------------------------------------------- +-------------------------------------------------------------------- +PARSER ISSUES: +ERROR: custom_policy_format_with_errors.yaml:19:3: Policy purpose is not a string + | - testing + | ..^ +ERROR: custom_policy_format_with_errors.yaml:20:10: Policy version is not an integer: new + | version: new + | .........^ +ERROR: custom_policy_format_with_errors.yaml:23:5: Policy 'if' condition is not a string + | spec.single_string: "none" + | ....^ +ERROR: custom_policy_format_with_errors.yaml:31:7: Bad syntax in 'if/then' block + | - ignore + | ......^ diff --git a/policy/testdata/nested_rule.yaml b/policy/testdata/nested_rule.yaml new file mode 100644 index 000000000..2b07faa64 --- /dev/null +++ b/policy/testdata/nested_rule.yaml @@ -0,0 +1,37 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: nested_rule +rule: + variables: + - name: "permitted_regions" + expression: "['us', 'uk', 'es']" + match: + - rule: + id: "banned regions" + description: > + determine whether the resource origin is in the banned + list. If the region is also in the permitted list, the + ban has no effect. + variables: + - name: "banned_regions" + expression: "{'us': false, 'ru': false, 'ir': false}" + match: + - condition: | + resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + output: "{'banned': true}" + - condition: resource.origin in variables.permitted_regions + output: "{'banned': false}" + - output: "{'banned': true}" + explanation: "'resource is in the banned region ' + resource.origin" \ No newline at end of file diff --git a/policy/testdata/nested_rule_parser.baseline b/policy/testdata/nested_rule_parser.baseline new file mode 100644 index 000000000..128f81bda --- /dev/null +++ b/policy/testdata/nested_rule_parser.baseline @@ -0,0 +1,84 @@ +POLICY SOURCE: nested_rule.yaml +-------------------------------------------------------------------- +PARSED POLICY: +CelPolicy{ + =========================================================== + # Copyright 2024 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # https://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + #0> name: #1> nested_rule + #2> rule: + #4> #3> variables: + - #5> name: #6> "permitted_regions" + #7> expression: #8> "['us', 'uk', 'es']" + #9> match: + - #11> #10> rule: + #13> #12> id: #14> "banned regions" + #15> description: #16> > + determine whether the resource origin is in the banned + list. If the region is also in the permitted list, the + ban has no effect. + #17> variables: + - #18> name: #19> "banned_regions" + #20> expression: #21> "{'us': false, 'ru': false, 'ir': false}" + #22> match: + - #24> #23> condition: #25> | + resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + #26> output: #27> "{'banned': true}" + - #29> #28> condition: #30> resource.origin in variables.permitted_regions + #31> output: #32> "{'banned': false}" + - #34> #33> output: #35> "{'banned': true}" + #36> explanation: #37> "'resource is in the banned region ' + resource.origin" + =========================================================== + name: #1> "nested_rule" + description: "nested_rule.yaml" + #3> rule: { + variable: { + name: #6> "permitted_regions" + expression: #8> "['us', 'uk', 'es']" + } + #10> match: { + result: + #12> rule: { + rule_id: #14> "banned regions" + description: #16> "determine whether the resource origin is in the banned list. If the region is also in the permitted list, the ban has no effect. + " + variable: { + name: #19> "banned_regions" + expression: #21> "{'us': false, 'ru': false, 'ir': false}" + } + #23> match: { + condition: #25> "resource.origin in variables.banned_regions && + !(resource.origin in variables.permitted_regions) + " + result: { + output: #27> "{'banned': true}" + } + } + } + } + #28> match: { + condition: #30> "resource.origin in variables.permitted_regions" + result: { + output: #32> "{'banned': false}" + } + } + #33> match: { + result: { + output: #35> "{'banned': true}" + explanation: #37> "'resource is in the banned region ' + resource.origin" + } + } + } +} diff --git a/policy/yaml_policy_parser.cc b/policy/yaml_policy_parser.cc new file mode 100644 index 000000000..c838cff33 --- /dev/null +++ b/policy/yaml_policy_parser.cc @@ -0,0 +1,411 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/yaml_policy_parser.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "internal/status_macros.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/exceptions.h" +#include "yaml-cpp/node/node.h" +#include "yaml-cpp/node/parse.h" +#include "yaml-cpp/null.h" +#include "yaml-cpp/yaml.h" // IWYU pragma: keep + +namespace cel { + +CelPolicyElementId YamlPolicyParser::CollectMetadata( + CelPolicyParseContext& ctx, const YAML::Node& node) const { + CelPolicyElementId element_id = ctx.next_element_id(); + if (!node.Mark().is_null()) { + ctx.policy_source().NoteSourcePosition(element_id, node.Mark().pos); + } + return element_id; +} + +std::optional YamlPolicyParser::GetValueString( + CelPolicyParseContext& ctx, const YAML::Node& node, + std::string_view error_message) const { + if (!node.IsDefined()) { + // This should never happen since the YAML syntax has already been checked. + return std::nullopt; + } + + CelPolicyElementId id = CollectMetadata(ctx, node); + if (!node.IsScalar()) { + ctx.ReportError(id, error_message); + return std::nullopt; + } + + try { + return ValueString(id, node.as()); + } catch (YAML::Exception& e) { + // This should never happen since we already checked that the node is a + // scalar and all scalars can be converted to strings. + return std::nullopt; + } +} + +absl::Status YamlPolicyParser::ParsePolicy(CelPolicyParseContext& ctx) const { + const Source* source = ctx.policy_source().content(); + if (source == nullptr) { + return absl::OkStatus(); + } + + ctx.policy().set_description(ValueString(-1, source->description())); + std::string text = source->content().ToString(); + YAML::Node node; + try { + node = YAML::Load(text); + } catch (YAML::Exception& e) { + if (!e.mark.is_null()) { + ctx.policy_source().NoteSourcePosition(0, e.mark.pos); + } + ctx.ReportError(0, "Invalid CEL policy YAML syntax"); + return absl::OkStatus(); + } + + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), "Policy is not a map"); + return absl::OkStatus(); + } + + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, ParsePolicyTag(ctx, *key, value_node)); + if (!handled) { + ctx.ReportError( + key->id(), + absl::StrCat("Unrecognized top-level policy tag: ", key->value())); + } + } + + return absl::OkStatus(); +} + +absl::StatusOr YamlPolicyParser::ParsePolicyTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node) const { + if (tag_name.value() == "imports") { + CEL_RETURN_IF_ERROR(ParseImports(ctx, node)); + return true; + } + if (tag_name.value() == "name") { + std::optional name = + GetValueString(ctx, node, "Policy 'name' is not a string"); + if (name.has_value()) { + ctx.policy().set_name(*name); + } + return true; + } + if (tag_name.value() == "description") { + std::optional description = + GetValueString(ctx, node, "Policy 'description' is not a string"); + if (description.has_value()) { + ctx.policy().set_description(*description); + } + return true; + } + if (tag_name.value() == "display_name") { + std::optional display_name = + GetValueString(ctx, node, "Policy 'display_name' is not a string"); + if (display_name.has_value()) { + ctx.policy().set_display_name(*display_name); + } + return true; + } + if (tag_name.value() == "rule") { + CEL_RETURN_IF_ERROR(ParseRule(ctx, node, ctx.policy().mutable_rule())); + return true; + } + return false; +} + +absl::Status YamlPolicyParser::ParseImports(CelPolicyParseContext& ctx, + const YAML::Node& node) const { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy 'imports' is not a sequence"); + return absl::OkStatus(); + } + + for (const YAML::Node& import : node) { + CelPolicyElementId import_id = CollectMetadata(ctx, import); + if (!import.IsMap()) { + ctx.ReportError(import_id, "Import is not a map"); + continue; + } + const YAML::Node& name_node = import["name"]; + if (!name_node.IsDefined()) { + ctx.ReportError(import_id, "No 'name' tag in import"); + continue; + } + std::optional import_name = + GetValueString(ctx, name_node, "Import name is not a string"); + if (import_name.has_value()) { + ctx.policy().mutable_imports().push_back(Import(import_id, *import_name)); + } + } + return absl::OkStatus(); +} + +absl::Status YamlPolicyParser::ParseRule(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const { + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), "Policy 'rule' is not a map"); + return absl::OkStatus(); + } + rule.set_id(CollectMetadata(ctx, node)); + + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy rule tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseRuleTag(ctx, *key, value_node, rule)); + if (!handled) { + ctx.ReportError(key->id(), absl::StrCat("Unrecognized policy rule tag: ", + key->value())); + } + } + return absl::OkStatus(); +} + +absl::StatusOr YamlPolicyParser::ParseRuleTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Rule& rule) const { + if (tag_name.value() == "id") { + std::optional rule_id = + GetValueString(ctx, node, "Policy rule 'id' is not a string"); + if (rule_id.has_value()) { + rule.set_rule_id(*rule_id); + } + return true; + } + if (tag_name.value() == "description") { + std::optional description = + GetValueString(ctx, node, "Policy rule 'description' is not a string"); + if (description.has_value()) { + rule.set_description(*description); + } + return true; + } + if (tag_name.value() == "variables") { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'variables' is not a sequence"); + return true; + } + for (const YAML::Node& variable_node : node) { + CEL_ASSIGN_OR_RETURN(Variable variable, + ParseVariable(ctx, variable_node, rule)); + rule.mutable_variables().push_back(std::move(variable)); + } + return true; + } + if (tag_name.value() == "match") { + if (!node.IsSequence()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'match' is not a sequence"); + return true; + } + for (const YAML::Node& match_node : node) { + CEL_ASSIGN_OR_RETURN(Match match, ParseMatch(ctx, match_node, rule)); + rule.mutable_matches().push_back(std::move(match)); + } + return true; + } + return false; +} + +absl::StatusOr YamlPolicyParser::ParseVariable( + CelPolicyParseContext& ctx, const YAML::Node& node, Rule& rule) const { + Variable variable; + if (!node.IsMap()) { + ctx.ReportError(CollectMetadata(ctx, node), + "Policy rule 'variable' is not a map"); + return variable; + } + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy variable tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseVariableTag(ctx, *key, value_node, variable)); + if (!handled) { + ctx.ReportError( + key->id(), + absl::StrCat("Unrecognized policy variable tag: ", key->value())); + } + } + return variable; +} + +absl::StatusOr YamlPolicyParser::ParseVariableTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node, Variable& variable) const { + if (tag_name.value() == "name") { + std::optional name = + GetValueString(ctx, node, "Policy variable 'name' is not a string"); + if (name.has_value()) { + variable.set_name(*name); + } + return true; + } + if (tag_name.value() == "expression") { + std::optional expression = GetValueString( + ctx, node, "Policy variable 'expression' is not a string"); + if (expression.has_value()) { + variable.set_expression(*expression); + } + return true; + } + return false; +} + +absl::StatusOr YamlPolicyParser::ParseMatch(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const { + Match match; + match.set_id(CollectMetadata(ctx, node)); + if (!node.IsMap()) { + ctx.ReportError(match.id(), "Policy rule 'match' is not a map"); + return match; + } + for (auto it = node.begin(); it != node.end(); ++it) { + const YAML::Node key_node = it->first; + const YAML::Node value_node = it->second; + std::optional key = + GetValueString(ctx, key_node, "Policy match tag is not a string"); + if (!key.has_value()) { + continue; + } + CEL_ASSIGN_OR_RETURN(bool handled, + ParseMatchTag(ctx, *key, value_node, match, rule)); + if (!handled) { + ctx.ReportError(key->id(), absl::StrCat("Unrecognized policy match tag: ", + key->value())); + } + } + + if (match.has_output_block()) { + if (match.output_block().output().value().empty() && + match.output_block().explanation().has_value()) { + ctx.ReportError(match.id(), "Match specifies explanation but no output"); + } + } + + return match; +} + +absl::StatusOr YamlPolicyParser::ParseMatchTag( + CelPolicyParseContext& ctx, const ValueString& tag_name, + const YAML::Node& node, Match& match, Rule& rule) const { + if (tag_name.value() == "condition") { + std::optional condition = + GetValueString(ctx, node, "Policy match 'condition' is not a string"); + if (condition.has_value()) { + match.set_condition(*condition); + } + return true; + } + if (tag_name.value() == "explanation") { + std::optional explanation = + GetValueString(ctx, node, "Policy match 'explanation' is not a string"); + if (explanation.has_value()) { + if (match.has_rule()) { + ctx.ReportError( + tag_name.id(), + "Cannot specify explanation when a nested rule is present"); + } else { + match.mutable_output_block().set_explanation(*explanation); + } + } + return true; + } + if (tag_name.value() == "output") { + std::optional output = + GetValueString(ctx, node, "Policy match 'output' is not a string"); + if (output.has_value()) { + if (match.has_rule()) { + ctx.ReportError(tag_name.id(), + "Cannot specify output when a nested rule is present"); + } else { + match.mutable_output_block().set_output(*output); + } + } + return true; + } + if (tag_name.value() == "rule") { + if (match.has_output_block()) { + ctx.ReportError(tag_name.id(), + "Cannot specify nested rule when output/explanation is " + "present"); + } + auto nested_rule = std::make_unique(); + CEL_RETURN_IF_ERROR(ParseRule(ctx, node, *nested_rule)); + match.set_result(std::move(nested_rule)); + return true; + } + return false; +} + +const CelPolicyParser& GetDefaultYamlPolicyParser() { + static const auto* const parser = new YamlPolicyParser(); + return *parser; +} + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source) { + return ParseYamlCelPolicy(std::move(policy_source), + GetDefaultYamlPolicyParser()); +} + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source, + const CelPolicyParser& parser) { + CelPolicyParseContext ctx(std::move(policy_source)); + CEL_RETURN_IF_ERROR(parser.ParsePolicy(ctx)); + return ctx.GetResult(); +} + +} // namespace cel diff --git a/policy/yaml_policy_parser.h b/policy/yaml_policy_parser.h new file mode 100644 index 000000000..469209333 --- /dev/null +++ b/policy/yaml_policy_parser.h @@ -0,0 +1,135 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ +#define THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_context.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/node/node.h" + +namespace cel { + +// A parser for YAML-based CEL policies. +// +// To support additional or alternative YAML elements, subclass +// `YamlPolicyParser` and override specific parsing methods, `Parse*` +class YamlPolicyParser : public CelPolicyParser { + public: + std::optional GetValueString( + CelPolicyParseContext& ctx, const YAML::Node& node, + std::string_view error_message) const; + + absl::Status ParsePolicy(CelPolicyParseContext& ctx) const override; + + protected: + // Collects metadata (e.g. source position) for the given YAML node, stores it + // in the context, and returns an ID that can be used to refer to it. + virtual CelPolicyElementId CollectMetadata(CelPolicyParseContext& ctx, + const YAML::Node& node) const; + + // Parses a top-level tag in the policy YAML. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParsePolicyTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node) const; + + // Parses the imports section of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::Status ParseImports(CelPolicyParseContext& ctx, + const YAML::Node& node) const; + + // Parses a rule element of the policy YAML, which may be the top-level rule + // or a sub-rule of a match. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::Status ParseRule(CelPolicyParseContext& ctx, + const YAML::Node& node, Rule& rule) const; + + // Parses a tag in a policy YAML rule. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseRuleTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Rule& rule) const; + + // Parses a variable element of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseVariable(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const; + + // Parses a tag in a policy YAML variable. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseVariableTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Variable& variable) const; + + // Parses a match element of the policy YAML. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseMatch(CelPolicyParseContext& ctx, + const YAML::Node& node, + Rule& rule) const; + + // Parses a tag in a policy YAML match. + // Returns true if the tag was handled. + // + // Note that an OkStatus does not necessarily mean that parsing was successful + // - only that it can continue. + virtual absl::StatusOr ParseMatchTag(CelPolicyParseContext& ctx, + const ValueString& tag_name, + const YAML::Node& node, + Match& match, Rule& rule) const; +}; + +// Returns a default implementation of YamlPolicyParser. +const CelPolicyParser& GetDefaultYamlPolicyParser(); + +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source, + const CelPolicyParser& parser); + +// YAML CelPolicy parser that uses the default format as implemented by +// `YamlPolicyParser`. +absl::StatusOr ParseYamlCelPolicy( + std::shared_ptr policy_source); + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_POLICY_YAML_POLICY_PARSER_H_ diff --git a/policy/yaml_policy_parser_test.cc b/policy/yaml_policy_parser_test.cc new file mode 100644 index 000000000..4e7dfc49c --- /dev/null +++ b/policy/yaml_policy_parser_test.cc @@ -0,0 +1,305 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "policy/yaml_policy_parser.h" + +#include +#include +#include +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "common/source.h" +#include "internal/runfiles.h" +#include "internal/testing.h" +#include "policy/cel_policy.h" +#include "policy/cel_policy_parse_result.h" +#include "policy/cel_policy_parser.h" +#include "yaml-cpp/node/node.h" + +namespace cel { + +namespace internal { +const CelPolicyParser& GetTestCustomYamlPolicyParser(); +} // namespace internal + +namespace { + +using ::absl_testing::IsOk; +using ::testing::HasSubstr; +using ::testing::IsNull; + +constexpr absl::string_view kTestPolicyFilePath = +"_main/policy/testdata/"; + +constexpr absl::string_view kBaselineSeparator = + "--------------------------------------------------------------------\n"; + +struct YamlPolicyParserTestCase { + std::string policy_source_file; + std::string baseline_file; + const cel::CelPolicyParser& (*parser_factory)(); +}; + +using YamlPolicyParserTest = testing::TestWithParam; + +TEST_P(YamlPolicyParserTest, Parse) { + std::string contents; + std::string test_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, GetParam().policy_source_file)); + ASSERT_THAT(cel::internal::GetFileContents(test_file, &contents), IsOk()); + + std::string baseline; + std::string baseline_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, GetParam().baseline_file)); + ASSERT_THAT(cel::internal::GetFileContents(baseline_file, &baseline), IsOk()); + baseline = absl::StripAsciiWhitespace(baseline); + + std::ostringstream out; + out << "POLICY SOURCE: " << GetParam().policy_source_file << "\n"; + + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(contents, GetParam().policy_source_file)); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + + ASSERT_OK_AND_ASSIGN( + CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source, GetParam().parser_factory())); + + out << kBaselineSeparator; + if (parse_result.IsValid()) { + out << "PARSED POLICY:\n"; + out << parse_result.GetPolicy()->DebugString(); + } else { + ASSERT_THAT(parse_result.GetPolicy(), IsNull()); + out << kBaselineSeparator; + out << "PARSER ISSUES:\n"; + for (const auto& issue : parse_result.GetIssues()) { + out << issue.ToDisplayString(*policy_source) << "\n"; + } + } + + std::string actual(absl::StripAsciiWhitespace(out.str())); + if (actual != baseline) { + // Log the actual result to make it easier to copy/paste into the baseline + // file when updating the tests. + ABSL_LOG(INFO) << "Actual:\n" << actual; + EXPECT_EQ(actual, baseline); + } +} + +INSTANTIATE_TEST_SUITE_P( + Formats, YamlPolicyParserTest, + testing::ValuesIn({ + YamlPolicyParserTestCase{ + .policy_source_file = "cel_policy.yaml", + .baseline_file = "cel_policy_parser.baseline", + .parser_factory = GetDefaultYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "nested_rule.yaml", + .baseline_file = "nested_rule_parser.baseline", + .parser_factory = GetDefaultYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "custom_policy_format.yaml", + .baseline_file = "custom_policy_format_parser.baseline", + .parser_factory = internal::GetTestCustomYamlPolicyParser, + }, + YamlPolicyParserTestCase{ + .policy_source_file = "custom_policy_format_with_errors.yaml", + .baseline_file = "custom_policy_format_with_errors_parser.baseline", + .parser_factory = internal::GetTestCustomYamlPolicyParser, + }, + })); + +struct ParseTestCase { + std::string yaml; + std::string expected_error; +}; + +using YamlPolicyParseErrorTest = testing::TestWithParam; + +TEST_P(YamlPolicyParseErrorTest, YamlSyntaxError) { + const ParseTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(param.yaml, "test")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + ASSERT_OK_AND_ASSIGN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + EXPECT_THAT(parse_result.FormattedIssues(), HasSubstr(param.expected_error)); +} + +std::vector GetParseTestCases() { + return { + ParseTestCase{ + .yaml = R"yaml( ? [ John, Doe ]: age: 30 )yaml", + .expected_error = "1:22: Invalid CEL policy YAML syntax\n" + " | ? [ John, Doe ]: age: 30 \n" + " | .....................^", + }, + ParseTestCase{ + .yaml = R"yaml( invalid yaml )yaml", + .expected_error = "1:2: Policy is not a map\n" + " | invalid yaml \n" + " | .^", + }, + ParseTestCase{ + .yaml = R"yaml( + ? [1, 2, 3] + : "Prime numbers sequence" + )yaml", + .expected_error = "2:23: Policy tag is not a string\n" + " | ? [1, 2, 3]\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: N/A + )yaml", + .expected_error = "2:28: Policy 'imports' is not a sequence\n" + " | imports: N/A\n" + " | ...........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: + - cel.expr.conformance + )yaml", + .expected_error = "3:21: Import is not a map\n" + " | - cel.expr.conformance\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + imports: + - name: + - cel.expr.conformance + )yaml", + .expected_error = "4:21: Import name is not a string\n" + " | - cel.expr.conformance\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: do something + )yaml", + .expected_error = "2:25: Policy 'rule' is not a map\n" + " | rule: do something\n" + " | ........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + id: + - 22 + )yaml", + .expected_error = "4:21: Policy rule 'id' is not a string\n" + " | - 22\n" + " | ....................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + no vars + )yaml", + .expected_error = "4:23: Policy rule 'variables' is not a sequence\n" + " | no vars\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: + foo: bar + )yaml", + .expected_error = "5:25: Policy variable 'name' is not a string\n" + " | foo: bar\n" + " | ........................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: test_var + expression: + - 22 + )yaml", + .expected_error = + "6:23: Policy variable 'expression' is not a string\n" + " | - 22\n" + " | ......................^", + }, + ParseTestCase{ + .yaml = R"yaml( + rule: + variables: + - name: '\u0041\u00a9\u20ac\U0001f680' + - '\u0041\u00a9\u20ac\U0001f680': name + )yaml", + .expected_error = + "5:23: Unrecognized policy variable tag: " + "\\u0041\\u00a9\\u20ac\\U0001f680\n" + " | - '\\u0041\\u00a9\\u20ac\\U0001f680': " + "name\n" + " | ......................^", + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(YamlPolicyParseErrorTest, YamlPolicyParseErrorTest, + ::testing::ValuesIn(GetParseTestCases())); + +TEST(YamlPolicyParserTest, OffsetIssueFormatting) { + // TODO(b/506179116): will need to copy the go implementation in extracting + // the source string from the YAML document instead of the interpreted string + // value to fix up error locations in folded and block literals. + std::string contents; + std::string test_file = cel::internal::ResolveRunfilesPath( + absl::StrCat(kTestPolicyFilePath, "cel_policy.yaml")); + ASSERT_THAT(cel::internal::GetFileContents(test_file, &contents), IsOk()); + + ASSERT_OK_AND_ASSIGN(cel::SourcePtr source, + cel::NewSource(contents, "cel_policy.yaml")); + std::shared_ptr policy_source = + std::make_shared(std::move(source)); + ASSERT_OK_AND_ASSIGN(CelPolicyParseResult parse_result, + cel::ParseYamlCelPolicy(policy_source)); + + ASSERT_TRUE(parse_result.IsValid()); + const CelPolicy* policy = parse_result.GetPolicy(); + + CelPolicyElementId name_id = policy->name().id(); + + CelPolicyIssue issue(name_id, 4, CelPolicyIssue::Severity::kError, + "Test error"); + + std::string formatted = issue.ToDisplayString(*policy_source); + + EXPECT_THAT(formatted, HasSubstr("ERROR: cel_policy.yaml:16:11: Test error")); + EXPECT_THAT(formatted, HasSubstr(" | name: cel_policy")); + EXPECT_THAT(formatted, HasSubstr(" | ..........^")); +} + +} // namespace +} // namespace cel From c8c187fea145b3113b3638d76c727c9052f13511 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Tue, 16 Jun 2026 10:10:54 -0700 Subject: [PATCH 65/87] Support overload lookup and filtering by signature as an alternative to overload_id PiperOrigin-RevId: 933154803 --- checker/BUILD | 2 + checker/internal/type_checker_builder_impl.cc | 4 +- checker/type_checker_builder.h | 2 +- checker/type_checker_builder_factory_test.cc | 17 +++-- checker/type_checker_subset_factory.cc | 20 ++++-- checker/type_checker_subset_factory_test.cc | 33 +++++++-- common/BUILD | 3 +- common/decl.cc | 68 +++++++++++-------- common/decl.h | 48 +++---------- common/decl_test.cc | 47 +++++++++++++ env/BUILD | 1 + env/config_test.cc | 55 +++++++++++++++ env/env.cc | 56 +++++++++++---- env/env_test.cc | 16 +++++ env/env_yaml_test.cc | 12 ++-- 15 files changed, 276 insertions(+), 108 deletions(-) diff --git a/checker/BUILD b/checker/BUILD index 27a1eb84e..10ed6e363 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -229,6 +229,8 @@ cc_library( hdrs = ["type_checker_subset_factory.h"], deps = [ ":type_checker_builder", + "//common:decl", + "//common/internal:signature", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index 9b91fc926..f0332b999 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -187,8 +187,8 @@ std::optional FilterDecl(FunctionDecl decl, FunctionDecl filtered; std::string name = decl.release_name(); std::vector overloads = decl.release_overloads(); - for (const auto& ovl : overloads) { - if (subset.should_include_overload(name, ovl.id())) { + for (auto& ovl : overloads) { + if (subset.should_include_overload(name, ovl)) { absl::Status s = filtered.AddOverload(std::move(ovl)); if (!s.ok()) { // Should not be possible to construct the original decl in a way that diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index f145b8a98..c2d0cbf7b 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -52,7 +52,7 @@ struct CheckerLibrary { // Represents a declaration to only use a subset of a library. struct TypeCheckerSubset { using FunctionPredicate = absl::AnyInvocable; + absl::string_view function, const OverloadDecl& overload) const>; // The id of the library to subset. Only one subset can be applied per // library id. diff --git a/checker/type_checker_builder_factory_test.cc b/checker/type_checker_builder_factory_test.cc index 9c4775e7f..40406948d 100644 --- a/checker/type_checker_builder_factory_test.cc +++ b/checker/type_checker_builder_factory_test.cc @@ -235,8 +235,8 @@ TEST(TypeCheckerBuilderTest, AddLibraryIncludeSubset) { ASSERT_THAT( builder->AddLibrarySubset( {"testlib", - [](absl::string_view /*function*/, absl::string_view overload_id) { - return (overload_id == "add_int" || overload_id == "sub_int"); + [](absl::string_view /*function*/, const OverloadDecl& overload) { + return (overload.id() == "add_int" || overload.id() == "sub_int"); }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); @@ -274,9 +274,8 @@ TEST(TypeCheckerBuilderTest, AddLibraryExcludeSubset) { ASSERT_THAT( builder->AddLibrarySubset( {"testlib", - [](absl::string_view /*function*/, absl::string_view overload_id) { - return (overload_id != "add_int" && overload_id != "sub_int"); - ; + [](absl::string_view /*function*/, const OverloadDecl& overload) { + return (overload.id() != "add_int" && overload.id() != "sub_int"); }}), IsOk()); ASSERT_OK_AND_ASSIGN(auto checker, builder->Build()); @@ -313,7 +312,7 @@ TEST(TypeCheckerBuilderTest, AddLibrarySubsetRemoveAllOvl) { ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); ASSERT_THAT(builder->AddLibrarySubset({"testlib", [](absl::string_view function, - absl::string_view /*overload_id*/) { + const OverloadDecl& /*overload*/) { return function != "add"; }}), IsOk()); @@ -352,12 +351,12 @@ TEST(TypeCheckerBuilderTest, AddLibraryOneSubsetPerLibraryId) { ASSERT_THAT( builder->AddLibrarySubset( {"testlib", [](absl::string_view function, - absl::string_view /*overload_id*/) { return true; }}), + const OverloadDecl& /*overload*/) { return true; }}), IsOk()); EXPECT_THAT( builder->AddLibrarySubset( {"testlib", [](absl::string_view function, - absl::string_view /*overload_id*/) { return true; }}), + const OverloadDecl& /*overload*/) { return true; }}), StatusIs(absl::StatusCode::kAlreadyExists)); } @@ -369,7 +368,7 @@ TEST(TypeCheckerBuilderTest, AddLibrarySubsetLibraryIdRequireds) { ASSERT_THAT(builder->AddLibrary(SubsetTestlib()), IsOk()); EXPECT_THAT(builder->AddLibrarySubset({"", [](absl::string_view function, - absl::string_view /*overload_id*/) { + const OverloadDecl& /*overload*/) { return function == "add"; }}), StatusIs(absl::StatusCode::kInvalidArgument)); diff --git a/checker/type_checker_subset_factory.cc b/checker/type_checker_subset_factory.cc index 6a05ce220..e5335e220 100644 --- a/checker/type_checker_subset_factory.cc +++ b/checker/type_checker_subset_factory.cc @@ -21,14 +21,21 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/internal/signature.h" namespace cel { TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( absl::flat_hash_set overload_ids) { return [overload_ids = std::move(overload_ids)]( - absl::string_view /*function*/, absl::string_view overload_id) { - return overload_ids.contains(overload_id); + absl::string_view function, const OverloadDecl& overload) { + if (overload_ids.contains(overload.id())) { + return true; + } + auto signature = common_internal::MakeOverloadSignature( + function, overload.args(), overload.member()); + return signature.ok() && overload_ids.contains(*signature); }; } @@ -41,8 +48,13 @@ TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( absl::flat_hash_set overload_ids) { return [overload_ids = std::move(overload_ids)]( - absl::string_view /*function*/, absl::string_view overload_id) { - return !overload_ids.contains(overload_id); + absl::string_view function, const OverloadDecl& overload) { + if (overload_ids.contains(overload.id())) { + return false; + } + auto signature = common_internal::MakeOverloadSignature( + function, overload.args(), overload.member()); + return !signature.ok() || !overload_ids.contains(*signature); }; } diff --git a/checker/type_checker_subset_factory_test.cc b/checker/type_checker_subset_factory_test.cc index fa38e1c0d..5b644ec7c 100644 --- a/checker/type_checker_subset_factory_test.cc +++ b/checker/type_checker_subset_factory_test.cc @@ -43,6 +43,8 @@ TEST(TypeCheckerSubsetFactoryTest, IncludeOverloadsByIdPredicate) { StandardOverloadIds::kEquals, StandardOverloadIds::kNotEquals, StandardOverloadIds::kNotStrictlyFalse, + "matches(string,string)", + "string.matches(string)", }; ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({ @@ -65,15 +67,19 @@ TEST(TypeCheckerSubsetFactoryTest, IncludeOverloadsByIdPredicate) { EXPECT_TRUE(r.IsValid()); + // Allowed by signature. + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); + EXPECT_TRUE(r.IsValid()); + + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("matches(r'foo.*', 'foobar')")); + EXPECT_TRUE(r.IsValid()); + // Not in allowlist. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3")); EXPECT_FALSE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'")); EXPECT_FALSE(r.IsValid()); - - ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); - EXPECT_FALSE(r.IsValid()); } TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { @@ -83,6 +89,8 @@ TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { absl::string_view exclude_list[] = { StandardOverloadIds::kMatches, StandardOverloadIds::kMatchesMember, + "size(string)", + "string.size()", }; ASSERT_THAT(builder->AddLibrary(StandardCompilerLibrary()), IsOk()); ASSERT_THAT(builder->GetCheckerBuilder().AddLibrarySubset({ @@ -105,18 +113,35 @@ TEST(TypeCheckerSubsetFactoryTest, ExcludeOverloadsByIdPredicate) { EXPECT_TRUE(r.IsValid()); - // Not in allowlist. + // Allowed. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("1 + 2 < 3")); EXPECT_TRUE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc' + 'def'")); EXPECT_TRUE(r.IsValid()); + // Excluded by ID. ASSERT_OK_AND_ASSIGN(r, compiler->Compile("r'foo.*'.matches('foobar')")); EXPECT_FALSE(r.IsValid()); ASSERT_OK_AND_ASSIGN(r, compiler->Compile("matches(r'foo.*', 'foobar')")); EXPECT_FALSE(r.IsValid()); + + // Excluded by signature (top-level function). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("size('abc')")); + EXPECT_FALSE(r.IsValid()); + + // Allowed (other overloads of size). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("size([1, 2, 3])")); + EXPECT_TRUE(r.IsValid()); + + // Excluded by signature (member function). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("'abc'.size()")); + EXPECT_FALSE(r.IsValid()); + + // Allowed (other overloads of size member). + ASSERT_OK_AND_ASSIGN(r, compiler->Compile("[1, 2, 3].size()")); + EXPECT_TRUE(r.IsValid()); } } // namespace diff --git a/common/BUILD b/common/BUILD index 93410306f..0bd3632dd 100644 --- a/common/BUILD +++ b/common/BUILD @@ -151,9 +151,8 @@ cc_library( "//internal:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/common/decl.cc b/common/decl.cc index b338bfd4f..d2d50964a 100644 --- a/common/decl.cc +++ b/common/decl.cc @@ -20,8 +20,8 @@ #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -109,43 +109,46 @@ bool SignaturesOverlap(const OverloadDecl& lhs, const OverloadDecl& rhs) { template void AddOverloadInternal(std::string_view function_name, std::vector& insertion_order, - OverloadDeclHashSet& overloads, Overload&& overload, - absl::Status& status) { + absl::flat_hash_map& by_id, + absl::flat_hash_map& by_signature, + Overload&& overload, absl::Status& status) { if (!status.ok()) { return; } - if (overload.id().empty()) { - OverloadDecl overload_decl = overload; - absl::StatusOr overload_id = - common_internal::MakeOverloadSignature( - function_name, overload_decl.args(), overload_decl.member()); - if (!overload_id.ok()) { - status = overload_id.status(); - return; - } - overload_decl.set_id(*overload_id); - AddOverloadInternal(function_name, insertion_order, overloads, - std::move(overload_decl), status); + absl::StatusOr signature = + common_internal::MakeOverloadSignature(function_name, overload.args(), + overload.member()); + if (!signature.ok()) { + status = signature.status(); return; } - if (auto it = overloads.find(overload.id()); it != overloads.end()) { + OverloadDecl mutable_overload = std::forward(overload); + + if (mutable_overload.id().empty()) { + mutable_overload.set_id(*signature); + } + + if (auto it = by_id.find(mutable_overload.id()); it != by_id.end()) { status = absl::AlreadyExistsError( - absl::StrCat("overload already exists: ", overload.id())); + absl::StrCat("overload exists: ", mutable_overload.id())); return; } - for (const auto& existing : overloads) { - if (SignaturesOverlap(overload, existing)) { + + for (const auto& existing : insertion_order) { + if (SignaturesOverlap(mutable_overload, existing)) { status = absl::InvalidArgumentError( absl::StrCat("overload signature collision: ", existing.id(), - " collides with ", overload.id())); + " collides with ", mutable_overload.id())); return; } } - const auto inserted = overloads.insert(std::forward(overload)); - ABSL_DCHECK(inserted.second); - insertion_order.push_back(*inserted.first); + + size_t index = insertion_order.size(); + by_id[mutable_overload.id()] = index; + by_signature[*signature] = index; + insertion_order.push_back(std::move(mutable_overload)); } void CollectTypeParams(absl::flat_hash_set& type_params, @@ -195,14 +198,25 @@ absl::flat_hash_set OverloadDecl::GetTypeParams() const { void FunctionDecl::AddOverloadImpl(const OverloadDecl& overload, absl::Status& status) { - AddOverloadInternal(name_, overloads_.insertion_order, overloads_.set, - overload, status); + AddOverloadInternal(name_, overloads_.insertion_order, overloads_.by_id, + overloads_.by_signature, overload, status); } void FunctionDecl::AddOverloadImpl(OverloadDecl&& overload, absl::Status& status) { - AddOverloadInternal(name_, overloads_.insertion_order, overloads_.set, - std::move(overload), status); + AddOverloadInternal(name_, overloads_.insertion_order, overloads_.by_id, + overloads_.by_signature, std::move(overload), status); +} + +const OverloadDecl* FunctionDecl::FindOverloadById(absl::string_view id) const { + if (auto it = overloads_.by_id.find(id); it != overloads_.by_id.end()) { + return &overloads_.insertion_order[it->second]; + } + if (auto it = overloads_.by_signature.find(id); + it != overloads_.by_signature.end()) { + return &overloads_.insertion_order[it->second]; + } + return nullptr; } } // namespace cel diff --git a/common/decl.h b/common/decl.h index 22ee8cbf0..7aea8c502 100644 --- a/common/decl.h +++ b/common/decl.h @@ -22,11 +22,10 @@ #include "absl/algorithm/container.h" #include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -264,39 +263,6 @@ OverloadDecl MakeMemberOverloadDecl(absl::string_view id, Type result, return overload_decl; } -struct OverloadDeclHash { - using is_transparent = void; - - size_t operator()(const OverloadDecl& overload_decl) const { - return (*this)(overload_decl.id()); - } - - size_t operator()(absl::string_view id) const { return absl::HashOf(id); } -}; - -struct OverloadDeclEqualTo { - using is_transparent = void; - - bool operator()(const OverloadDecl& lhs, const OverloadDecl& rhs) const { - return (*this)(lhs.id(), rhs.id()); - } - - bool operator()(const OverloadDecl& lhs, absl::string_view rhs) const { - return (*this)(lhs.id(), rhs); - } - - bool operator()(absl::string_view lhs, const OverloadDecl& rhs) const { - return (*this)(lhs, rhs.id()); - } - - bool operator()(absl::string_view lhs, absl::string_view rhs) const { - return lhs == rhs; - } -}; - -using OverloadDeclHashSet = - absl::flat_hash_set; - template absl::StatusOr MakeFunctionDecl(std::string name, Overloads&&... overloads); @@ -346,21 +312,27 @@ class FunctionDecl final { return overloads_.insertion_order; } + ABSL_MUST_USE_RESULT const OverloadDecl* FindOverloadById( + absl::string_view id) const; + std::vector release_overloads() { std::vector released = std::move(overloads_.insertion_order); overloads_.insertion_order.clear(); - overloads_.set.clear(); + overloads_.by_id.clear(); + overloads_.by_signature.clear(); return released; } private: struct Overloads { std::vector insertion_order; - OverloadDeclHashSet set; + absl::flat_hash_map by_id; + absl::flat_hash_map by_signature; void Reserve(size_t size) { insertion_order.reserve(size); - set.reserve(size); + by_id.reserve(size); + by_signature.reserve(size); } }; diff --git a/common/decl_test.cc b/common/decl_test.cc index 510cd5017..72e7f1b93 100644 --- a/common/decl_test.cc +++ b/common/decl_test.cc @@ -165,6 +165,53 @@ TEST(FunctionDecl, Overloads) { StatusIs(absl::StatusCode::kInvalidArgument)); } +TEST(FunctionDecl, AddOverloadInvalidSignature) { + FunctionDecl function_decl; + function_decl.set_name("foo"); + // Member overload must have at least one argument (the receiver). + // This should fail to add because signature generation fails. + EXPECT_THAT(function_decl.AddOverload(MakeMemberOverloadDecl(StringType{})), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(FunctionDecl, AddOverloadDuplicateId) { + ASSERT_OK_AND_ASSIGN( + auto function_decl, + MakeFunctionDecl("hello", + MakeOverloadDecl("foo", StringType{}, StringType{}))); + // Adding another overload with the same ID "foo" should fail. + EXPECT_THAT( + function_decl.AddOverload(MakeOverloadDecl("foo", IntType{}, IntType{})), + StatusIs(absl::StatusCode::kAlreadyExists)); +} + +TEST(FunctionDecl, FindOverload) { + ASSERT_OK_AND_ASSIGN( + auto function_decl, + MakeFunctionDecl( + "hello", MakeOverloadDecl("foo", StringType{}, StringType{}), + MakeMemberOverloadDecl("bar", StringType{}, StringType{}), + MakeOverloadDecl(IntType{}, IntType{}))); + + // Find by explicit ID + const OverloadDecl* overload = function_decl.FindOverloadById("foo"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "foo"); + + // Find by ID fallback to signature + overload = function_decl.FindOverloadById("hello(string)"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "foo"); + + // Find implicit overload (where ID == signature) + overload = function_decl.FindOverloadById("hello(int)"); + ASSERT_NE(overload, nullptr); + EXPECT_EQ(overload->id(), "hello(int)"); + + // Non-existent + EXPECT_EQ(function_decl.FindOverloadById("non_existent"), nullptr); +} + TEST(FunctionDecl, OverloadId) { google::protobuf::Arena arena; const auto* descriptor = diff --git a/env/BUILD b/env/BUILD index bd82e8ec6..1816238a5 100644 --- a/env/BUILD +++ b/env/BUILD @@ -56,6 +56,7 @@ cc_library( "//common:container", "//common:decl", "//common:type", + "//common/internal:signature", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", diff --git a/env/config_test.cc b/env/config_test.cc index df0d6f875..8cfc3cf7f 100644 --- a/env/config_test.cc +++ b/env/config_test.cc @@ -88,6 +88,34 @@ INSTANTIATE_TEST_SUITE_P( StandardLibraryConfigTestCase{ .standard_library_config = {}, }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add(int,int)"}, + {"_+_", "add(list,list)"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", "add_int64"}, + {"_+_", "add_list"}}, + }, + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", "add(int,int)"}, + {"_+_", "add(list,list)"}}, + }, + }, StandardLibraryConfigTestCase{ .standard_library_config = { @@ -106,6 +134,15 @@ INSTANTIATE_TEST_SUITE_P( .expected_error = "Cannot set both included and excluded functions.", }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", "add(int,int)"}}, + .excluded_functions = {{"_-_", ""}}, + }, + .expected_error = + "Cannot set both included and excluded functions.", + }, StandardLibraryConfigTestCase{ .standard_library_config = { @@ -114,6 +151,15 @@ INSTANTIATE_TEST_SUITE_P( .expected_error = "Cannot include function '_+_' and also its " "specific overload 'add_list'", }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .included_functions = {{"_+_", ""}, + {"_+_", "add(int,int)"}}, + }, + .expected_error = "Cannot include function '_+_' and also its " + "specific overload 'add(int,int)'", + }, StandardLibraryConfigTestCase{ .standard_library_config = { @@ -121,6 +167,15 @@ INSTANTIATE_TEST_SUITE_P( }, .expected_error = "Cannot exclude function '_+_' and also its " "specific overload 'add_list'", + }, + StandardLibraryConfigTestCase{ + .standard_library_config = + { + .excluded_functions = {{"_+_", ""}, + {"_+_", "add(int,int)"}}, + }, + .expected_error = "Cannot exclude function '_+_' and also its " + "specific overload 'add(int,int)'", })); TEST(VariableConfigTest, VariableConfig) { diff --git a/env/env.cc b/env/env.cc index 22d24295e..4fa4e7398 100644 --- a/env/env.cc +++ b/env/env.cc @@ -26,6 +26,7 @@ #include "common/constant.h" #include "common/container.h" #include "common/decl.h" +#include "common/internal/signature.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" @@ -57,21 +58,47 @@ bool ShouldIncludeMacro(const Config::StandardLibraryConfig& config, bool ShouldIncludeFunction(const Config::StandardLibraryConfig& config, absl::string_view function, - absl::string_view overload_id) { - if (config.excluded_functions.contains( - std::make_pair(std::string(function), std::string(overload_id))) || - config.excluded_functions.contains( - std::make_pair(std::string(function), ""))) { - return false; + const OverloadDecl& overload) { + if (config.excluded_functions.empty() && config.included_functions.empty()) { + return true; + } + + if (!config.excluded_functions.empty()) { + if (config.excluded_functions.contains(std::make_pair( + std::string(function), std::string(overload.id()))) || + config.excluded_functions.contains( + std::make_pair(std::string(function), ""))) { + return false; + } + absl::StatusOr signature = + common_internal::MakeOverloadSignature(function, overload.args(), + overload.member()); + if (signature.ok() && config.excluded_functions.contains(std::make_pair( + std::string(function), *std::move(signature)))) { + return false; + } } - if (!config.included_functions.empty() && - !config.included_functions.contains( - std::make_pair(std::string(function), "")) && - !config.included_functions.contains( - std::make_pair(std::string(function), std::string(overload_id)))) { + + if (!config.included_functions.empty()) { + if (config.included_functions.contains(std::make_pair( + std::string(function), std::string(overload.id()))) || + config.included_functions.contains( + std::make_pair(std::string(function), ""))) { + return true; + } + // Ok to call MakeOverloadSignature() again, because in practice either + // included or excluded functions may be specified, but not both. + absl::StatusOr signature = + common_internal::MakeOverloadSignature(function, overload.args(), + overload.member()); + if (signature.ok() && config.included_functions.contains(std::make_pair( + std::string(function), *std::move(signature)))) { + return true; + } return false; } - return true; + + return true; // Never reached } absl::StatusOr MakeStdlibSubset( @@ -87,9 +114,8 @@ absl::StatusOr MakeStdlibSubset( }; subset.should_include_overload = [&standard_library_config]( absl::string_view function, - absl::string_view overload_id) { - return ShouldIncludeFunction(standard_library_config, function, - overload_id); + const OverloadDecl& overload) { + return ShouldIncludeFunction(standard_library_config, function, overload); }; return subset; } diff --git a/env/env_test.cc b/env/env_test.cc index fda87dfab..00143a857 100644 --- a/env/env_test.cc +++ b/env/env_test.cc @@ -280,6 +280,15 @@ INSTANTIATE_TEST_SUITE_P( .expected_invalid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", "'hello' + 'world'"}, }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.excluded_functions = {{"_+_", "_+_(bytes,bytes)"}, + {"_+_", "_+_(list<~A>,list<~A>)"}, + {"_+_", "_+_(string,string)"}}}, + .expected_valid_expressions = {"1 + 2"}, + .expected_invalid_expressions = {"[1, 2, 3] + [4, 5, 6]", + "'hello' + 'world'"}, + }, StandardLibraryConfigTestCase{ .standard_library_config = {.excluded_functions = {{"_+_", "add_bytes"}, @@ -294,6 +303,13 @@ INSTANTIATE_TEST_SUITE_P( .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]", "'hello' + 'world'"}, }, + StandardLibraryConfigTestCase{ + .standard_library_config = + {.included_functions = {{"_+_", "_+_(int,int)"}, + {"_+_", "_+_(list<~A>,list<~A>)"}}}, + .expected_valid_expressions = {"1 + 2", "[1, 2, 3] + [4, 5, 6]"}, + .expected_invalid_expressions = {"'hello' + 'world'"}, + }, StandardLibraryConfigTestCase{ .standard_library_config = {.included_functions = {{"_+_", "add_int64"}, diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index 9c5b3f04f..38f08e371 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -153,7 +153,7 @@ TEST(EnvYamlTest, ParseStdlibConfig_InclusionStyle) { - name: "_+_" overloads: - id: add_bytes - - id: add_list + - id: "_+_(list<~A>,list<~A>)" - name: "matches" - name: "timestamp" overloads: @@ -166,7 +166,7 @@ TEST(EnvYamlTest, ParseStdlibConfig_InclusionStyle) { EXPECT_THAT( stdlib_config.included_functions, UnorderedElementsAre(std::make_pair("_+_", "add_bytes"), - std::make_pair("_+_", "add_list"), + std::make_pair("_+_", "_+_(list<~A>,list<~A>)"), std::make_pair("matches", ""), std::make_pair("timestamp", "string_to_timestamp"))) << " Actual stdlib config: " << stdlib_config; @@ -1405,9 +1405,9 @@ std::vector GetExportTestCases() { .included_functions = { std::make_pair("timestamp", "string_to_timestamp"), - std::make_pair("_+_", "add_list"), + std::make_pair("_+_", "_+_(list<~A>,list<~A>)"), std::make_pair("matches", ""), - std::make_pair("_+_", "add_bytes"), + std::make_pair("_+_", "_+_(bytes,bytes)"), }, })); return config; @@ -1417,8 +1417,8 @@ std::vector GetExportTestCases() { include_functions: - name: "_+_" overloads: - - id: "add_bytes" - - id: "add_list" + - id: "_+_(bytes,bytes)" + - id: "_+_(list<~A>,list<~A>)" - name: "matches" - name: "timestamp" overloads: From 903210f28babe29c18710d89217c0ce468fef9ba Mon Sep 17 00:00:00 2001 From: Justin King Date: Tue, 16 Jun 2026 13:19:17 -0700 Subject: [PATCH 66/87] Internal change PiperOrigin-RevId: 933263369 --- checker/BUILD | 1 + checker/validation_result.h | 1 + common/decl.h | 64 +++++++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/checker/BUILD b/checker/BUILD index 10ed6e363..efca3ff73 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -49,6 +49,7 @@ cc_library( deps = [ ":type_check_issue", "//common:ast", + "//common:decl", "//common:source", "//common:type", "@com_google_absl//absl/base:nullability", diff --git a/checker/validation_result.h b/checker/validation_result.h index f424e7f6f..7417e9969 100644 --- a/checker/validation_result.h +++ b/checker/validation_result.h @@ -28,6 +28,7 @@ #include "absl/types/span.h" #include "checker/type_check_issue.h" #include "common/ast.h" +#include "common/decl.h" #include "common/source.h" #include "common/type.h" diff --git a/common/decl.h b/common/decl.h index 7aea8c502..b15645236 100644 --- a/common/decl.h +++ b/common/decl.h @@ -377,6 +377,70 @@ bool TypeIsAssignable(const Type& to, const Type& from); } // namespace common_internal +struct VariableDeclEqualTo { + using is_transparent = void; + + bool operator()(const cel::VariableDecl& lhs, + const cel::VariableDecl& rhs) const { + return lhs.name() == rhs.name(); + } + + bool operator()(const cel::VariableDecl& lhs, std::string_view rhs) const { + return lhs.name() == rhs; + } + + bool operator()(std::string_view lhs, const cel::VariableDecl& rhs) const { + return lhs == rhs.name(); + } +}; + +struct VariableDeclHash { + using is_transparent = void; + + size_t operator()(const cel::VariableDecl& decl) const { + return (*this)(decl.name()); + } + + size_t operator()(std::string_view name) const { return absl::HashOf(name); } +}; + +using VariableDeclSet = absl::flat_hash_set; + +struct FunctionDeclEqualTo { + using is_transparent = void; + + bool operator()(const cel::FunctionDecl& lhs, + const cel::FunctionDecl& rhs) const { + return (*this)(lhs.name(), rhs.name()); + } + + bool operator()(const cel::FunctionDecl& lhs, std::string_view rhs) const { + return (*this)(lhs.name(), rhs); + } + + bool operator()(std::string_view lhs, const cel::FunctionDecl& rhs) const { + return (*this)(lhs, rhs.name()); + } + + bool operator()(std::string_view lhs, std::string_view rhs) const { + return lhs == rhs; + } +}; + +struct FunctionDeclHash { + using is_transparent = void; + + size_t operator()(const cel::FunctionDecl& decl) const { + return absl::HashOf(decl.name()); + } + + size_t operator()(std::string_view name) const { return absl::HashOf(name); } +}; + +using FunctionDeclSet = absl::flat_hash_set; + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_DECL_H_ From 19b54d719884e57b5e41be5c79001d0a13c22c3c Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Tue, 16 Jun 2026 15:59:45 -0700 Subject: [PATCH 67/87] Add policy conformance tests to postsubmit windows bazel test workflow. PiperOrigin-RevId: 933348031 --- .github/workflows/windows_bazel_test.yml | 2 +- .github/workflows/windows_bazel_test_post_merge.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/windows_bazel_test.yml b/.github/workflows/windows_bazel_test.yml index 4ac7f2eec..6d12e6861 100644 --- a/.github/workflows/windows_bazel_test.yml +++ b/.github/workflows/windows_bazel_test.yml @@ -25,4 +25,4 @@ jobs: # //... won't work. shell: bash run: | - bazelisk test --config=msvc conformance:all \ No newline at end of file + bazelisk test --config=msvc conformance:all conformance/policy:all \ No newline at end of file diff --git a/.github/workflows/windows_bazel_test_post_merge.yml b/.github/workflows/windows_bazel_test_post_merge.yml index 11801011e..569177fcc 100644 --- a/.github/workflows/windows_bazel_test_post_merge.yml +++ b/.github/workflows/windows_bazel_test_post_merge.yml @@ -9,5 +9,5 @@ jobs: trigger-test: # This prevents the workflow from running automatically when someone # pushes to their fork. - if: github.repository == 'google/cel-cpp' + if: github.repository == 'cel-expr/cel-cpp' uses: ./.github/workflows/windows_bazel_test.yml \ No newline at end of file From 9485d85c8d73d31ed0d21b8b21397c8d74b90584 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Tue, 16 Jun 2026 16:36:08 -0700 Subject: [PATCH 68/87] Update from google/cel-spec to cel-expr/cel-spec (Part 2) PiperOrigin-RevId: 933365198 --- README.md | 2 +- eval/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 41b44388d..7c3c26be0 100644 --- a/README.md +++ b/README.md @@ -15,4 +15,4 @@ parser, and type checker. Released under the [Apache License](LICENSE). -[1]: https://github.com/google/cel-spec +[1]: https://github.com/cel-expr/cel-spec diff --git a/eval/README.md b/eval/README.md index ee6fd0798..32fa4bda4 100644 --- a/eval/README.md +++ b/eval/README.md @@ -3,4 +3,4 @@ A C++ implementation of a [Common Expression Language][1] evaluator. -[1]: https://github.com/google/cel-spec +[1]: https://github.com/cel-expr/cel-spec From 7dc7eacfbccf56c4d74bcb6e0f4dc2e63768c6bf Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 17 Jun 2026 01:08:33 -0700 Subject: [PATCH 69/87] No public description PiperOrigin-RevId: 933560360 --- .../descriptor_pool_type_introspector.cc | 18 +++++++++--------- .../descriptor_pool_type_introspector_test.cc | 8 ++++---- checker/internal/type_check_env.cc | 10 +++++----- checker/internal/type_checker_builder_impl.cc | 2 +- .../internal/type_checker_builder_impl_test.cc | 4 ++-- checker/internal/type_checker_impl.cc | 6 +++--- checker/internal/type_inference_context.cc | 4 ++-- 7 files changed, 26 insertions(+), 26 deletions(-) diff --git a/checker/internal/descriptor_pool_type_introspector.cc b/checker/internal/descriptor_pool_type_introspector.cc index da4f4430b..733e4a3cb 100644 --- a/checker/internal/descriptor_pool_type_introspector.cc +++ b/checker/internal/descriptor_pool_type_introspector.cc @@ -42,7 +42,7 @@ FindStructTypeFieldByNameDirectly( const google::protobuf::Descriptor* absl_nullable descriptor = descriptor_pool->FindMessageTypeByName(type); if (descriptor == nullptr) { - return absl::nullopt; + return std::nullopt; } const google::protobuf::FieldDescriptor* absl_nullable field = descriptor->FindFieldByName(name); @@ -54,7 +54,7 @@ FindStructTypeFieldByNameDirectly( if (field != nullptr) { return StructTypeField(MessageTypeField(field)); } - return absl::nullopt; + return std::nullopt; } // Standard implementation for listing fields. @@ -67,7 +67,7 @@ ListStructTypeFieldsDirectly( const google::protobuf::Descriptor* absl_nullable descriptor = descriptor_pool->FindMessageTypeByName(type); if (descriptor == nullptr) { - return absl::nullopt; + return std::nullopt; } std::vector extensions; @@ -100,7 +100,7 @@ DescriptorPoolTypeIntrospector::FindTypeImpl(absl::string_view name) const { if (enum_descriptor != nullptr) { return Type::Enum(enum_descriptor); } - return absl::nullopt; + return std::nullopt; } absl::StatusOr> @@ -112,7 +112,7 @@ DescriptorPoolTypeIntrospector::FindEnumConstantImpl( const google::protobuf::EnumValueDescriptor* absl_nullable enum_value_descriptor = enum_descriptor->FindValueByName(value); if (enum_value_descriptor == nullptr) { - return absl::nullopt; + return std::nullopt; } return EnumConstant{ .type = Type::Enum(enum_descriptor), @@ -121,7 +121,7 @@ DescriptorPoolTypeIntrospector::FindEnumConstantImpl( .number = enum_value_descriptor->number(), }; } - return absl::nullopt; + return std::nullopt; } absl::StatusOr> @@ -134,7 +134,7 @@ DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( const FieldTable* field_table = GetFieldTable(type); if (field_table == nullptr) { - return absl::nullopt; + return std::nullopt; } if (auto it = field_table->json_name_map.find(name); @@ -147,7 +147,7 @@ DescriptorPoolTypeIntrospector::FindStructTypeFieldByNameImpl( return field_table->fields[it->second].field; } - return absl::nullopt; + return std::nullopt; } absl::StatusOr< @@ -160,7 +160,7 @@ DescriptorPoolTypeIntrospector::ListFieldsForStructTypeImpl( const FieldTable* field_table = GetFieldTable(type); if (field_table == nullptr) { - return absl::nullopt; + return std::nullopt; } std::vector fields; fields.reserve(field_table->non_extensions.size()); diff --git a/checker/internal/descriptor_pool_type_introspector_test.cc b/checker/internal/descriptor_pool_type_introspector_test.cc index 456798744..db766b347 100644 --- a/checker/internal/descriptor_pool_type_introspector_test.cc +++ b/checker/internal/descriptor_pool_type_introspector_test.cc @@ -47,7 +47,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, FindType) { "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"), IsOkAndHolds(Optional(Property(&Type::IsEnum, true)))); EXPECT_THAT(introspector.FindType("non.existent.Type"), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); } TEST(DescriptorPoolTypeIntrospectorTest, FindEnumConstant) { @@ -84,7 +84,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, auto field = introspector.FindStructTypeFieldByName( "cel.expr.conformance.proto3.TestAllTypes", "singleInt64"); - EXPECT_THAT(field, IsOkAndHolds(Eq(absl::nullopt))); + EXPECT_THAT(field, IsOkAndHolds(Eq(std::nullopt))); } TEST(DescriptorPoolTypeIntrospectorTest, FindExtension) { @@ -108,7 +108,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, FindStructTypeFieldByNameWithJsonOpt) { auto field = introspector.FindStructTypeFieldByName( "cel.expr.conformance.proto3.TestAllTypes", "single_int64"); - ASSERT_THAT(field, IsOkAndHolds(Eq(absl::nullopt))); + ASSERT_THAT(field, IsOkAndHolds(Eq(std::nullopt))); } TEST(DescriptorPoolTypeIntrospectorTest, @@ -168,7 +168,7 @@ TEST(DescriptorPoolTypeIntrospectorTest, ListFieldsForStructTypeNotFound) { internal::GetTestingDescriptorPool()); auto fields = introspector.ListFieldsForStructType( "cel.expr.conformance.proto3.SomeOtherType"); - EXPECT_THAT(fields, IsOkAndHolds(Eq(absl::nullopt))); + EXPECT_THAT(fields, IsOkAndHolds(Eq(std::nullopt))); } } // namespace diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index 47487220c..8dc83518d 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -58,7 +58,7 @@ absl::StatusOr> TypeCheckEnv::LookupTypeName( return type; } } - return absl::nullopt; + return std::nullopt; } absl::StatusOr> TypeCheckEnv::LookupEnumConstant( @@ -75,7 +75,7 @@ absl::StatusOr> TypeCheckEnv::LookupEnumConstant( return decl; } } - return absl::nullopt; + return std::nullopt; } absl::StatusOr> TypeCheckEnv::LookupTypeConstant( @@ -92,14 +92,14 @@ absl::StatusOr> TypeCheckEnv::LookupTypeConstant( return LookupEnumConstant(enum_name_candidate, value_name_candidate); } - return absl::nullopt; + return std::nullopt; } absl::StatusOr> TypeCheckEnv::LookupStructField( absl::string_view type_name, absl::string_view field_name) const { if (proto_type_mask_registry_ != nullptr && !proto_type_mask_registry_->FieldIsVisible(type_name, field_name)) { - return absl::nullopt; + return std::nullopt; } // Check the type providers in registration order. // Note: this doesn't allow for shadowing a type with a subset type of the @@ -113,7 +113,7 @@ absl::StatusOr> TypeCheckEnv::LookupStructField( return field; } } - return absl::nullopt; + return std::nullopt; } const VariableDecl* absl_nullable VariableScope::LookupLocalVariable( diff --git a/checker/internal/type_checker_builder_impl.cc b/checker/internal/type_checker_builder_impl.cc index f0332b999..4289fb528 100644 --- a/checker/internal/type_checker_builder_impl.cc +++ b/checker/internal/type_checker_builder_impl.cc @@ -198,7 +198,7 @@ std::optional FilterDecl(FunctionDecl decl, } } if (filtered.overloads().empty()) { - return absl::nullopt; + return std::nullopt; } filtered.set_name(std::move(name)); return filtered; diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc index 913e704ee..fa7f80960 100644 --- a/checker/internal/type_checker_builder_impl_test.cc +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -349,7 +349,7 @@ TEST(ContextDeclsTest, CustomStructNotSupported) { if (name == "com.example.MyStruct") { return common_internal::MakeBasicStructType("com.example.MyStruct"); } - return absl::nullopt; + return std::nullopt; } }; @@ -370,7 +370,7 @@ TEST(ContextDeclsWithProtoTypeMaskTest, CustomStructNotSupported) { if (name == "com.example.MyStruct") { return common_internal::MakeBasicStructType("com.example.MyStruct"); } - return absl::nullopt; + return std::nullopt; } }; diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index f3a06a28d..2bc05dbf7 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -1098,11 +1098,11 @@ std::optional ResolveVisitor::CheckFieldType(int64_t id, auto field_info = env_->LookupStructField(struct_type.name(), field); if (!field_info.ok()) { status_.Update(field_info.status()); - return absl::nullopt; + return std::nullopt; } if (!field_info->has_value()) { ReportUndefinedField(id, field, struct_type.name()); - return absl::nullopt; + return std::nullopt; } auto type = field_info->value().GetType(); if (type.kind() == TypeKind::kEnum) { @@ -1134,7 +1134,7 @@ std::optional ResolveVisitor::CheckFieldType(int64_t id, "expression of type '", FormatTypeName(inference_context_->FinalizeType(operand_type)), "' cannot be the operand of a select operation"))); - return absl::nullopt; + return std::nullopt; } void ResolveVisitor::ResolveSelectOperation(const Expr& expr, diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 4681784af..4f738b804 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -149,7 +149,7 @@ std::optional WrapperToPrimitive(const Type& t) { case TypeKind::kUintWrapper: return UintType(); default: - return absl::nullopt; + return std::nullopt; } } @@ -579,7 +579,7 @@ TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, } if (!result_type.has_value() || matching_overloads.empty()) { - return absl::nullopt; + return std::nullopt; } return OverloadResolution{ .result_type = FullySubstitute(*result_type, /*free_to_dyn=*/false), From 8b7068abb4062074a135491ad8357139287084f9 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 17 Jun 2026 05:56:44 -0700 Subject: [PATCH 70/87] No public description PiperOrigin-RevId: 933675942 --- internal/message_equality_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/message_equality_test.cc b/internal/message_equality_test.cc index 318138d9b..092edd71b 100644 --- a/internal/message_equality_test.cc +++ b/internal/message_equality_test.cc @@ -399,7 +399,7 @@ absl::optional, PackTestAllTypesProto3Field(const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* absl_nonnull field) { if (field->is_map()) { - return absl::nullopt; + return std::nullopt; } if (field->is_repeated() && field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { @@ -425,7 +425,7 @@ PackTestAllTypesProto3Field(const google::protobuf::Message& message, cel::to_address(packed), any_field)); return std::pair{packed, any_field}; } - return absl::nullopt; + return std::nullopt; } TEST_P(UnaryMessageFieldEqualsTest, Equals) { From 071fca088461260c8667b075d2c3c2a78d7b1fd8 Mon Sep 17 00:00:00 2001 From: Mike Kruskal Date: Wed, 17 Jun 2026 11:03:40 -0700 Subject: [PATCH 71/87] Add -fexceptions to the antlr gencode to avoid breaking users building with -fno-exceptions PiperOrigin-RevId: 933819478 --- bazel/antlr.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/bazel/antlr.bzl b/bazel/antlr.bzl index 2abbb6dbd..a4d28cdf8 100644 --- a/bazel/antlr.bzl +++ b/bazel/antlr.bzl @@ -55,6 +55,7 @@ def antlr_cc_library(name, src, package): generated, "@antlr4-cpp-runtime//:antlr4-cpp-runtime", ], + copts = ["-fexceptions"], linkstatic = 1, ) From e2ed5270e50f90a15e1008a18662cefbac303650 Mon Sep 17 00:00:00 2001 From: Justin King Date: Wed, 17 Jun 2026 14:08:48 -0700 Subject: [PATCH 72/87] Internal change PiperOrigin-RevId: 933914348 --- checker/internal/type_checker_impl.cc | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 2bc05dbf7..9c2806e89 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -199,6 +199,7 @@ class ResolveVisitor : public AstVisitorBase { struct AttributeResolution { const VariableDecl* decl; bool requires_disambiguation; + bool local; }; ResolveVisitor(NamespaceGenerator namespace_generator, @@ -1001,7 +1002,7 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, const VariableDecl* local_decl = LookupLocalIdentifier(name); if (local_decl != nullptr && !absl::StartsWith(name, ".")) { - attributes_[&expr] = {local_decl, false}; + attributes_[&expr] = {local_decl, false, /*local=*/true}; types_[&expr] = inference_context_->InstantiateTypeParams(local_decl->type()); return; @@ -1016,8 +1017,13 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, }); if (decl != nullptr) { - attributes_[&expr] = {decl, - /* requires_disambiguation= */ local_decl != nullptr}; + attributes_[&expr] = { + decl, + /* requires_disambiguation= */ local_decl != nullptr, + // There is some oddity here, `.` prefixed idents should never be local. + // So LookupLocalIdentifier above should never return a valid decl. + // Perhaps this is a refactor holdover? + /*local=*/false}; types_[&expr] = inference_context_->InstantiateTypeParams(decl->type()); return; } @@ -1072,9 +1078,14 @@ void ResolveVisitor::ResolveQualifiedIdentifier( root = &root->select_expr().operand(); } - attributes_[root] = {decl, - /* requires_disambiguation= */ decl != local_decl && - local_decl != nullptr}; + attributes_[root] = { + decl, + /* requires_disambiguation= */ decl != local_decl && + local_decl != nullptr, + // There is some oddity here, `.` prefixed idents should never be local. + // So LookupLocalIdentifier above should never return a valid decl. + // Perhaps this is a refactor holdover? + /*local=*/local_decl == decl}; types_[root] = inference_context_->InstantiateTypeParams(decl->type()); // fix-up select operations that were deferred. From 95e4b10bcc0ee5d3bd99650dc2a95e8fa2fee99a Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 17 Jun 2026 21:27:58 -0700 Subject: [PATCH 73/87] No public description PiperOrigin-RevId: 934095048 --- codelab/network_functions.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/codelab/network_functions.cc b/codelab/network_functions.cc index 64f199cb3..6cc1505a9 100644 --- a/codelab/network_functions.cc +++ b/codelab/network_functions.cc @@ -368,7 +368,7 @@ std::optional NetworkAddressRep::Unwrap( auto opaque = value.AsOpaque(); if (!opaque.has_value() || opaque->GetTypeId() != cel::TypeId()) { - return absl::nullopt; + return std::nullopt; } // Note: safety depends on: @@ -385,10 +385,10 @@ std::optional NetworkAddressRep::Parse( char ipv6[16]; auto version = ParseAddressImpl(str, &ipv4, ipv6); if (!version.ok()) { - return absl::nullopt; + return std::nullopt; } if (*version != IpVersion::kIPv4) { - return absl::nullopt; + return std::nullopt; } NetworkAddressRep rep; rep.version_ = *version; @@ -422,7 +422,7 @@ std::optional NetworkAddressMatcher::Parse( int dash_pos = str.find('-'); if (dash_pos == absl::string_view::npos) { // TODO(uncreated-issue/86): CIDR style addr/prefix-length - return absl::nullopt; + return std::nullopt; } absl::string_view min_str = str.substr(0, dash_pos); absl::string_view max_str = str.substr(dash_pos + 1); @@ -431,23 +431,23 @@ std::optional NetworkAddressMatcher::Parse( NetworkRangev6 v6; auto min_parse = ParseAddressImpl(min_str, &v4.min_incl, v6.min_incl); if (!min_parse.ok()) { - return absl::nullopt; + return std::nullopt; } auto max_parse = ParseAddressImpl(max_str, &v4.max_incl, v6.max_incl); if (!max_parse.ok()) { - return absl::nullopt; + return std::nullopt; } if (*min_parse != *max_parse) { - return absl::nullopt; + return std::nullopt; } NetworkAddressMatcher rep; if (*min_parse == IpVersion::kIPv4) { if (v4.min_incl > v4.max_incl) { - return absl::nullopt; + return std::nullopt; } rep.ranges_v4_.push_back(v4); } else if (*min_parse == IpVersion::kIPv6) { - return absl::nullopt; + return std::nullopt; } return rep; From 51faf9f1f2889bc1bbc582755ba37ed1d105f393 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 17 Jun 2026 22:43:18 -0700 Subject: [PATCH 74/87] No public description PiperOrigin-RevId: 934123553 --- extensions/select_optimization.cc | 12 ++++++------ extensions/select_optimization_test.cc | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index 42cad0f92..0cc64311a 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -158,11 +158,11 @@ std::optional GetSelectInstruction( absl::string_view field_name) { auto field_or = planner_context.type_reflector() .FindStructTypeFieldByName(runtime_type, field_name) - .value_or(absl::nullopt); + .value_or(std::nullopt); if (field_or.has_value()) { return SelectInstruction{field_or->number(), std::string(field_or->name())}; } - return absl::nullopt; + return std::nullopt; } absl::StatusOr SelectQualifierFromList(const ListExpr& list) { @@ -410,7 +410,7 @@ class RewriterImpl : public AstRewriterBase { std::optional rt_type = (checker_type.has_message_type()) ? GetRuntimeType(checker_type.message_type().type()) - : absl::nullopt; + : std::nullopt; if (rt_type.has_value() && (*rt_type).Is()) { const StructType& runtime_type = rt_type->GetStruct(); std::optional field_or = @@ -540,7 +540,7 @@ class RewriterImpl : public AstRewriterBase { std::optional GetRuntimeType(absl::string_view type_name) { return planner_context_.type_reflector().FindType(type_name).value_or( - absl::nullopt); + std::nullopt); } void SetProgressStatus(const absl::Status& status) { @@ -600,7 +600,7 @@ class OptimizedSelectImpl { absl::StatusOr> CheckForMarkedAttributes( ExecutionFrameBase& frame, const AttributeTrail& attribute_trail) { if (attribute_trail.empty()) { - return absl::nullopt; + return std::nullopt; } if (frame.unknown_processing_enabled() && @@ -624,7 +624,7 @@ absl::StatusOr> CheckForMarkedAttributes( attribute_trail.attribute()); } - return absl::nullopt; + return std::nullopt; } absl::StatusOr OptimizedSelectImpl::ApplySelect( diff --git a/extensions/select_optimization_test.cc b/extensions/select_optimization_test.cc index 9d4024098..c14c4d461 100644 --- a/extensions/select_optimization_test.cc +++ b/extensions/select_optimization_test.cc @@ -257,7 +257,7 @@ class MockAccessApis : public LegacyTypeInfoApis, public LegacyTypeAccessApis { std::optional< google::api::expr::runtime::LegacyTypeInfoApis::FieldDescription> FindFieldByName(absl::string_view field_name) const override { - return absl::nullopt; + return std::nullopt; } MOCK_METHOD(absl::StatusOr, GetField, From 80fc78aca1945cdf2653e6d888c5ff15058892c4 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 00:27:29 -0700 Subject: [PATCH 75/87] No public description PiperOrigin-RevId: 934169077 --- base/operators.cc | 32 ++++++++++++++++---------------- base/operators_test.cc | 32 ++++++++++++++++---------------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/base/operators.cc b/base/operators.cc index 805acc5a1..b7df40b27 100644 --- a/base/operators.cc +++ b/base/operators.cc @@ -179,13 +179,13 @@ CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_TERNARY_OPERATOR) absl::optional Operator::FindByName(absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(operators_by_name.cbegin(), operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == operators_by_name.cend() || (*it)->name != input) { - return absl::nullopt; + return std::nullopt; } return Operator(*it); } @@ -193,13 +193,13 @@ absl::optional Operator::FindByName(absl::string_view input) { absl::optional Operator::FindByDisplayName(absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(operators_by_display_name.cbegin(), operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == operators_by_name.cend() || (*it)->display_name != input) { - return absl::nullopt; + return std::nullopt; } return Operator(*it); } @@ -208,13 +208,13 @@ absl::optional UnaryOperator::FindByName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(unary_operators_by_name.cbegin(), unary_operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == unary_operators_by_name.cend() || (*it)->name != input) { - return absl::nullopt; + return std::nullopt; } return UnaryOperator(*it); } @@ -223,14 +223,14 @@ absl::optional UnaryOperator::FindByDisplayName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(unary_operators_by_display_name.cbegin(), unary_operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == unary_operators_by_display_name.cend() || (*it)->display_name != input) { - return absl::nullopt; + return std::nullopt; } return UnaryOperator(*it); } @@ -239,13 +239,13 @@ absl::optional BinaryOperator::FindByName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(binary_operators_by_name.cbegin(), binary_operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == binary_operators_by_name.cend() || (*it)->name != input) { - return absl::nullopt; + return std::nullopt; } return BinaryOperator(*it); } @@ -254,14 +254,14 @@ absl::optional BinaryOperator::FindByDisplayName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(binary_operators_by_display_name.cbegin(), binary_operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == binary_operators_by_display_name.cend() || (*it)->display_name != input) { - return absl::nullopt; + return std::nullopt; } return BinaryOperator(*it); } @@ -270,13 +270,13 @@ absl::optional TernaryOperator::FindByName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(ternary_operators_by_name.cbegin(), ternary_operators_by_name.cend(), input, OperatorDataNameComparer{}); if (it == ternary_operators_by_name.cend() || (*it)->name != input) { - return absl::nullopt; + return std::nullopt; } return TernaryOperator(*it); } @@ -285,14 +285,14 @@ absl::optional TernaryOperator::FindByDisplayName( absl::string_view input) { absl::call_once(operators_once_flag, InitializeOperators); if (input.empty()) { - return absl::nullopt; + return std::nullopt; } auto it = std::lower_bound(ternary_operators_by_display_name.cbegin(), ternary_operators_by_display_name.cend(), input, OperatorDataDisplayNameComparer{}); if (it == ternary_operators_by_display_name.cend() || (*it)->display_name != input) { - return absl::nullopt; + return std::nullopt; } return TernaryOperator(*it); } diff --git a/base/operators_test.cc b/base/operators_test.cc index fdf95e7ae..6049f76c8 100644 --- a/base/operators_test.cc +++ b/base/operators_test.cc @@ -130,55 +130,55 @@ CEL_INTERNAL_TERNARY_OPERATORS_ENUM(CEL_TERNARY_OPERATOR) TEST(Operator, FindByName) { EXPECT_THAT(Operator::FindByName("@in"), Optional(Eq(Operator::In()))); EXPECT_THAT(Operator::FindByName("_in_"), Optional(Eq(Operator::OldIn()))); - EXPECT_THAT(Operator::FindByName("in"), Eq(absl::nullopt)); - EXPECT_THAT(Operator::FindByName(""), Eq(absl::nullopt)); + EXPECT_THAT(Operator::FindByName("in"), Eq(std::nullopt)); + EXPECT_THAT(Operator::FindByName(""), Eq(std::nullopt)); } TEST(Operator, FindByDisplayName) { EXPECT_THAT(Operator::FindByDisplayName("-"), Optional(Eq(Operator::Subtract()))); - EXPECT_THAT(Operator::FindByDisplayName("@in"), Eq(absl::nullopt)); - EXPECT_THAT(Operator::FindByDisplayName(""), Eq(absl::nullopt)); + EXPECT_THAT(Operator::FindByDisplayName("@in"), Eq(std::nullopt)); + EXPECT_THAT(Operator::FindByDisplayName(""), Eq(std::nullopt)); } TEST(UnaryOperator, FindByName) { EXPECT_THAT(UnaryOperator::FindByName("-_"), Optional(Eq(Operator::Negate()))); - EXPECT_THAT(UnaryOperator::FindByName("_-_"), Eq(absl::nullopt)); - EXPECT_THAT(UnaryOperator::FindByName(""), Eq(absl::nullopt)); + EXPECT_THAT(UnaryOperator::FindByName("_-_"), Eq(std::nullopt)); + EXPECT_THAT(UnaryOperator::FindByName(""), Eq(std::nullopt)); } TEST(UnaryOperator, FindByDisplayName) { EXPECT_THAT(UnaryOperator::FindByDisplayName("-"), Optional(Eq(Operator::Negate()))); - EXPECT_THAT(UnaryOperator::FindByDisplayName("&&"), Eq(absl::nullopt)); - EXPECT_THAT(UnaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); + EXPECT_THAT(UnaryOperator::FindByDisplayName("&&"), Eq(std::nullopt)); + EXPECT_THAT(UnaryOperator::FindByDisplayName(""), Eq(std::nullopt)); } TEST(BinaryOperator, FindByName) { EXPECT_THAT(BinaryOperator::FindByName("_-_"), Optional(Eq(Operator::Subtract()))); - EXPECT_THAT(BinaryOperator::FindByName("-_"), Eq(absl::nullopt)); - EXPECT_THAT(BinaryOperator::FindByName(""), Eq(absl::nullopt)); + EXPECT_THAT(BinaryOperator::FindByName("-_"), Eq(std::nullopt)); + EXPECT_THAT(BinaryOperator::FindByName(""), Eq(std::nullopt)); } TEST(BinaryOperator, FindByDisplayName) { EXPECT_THAT(BinaryOperator::FindByDisplayName("-"), Optional(Eq(Operator::Subtract()))); - EXPECT_THAT(BinaryOperator::FindByDisplayName("!"), Eq(absl::nullopt)); - EXPECT_THAT(BinaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); + EXPECT_THAT(BinaryOperator::FindByDisplayName("!"), Eq(std::nullopt)); + EXPECT_THAT(BinaryOperator::FindByDisplayName(""), Eq(std::nullopt)); } TEST(TernaryOperator, FindByName) { EXPECT_THAT(TernaryOperator::FindByName("_?_:_"), Optional(Eq(TernaryOperator::Conditional()))); - EXPECT_THAT(TernaryOperator::FindByName("-_"), Eq(absl::nullopt)); - EXPECT_THAT(TernaryOperator::FindByName(""), Eq(absl::nullopt)); + EXPECT_THAT(TernaryOperator::FindByName("-_"), Eq(std::nullopt)); + EXPECT_THAT(TernaryOperator::FindByName(""), Eq(std::nullopt)); } TEST(TernaryOperator, FindByDisplayName) { - EXPECT_THAT(TernaryOperator::FindByDisplayName(""), Eq(absl::nullopt)); - EXPECT_THAT(TernaryOperator::FindByDisplayName("!"), Eq(absl::nullopt)); + EXPECT_THAT(TernaryOperator::FindByDisplayName(""), Eq(std::nullopt)); + EXPECT_THAT(TernaryOperator::FindByDisplayName("!"), Eq(std::nullopt)); } TEST(Operator, SupportsAbslHash) { From 8515127e730d338066a0f39d13317b0b4cebc32a Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 03:35:24 -0700 Subject: [PATCH 76/87] No public description PiperOrigin-RevId: 934248225 --- tools/flatbuffers_backed_impl.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/flatbuffers_backed_impl.cc b/tools/flatbuffers_backed_impl.cc index 10c0b1cb8..2ee226859 100644 --- a/tools/flatbuffers_backed_impl.cc +++ b/tools/flatbuffers_backed_impl.cc @@ -127,7 +127,7 @@ class ObjectStringIndexedMapImpl : public CelMap { arena_, **it, schema_, object_, arena_)); } } - return absl::nullopt; + return std::nullopt; } absl::StatusOr ListKeys() const override { return &keys_; } @@ -188,7 +188,7 @@ absl::optional FlatBuffersMapImpl::operator[]( } auto field = keys_.fields->LookupByKey(cel_key.StringOrDie().value().data()); if (field == nullptr) { - return absl::nullopt; + return std::nullopt; } switch (field->type()->base_type()) { case reflection::Byte: @@ -323,15 +323,15 @@ absl::optional FlatBuffersMapImpl::operator[]( } default: // Unsupported vector base types - return absl::nullopt; + return std::nullopt; } break; } default: // Unsupported types: enums, unions, arrays - return absl::nullopt; + return std::nullopt; } - return absl::nullopt; + return std::nullopt; } const CelMap* CreateFlatBuffersBackedObject(const uint8_t* flatbuf, From 3244f4a55e713555fc0f40ecfea5e38b5e3ad492 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 03:35:53 -0700 Subject: [PATCH 77/87] No public description PiperOrigin-RevId: 934248421 --- testutil/test_macros.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/testutil/test_macros.cc b/testutil/test_macros.cc index 672439dc5..19e9a4844 100644 --- a/testutil/test_macros.cc +++ b/testutil/test_macros.cc @@ -40,7 +40,7 @@ bool IsCelNamespace(const Expr& target) { std::optional CelBlockMacroExpander(MacroExprFactory& factory, Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { - return absl::nullopt; + return std::nullopt; } Expr& bindings_arg = args[0]; if (!bindings_arg.has_list_expr()) { @@ -53,7 +53,7 @@ std::optional CelBlockMacroExpander(MacroExprFactory& factory, std::optional CelIndexMacroExpander(MacroExprFactory& factory, Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { - return absl::nullopt; + return std::nullopt; } Expr& index_arg = args[0]; if (!index_arg.has_const_expr() || !index_arg.const_expr().has_int_value()) { @@ -72,7 +72,7 @@ std::optional CelIterVarMacroExpander(MacroExprFactory& factory, Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { - return absl::nullopt; + return std::nullopt; } Expr& depth_arg = args[0]; if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || @@ -96,7 +96,7 @@ std::optional CelAccuVarMacroExpander(MacroExprFactory& factory, Expr& target, absl::Span args) { if (!IsCelNamespace(target)) { - return absl::nullopt; + return std::nullopt; } Expr& depth_arg = args[0]; if (!depth_arg.has_const_expr() || !depth_arg.const_expr().has_int_value() || From 3456ec81c08ec5ce3980c75b9522e917683e58ed Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 05:34:35 -0700 Subject: [PATCH 78/87] No public description PiperOrigin-RevId: 934294071 --- extensions/protobuf/bind_proto_to_activation_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions/protobuf/bind_proto_to_activation_test.cc b/extensions/protobuf/bind_proto_to_activation_test.cc index fd79508ac..680b4b353 100644 --- a/extensions/protobuf/bind_proto_to_activation_test.cc +++ b/extensions/protobuf/bind_proto_to_activation_test.cc @@ -76,10 +76,10 @@ TEST_F(BindProtoToActivationTest, BindProtoToActivationSkip) { EXPECT_THAT(activation.FindVariable("single_int32", descriptor_pool(), message_factory(), arena()), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); EXPECT_THAT(activation.FindVariable("single_sint32", descriptor_pool(), message_factory(), arena()), - IsOkAndHolds(Eq(absl::nullopt))); + IsOkAndHolds(Eq(std::nullopt))); } TEST_F(BindProtoToActivationTest, BindProtoToActivationDefault) { From 774340d632ae13dc74c64f2eade30d8209686089 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 06:50:51 -0700 Subject: [PATCH 79/87] No public description PiperOrigin-RevId: 934323942 --- parser/macro_registry.cc | 4 ++-- parser/macro_registry_test.cc | 4 ++-- parser/parser.cc | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/parser/macro_registry.cc b/parser/macro_registry.cc index 3a816b10e..d36761e87 100644 --- a/parser/macro_registry.cc +++ b/parser/macro_registry.cc @@ -55,7 +55,7 @@ absl::optional MacroRegistry::FindMacro(absl::string_view name, bool receiver_style) const { // :: if (name.empty() || absl::StrContains(name, ':')) { - return absl::nullopt; + return std::nullopt; } // Try argument count specific key first. auto key = absl::StrCat(name, ":", arg_count, ":", @@ -68,7 +68,7 @@ absl::optional MacroRegistry::FindMacro(absl::string_view name, if (auto it = macros_.find(key); it != macros_.end()) { return it->second; } - return absl::nullopt; + return std::nullopt; } std::vector MacroRegistry::ListMacros() const { diff --git a/parser/macro_registry_test.cc b/parser/macro_registry_test.cc index 9e6da87a4..db8a99ab2 100644 --- a/parser/macro_registry_test.cc +++ b/parser/macro_registry_test.cc @@ -30,14 +30,14 @@ using ::testing::Ne; TEST(MacroRegistry, RegisterAndFind) { MacroRegistry macros; EXPECT_THAT(macros.RegisterMacro(HasMacro()), IsOk()); - EXPECT_THAT(macros.FindMacro("has", 1, false), Ne(absl::nullopt)); + EXPECT_THAT(macros.FindMacro("has", 1, false), Ne(std::nullopt)); } TEST(MacroRegistry, RegisterRollsback) { MacroRegistry macros; EXPECT_THAT(macros.RegisterMacros({HasMacro(), AllMacro(), AllMacro()}), StatusIs(absl::StatusCode::kAlreadyExists)); - EXPECT_THAT(macros.FindMacro("has", 1, false), Eq(absl::nullopt)); + EXPECT_THAT(macros.FindMacro("has", 1, false), Eq(std::nullopt)); } } // namespace diff --git a/parser/parser.cc b/parser/parser.cc index a858337a4..24b4ca079 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -1468,11 +1468,11 @@ Expr ParserVisitor::GlobalCallOrMacroImpl(int64_t expr_id, } } factory_.BeginMacro(factory_.GetSourceRange(expr_id)); - auto expr = macro->Expand(factory_, absl::nullopt, absl::MakeSpan(args)); + auto expr = macro->Expand(factory_, std::nullopt, absl::MakeSpan(args)); factory_.EndMacro(); if (expr) { if (add_macro_calls_) { - factory_.AddMacroCall(expr->id(), function, absl::nullopt, + factory_.AddMacroCall(expr->id(), function, std::nullopt, std::move(macro_args)); } // We did not end up using `expr_id`. Delete metadata. From 83ccadfe7e0d284956332f3b90804c23b6ef3caf Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 18 Jun 2026 15:33:26 -0700 Subject: [PATCH 80/87] Refactor local variable resolution in checker. No functional changes, but makes intent a little clearer w.r.t. checking if we need to preserve a leading '.' or not. PiperOrigin-RevId: 934579650 --- checker/internal/type_checker_impl.cc | 97 ++++++++++++++------------- 1 file changed, 52 insertions(+), 45 deletions(-) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 9c2806e89..bca187417 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -960,11 +960,10 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, const VariableDecl* absl_nullable ResolveVisitor::LookupLocalIdentifier( absl::string_view name) { - // Note: if we see a leading dot, this shouldn't resolve to a local variable, - // but we need to check whether we need to disambiguate against a global in - // the reference map. if (absl::StartsWith(name, ".")) { - name = name.substr(1); + // Should not happen for normally parsed CEL, but prevent lookup in case + // of hand-crafted ASTs. + return nullptr; } return current_scope_->LookupLocalVariable(name); } @@ -999,13 +998,15 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, absl::string_view name) { // Local variables (comprehension, bind) are simple identifiers so we can // skip generating the different namespace-qualified candidates. - const VariableDecl* local_decl = LookupLocalIdentifier(name); - - if (local_decl != nullptr && !absl::StartsWith(name, ".")) { - attributes_[&expr] = {local_decl, false, /*local=*/true}; - types_[&expr] = - inference_context_->InstantiateTypeParams(local_decl->type()); - return; + if (!absl::StartsWith(name, ".")) { + const VariableDecl* local_decl = LookupLocalIdentifier(name); + if (local_decl != nullptr) { + attributes_[&expr] = {local_decl, /*requires_disambiguation=*/false, + /*local=*/true}; + types_[&expr] = + inference_context_->InstantiateTypeParams(local_decl->type()); + return; + } } const VariableDecl* decl = nullptr; @@ -1016,14 +1017,13 @@ void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, return decl == nullptr; }); + bool requires_disambiguation = false; + if (absl::StartsWith(name, ".")) { + requires_disambiguation = LookupLocalIdentifier(name.substr(1)) != nullptr; + } + if (decl != nullptr) { - attributes_[&expr] = { - decl, - /* requires_disambiguation= */ local_decl != nullptr, - // There is some oddity here, `.` prefixed idents should never be local. - // So LookupLocalIdentifier above should never return a valid decl. - // Perhaps this is a refactor holdover? - /*local=*/false}; + attributes_[&expr] = {decl, requires_disambiguation, /*local=*/false}; types_[&expr] = inference_context_->InstantiateTypeParams(decl->type()); return; } @@ -1039,35 +1039,49 @@ void ResolveVisitor::ResolveQualifiedIdentifier( return; } + int matched_segment_index = -1; + const VariableDecl* decl = nullptr; + bool requires_disambiguation = false; + bool is_local = false; // Local variables (comprehension, bind) are simple identifiers so we can // skip generating the different namespace-qualified candidates. - const VariableDecl* local_decl = LookupLocalIdentifier(qualifiers[0]); - const VariableDecl* decl = nullptr; - - int matched_segment_index = -1; - - if (local_decl != nullptr && !absl::StartsWith(qualifiers[0], ".")) { - decl = local_decl; - matched_segment_index = 0; - } else { - namespace_generator_.GenerateCandidates( - qualifiers, [&decl, &matched_segment_index, this]( - absl::string_view candidate, int segment_index) { - decl = LookupGlobalIdentifier(candidate); - if (decl != nullptr) { - matched_segment_index = segment_index; - return false; - } - return true; - }); + if (!absl::StartsWith(qualifiers[0], ".")) { + const VariableDecl* local_decl = LookupLocalIdentifier(qualifiers[0]); + if (local_decl != nullptr) { + decl = local_decl; + matched_segment_index = 0; + is_local = true; + goto resolve_select_trail; + } } + namespace_generator_.GenerateCandidates( + qualifiers, [&decl, &matched_segment_index, this]( + absl::string_view candidate, int segment_index) { + decl = LookupGlobalIdentifier(candidate); + if (decl != nullptr) { + matched_segment_index = segment_index; + return false; + } + return true; + }); + if (decl == nullptr) { ReportMissingReference(expr, FormatCandidate(qualifiers)); types_[&expr] = ErrorType(); return; } + if (absl::StartsWith(qualifiers[0], ".")) { + const VariableDecl* local_decl = + LookupLocalIdentifier(qualifiers[0].substr(1)); + if (local_decl != nullptr) { + requires_disambiguation = true; + } + } + +resolve_select_trail: + const int num_select_opts = qualifiers.size() - matched_segment_index - 1; const Expr* root = &expr; @@ -1078,14 +1092,7 @@ void ResolveVisitor::ResolveQualifiedIdentifier( root = &root->select_expr().operand(); } - attributes_[root] = { - decl, - /* requires_disambiguation= */ decl != local_decl && - local_decl != nullptr, - // There is some oddity here, `.` prefixed idents should never be local. - // So LookupLocalIdentifier above should never return a valid decl. - // Perhaps this is a refactor holdover? - /*local=*/local_decl == decl}; + attributes_[root] = {decl, requires_disambiguation, is_local}; types_[root] = inference_context_->InstantiateTypeParams(decl->type()); // fix-up select operations that were deferred. From ef57455367567055465d8d86032c54fb9db27aff Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 19:41:49 -0700 Subject: [PATCH 81/87] No public description PiperOrigin-RevId: 934677950 --- extensions/bindings_ext.cc | 2 +- extensions/lists_functions.cc | 2 +- extensions/math_ext_macros.cc | 6 +++--- extensions/math_ext_test.cc | 4 ++-- extensions/proto_ext.cc | 12 ++++++------ 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/extensions/bindings_ext.cc b/extensions/bindings_ext.cc index c59f724bd..4823c077c 100644 --- a/extensions/bindings_ext.cc +++ b/extensions/bindings_ext.cc @@ -73,7 +73,7 @@ std::vector bindings_macros() { [](MacroExprFactory& factory, Expr& target, absl::Span args) -> absl::optional { if (!IsTargetNamespace(target)) { - return absl::nullopt; + return std::nullopt; } if (!args[0].has_ident_expr()) { return factory.ReportErrorAt( diff --git a/extensions/lists_functions.cc b/extensions/lists_functions.cc index 10bc717ed..bfe05d887 100644 --- a/extensions/lists_functions.cc +++ b/extensions/lists_functions.cc @@ -454,7 +454,7 @@ Macro ListSortByMacro() { MakeMapComprehension(factory, factory.Copy(sortby_input_ident), std::move(key_ident), std::move(key_expr)); if (!map_compr.has_value()) { - return absl::nullopt; + return std::nullopt; } // Build the call expression: diff --git a/extensions/math_ext_macros.cc b/extensions/math_ext_macros.cc index a66720a60..08b163132 100644 --- a/extensions/math_ext_macros.cc +++ b/extensions/math_ext_macros.cc @@ -72,7 +72,7 @@ absl::optional CheckInvalidArgs(MacroExprFactory &factory, } } - return absl::nullopt; + return std::nullopt; } bool IsListLiteralWithValidArgs(const Expr &arg) { @@ -99,7 +99,7 @@ std::vector math_macros() { [](MacroExprFactory &factory, Expr &target, absl::Span arguments) -> absl::optional { if (!IsTargetNamespace(target)) { - return absl::nullopt; + return std::nullopt; } switch (arguments.size()) { @@ -143,7 +143,7 @@ std::vector math_macros() { [](MacroExprFactory &factory, Expr &target, absl::Span arguments) -> absl::optional { if (!IsTargetNamespace(target)) { - return absl::nullopt; + return std::nullopt; } switch (arguments.size()) { diff --git a/extensions/math_ext_test.cc b/extensions/math_ext_test.cc index ea9331970..ce05ae6ed 100644 --- a/extensions/math_ext_test.cc +++ b/extensions/math_ext_test.cc @@ -93,7 +93,7 @@ TestCase MinCase(CelValue v1, CelValue v2, CelValue result) { } TestCase MinCase(CelValue list, CelValue result) { - return TestCase{kMathMin, list, absl::nullopt, result}; + return TestCase{kMathMin, list, std::nullopt, result}; } TestCase MaxCase(CelValue v1, CelValue v2, CelValue result) { @@ -101,7 +101,7 @@ TestCase MaxCase(CelValue v1, CelValue v2, CelValue result) { } TestCase MaxCase(CelValue list, CelValue result) { - return TestCase{kMathMax, list, absl::nullopt, result}; + return TestCase{kMathMax, list, std::nullopt, result}; } struct MacroTestCase { diff --git a/extensions/proto_ext.cc b/extensions/proto_ext.cc index f38039002..48618f7ae 100644 --- a/extensions/proto_ext.cc +++ b/extensions/proto_ext.cc @@ -45,11 +45,11 @@ absl::optional ValidateExtensionIdentifier(const Expr& expr) { absl::Overload( [](const SelectExpr& select_expr) -> absl::optional { if (select_expr.test_only()) { - return absl::nullopt; + return std::nullopt; } auto op_name = ValidateExtensionIdentifier(select_expr.operand()); if (!op_name.has_value()) { - return absl::nullopt; + return std::nullopt; } return absl::StrCat(*op_name, ".", select_expr.field()); }, @@ -57,7 +57,7 @@ absl::optional ValidateExtensionIdentifier(const Expr& expr) { return ident_expr.name(); }, [](const auto&) -> absl::optional { - return absl::nullopt; + return std::nullopt; }), expr.kind()); } @@ -68,7 +68,7 @@ absl::optional GetExtensionFieldName(const Expr& expr) { select_expr) { return ValidateExtensionIdentifier(expr); } - return absl::nullopt; + return std::nullopt; } bool IsExtensionCall(const Expr& target) { @@ -95,7 +95,7 @@ std::vector proto_macros() { [](MacroExprFactory& factory, Expr& target, absl::Span arguments) -> absl::optional { if (!IsExtensionCall(target)) { - return absl::nullopt; + return std::nullopt; } auto extFieldName = GetExtensionFieldName(arguments[1]); if (!extFieldName.has_value()) { @@ -109,7 +109,7 @@ std::vector proto_macros() { [](MacroExprFactory& factory, Expr& target, absl::Span arguments) -> absl::optional { if (!IsExtensionCall(target)) { - return absl::nullopt; + return std::nullopt; } auto extFieldName = GetExtensionFieldName(arguments[1]); if (!extFieldName.has_value()) { From 7641fac87020aae83f71cca82de1d7e759672b32 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Thu, 18 Jun 2026 23:58:02 -0700 Subject: [PATCH 82/87] No public description PiperOrigin-RevId: 934760042 --- eval/eval/attribute_utility.cc | 2 +- eval/eval/container_access_step.cc | 2 +- eval/eval/function_step.cc | 4 ++-- eval/eval/select_step.cc | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/eval/eval/attribute_utility.cc b/eval/eval/attribute_utility.cc index 117516caf..1e044627e 100644 --- a/eval/eval/attribute_utility.cc +++ b/eval/eval/attribute_utility.cc @@ -123,7 +123,7 @@ absl::optional AttributeUtility::MergeUnknowns( } if (!result_set.has_value()) { - return absl::nullopt; + return std::nullopt; } return UnknownValue(cel::Unknown(result_set->unknown_attributes(), diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index fda51e34f..4cf4ebf4d 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -55,7 +55,7 @@ absl::optional CelNumberFromValue(const Value& value) { case ValueKind::kDouble: return Number::FromDouble(value.GetDouble().NativeValue()); default: - return absl::nullopt; + return std::nullopt; } } diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index fcf429378..12c5af8a7 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -291,7 +291,7 @@ absl::StatusOr ResolveStatic( return overload; } } - return absl::nullopt; + return std::nullopt; } absl::StatusOr ResolveLazy( @@ -299,7 +299,7 @@ absl::StatusOr ResolveLazy( bool receiver_style, absl::Span providers, const ExecutionFrameBase& frame) { - ResolveResult result = absl::nullopt; + ResolveResult result = std::nullopt; std::vector arg_types(input_args.size()); diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 420f3ac31..b815f5d87 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -72,7 +72,7 @@ absl::optional CheckForMarkedAttributes(const AttributeTrail& trail, return cel::ErrorValue(std::move(result).status()); } - return absl::nullopt; + return std::nullopt; } void TestOnlySelect(const StructValue& msg, const std::string& field, From aea9b2adcddea6595548b65ad4e8a70d9e4de04c Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 19 Jun 2026 03:50:29 -0700 Subject: [PATCH 83/87] No public description PiperOrigin-RevId: 934844207 --- eval/compiler/flat_expr_builder.cc | 2 +- eval/compiler/flat_expr_builder_extensions.cc | 6 +++--- eval/compiler/qualified_reference_resolver.cc | 10 +++++----- eval/compiler/regex_precompilation_optimization.cc | 4 ++-- eval/compiler/resolver.cc | 4 ++-- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index fc6d87b16..aa9a8858c 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -1078,7 +1078,7 @@ class FlatExprVisitor : public cel::AstVisitor { // eligible for recursion, or nullopt if it is not. std::optional RecursionEligible() { if (!PlanRecursiveProgram() || program_builder_.current() == nullptr) { - return absl::nullopt; + return std::nullopt; } return program_builder_.current()->RecursiveDependencyDepth(); } diff --git a/eval/compiler/flat_expr_builder_extensions.cc b/eval/compiler/flat_expr_builder_extensions.cc index e51b64023..ee106ff4a 100644 --- a/eval/compiler/flat_expr_builder_extensions.cc +++ b/eval/compiler/flat_expr_builder_extensions.cc @@ -102,15 +102,15 @@ std::optional Subexpression::RecursiveDependencyDepth() const { auto* tree = absl::get_if(&program_); int depth = 0; if (tree == nullptr) { - return absl::nullopt; + return std::nullopt; } for (const auto& element : *tree) { auto* subexpression = absl::get_if(&element); if (subexpression == nullptr) { - return absl::nullopt; + return std::nullopt; } if (!(*subexpression)->IsRecursive()) { - return absl::nullopt; + return std::nullopt; } depth = std::max(depth, (*subexpression)->recursive_program().depth); } diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 67c14d9b2..158e492be 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -99,7 +99,7 @@ std::optional BestOverloadMatch(const Resolver& resolver, return *name; } } - return absl::nullopt; + return std::nullopt; } // Rewriter visitor for resolving references. @@ -267,22 +267,22 @@ class ReferenceResolver : public cel::AstRewriterBase { if (rewritten_reference_.find(expr.id()) != rewritten_reference_.end()) { // The target expr matches a reference (resolved to an ident decl). // This should not be treated as a function qualifier. - return absl::nullopt; + return std::nullopt; } if (expr.has_ident_expr()) { return expr.ident_expr().name(); } else if (expr.has_select_expr()) { if (expr.select_expr().test_only()) { - return absl::nullopt; + return std::nullopt; } maybe_parent_namespace = ToNamespace(expr.select_expr().operand()); if (!maybe_parent_namespace.has_value()) { - return absl::nullopt; + return std::nullopt; } return absl::StrCat(*maybe_parent_namespace, ".", expr.select_expr().field()); } else { - return absl::nullopt; + return std::nullopt; } } diff --git a/eval/compiler/regex_precompilation_optimization.cc b/eval/compiler/regex_precompilation_optimization.cc index 455796131..38ef842b9 100644 --- a/eval/compiler/regex_precompilation_optimization.cc +++ b/eval/compiler/regex_precompilation_optimization.cc @@ -178,7 +178,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { if (subexpression == nullptr || subexpression->IsFlattened()) { // Already modified, can't recover the input pattern. - return absl::nullopt; + return std::nullopt; } std::optional constant; if (subexpression->IsRecursive()) { @@ -206,7 +206,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { return Cast(*constant).ToString(); } - return absl::nullopt; + return std::nullopt; } absl::Status RewritePlan( diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index 17f60eaad..cca72964a 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -128,7 +128,7 @@ std::optional Resolver::FindConstant(absl::string_view name, return TypeValue(**type_value); } } - return absl::nullopt; + return std::nullopt; } std::vector Resolver::FindOverloads( @@ -216,7 +216,7 @@ Resolver::FindType(absl::string_view name, int64_t expr_id) const { return std::make_pair(std::move(qualified_name), std::move(*maybe_type)); } } - return absl::nullopt; + return std::nullopt; } } // namespace google::api::expr::runtime From e54b8bd8322454b485852227512a6b458a26f4a8 Mon Sep 17 00:00:00 2001 From: Justin King Date: Fri, 19 Jun 2026 11:29:23 -0700 Subject: [PATCH 84/87] Actually reject non-simple variable names in comprehensions PiperOrigin-RevId: 934996645 --- extensions/comprehensions_v2_macros.cc | 43 +++++++++++++++----------- parser/macro.cc | 22 ++++++++----- parser/parser_test.cc | 21 +++++++++++++ 3 files changed, 60 insertions(+), 26 deletions(-) diff --git a/extensions/comprehensions_v2_macros.cc b/extensions/comprehensions_v2_macros.cc index 134fb80ff..a054626f9 100644 --- a/extensions/comprehensions_v2_macros.cc +++ b/extensions/comprehensions_v2_macros.cc @@ -14,12 +14,14 @@ #include "extensions/comprehensions_v2_macros.h" +#include #include #include #include "absl/base/no_destructor.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -38,16 +40,21 @@ namespace { using ::google::api::expr::common::CelOperator; +bool IsSimpleIdentifier(const Expr& expr) { + return expr.has_ident_expr() && !expr.ident_expr().name().empty() && + !absl::StartsWith(expr.ident_expr().name(), "."); +} + absl::optional ExpandAllMacro2(MacroExprFactory& factory, Expr& target, absl::Span args) { if (args.size() != 3) { return factory.ReportError("all() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "all() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "all() second variable name must be a simple identifier"); } @@ -89,11 +96,11 @@ absl::optional ExpandExistsMacro2(MacroExprFactory& factory, Expr& target, if (args.size() != 3) { return factory.ReportError("exists() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "exists() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "exists() second variable name must be a simple identifier"); } @@ -138,11 +145,11 @@ absl::optional ExpandExistsOneMacro2(MacroExprFactory& factory, if (args.size() != 3) { return factory.ReportError("existsOne() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "existsOne() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "existsOne() second variable name must be a simple identifier"); @@ -190,12 +197,12 @@ absl::optional ExpandTransformList3Macro(MacroExprFactory& factory, if (args.size() != 3) { return factory.ReportError("transformList() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformList() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformList() second variable name must be a simple identifier"); @@ -239,12 +246,12 @@ absl::optional ExpandTransformList4Macro(MacroExprFactory& factory, if (args.size() != 4) { return factory.ReportError("transformList() requires 4 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformList() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformList() second variable name must be a simple identifier"); @@ -290,12 +297,12 @@ absl::optional ExpandTransformMap3Macro(MacroExprFactory& factory, if (args.size() != 3) { return factory.ReportError("transformMap() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformMap() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformMap() second variable name must be a simple identifier"); @@ -338,12 +345,12 @@ absl::optional ExpandTransformMap4Macro(MacroExprFactory& factory, if (args.size() != 4) { return factory.ReportError("transformMap() requires 4 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformMap() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformMap() second variable name must be a simple identifier"); @@ -388,12 +395,12 @@ absl::optional ExpandTransformMapEntry3Macro(MacroExprFactory& factory, if (args.size() != 3) { return factory.ReportError("transformMapEntry() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformMapEntry() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformMapEntry() second variable name must be a simple identifier"); @@ -438,12 +445,12 @@ absl::optional ExpandTransformMapEntry4Macro(MacroExprFactory& factory, if (args.size() != 4) { return factory.ReportError("transformMapEntry() requires 4 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "transformMapEntry() first variable name must be a simple identifier"); } - if (!args[1].has_ident_expr() || args[1].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[1])) { return factory.ReportErrorAt( args[1], "transformMapEntry() second variable name must be a simple identifier"); diff --git a/parser/macro.cc b/parser/macro.cc index 8f8c9e596..815b07401 100644 --- a/parser/macro.cc +++ b/parser/macro.cc @@ -25,6 +25,7 @@ #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -40,6 +41,11 @@ namespace { using google::api::expr::common::CelOperator; +bool IsSimpleIdentifier(const Expr& expr) { + return expr.has_ident_expr() && !expr.ident_expr().name().empty() && + !absl::StartsWith(expr.ident_expr().name(), "."); +} + inline MacroExpander ToMacroExpander(GlobalMacroExpander expander) { ABSL_DCHECK(expander); return [expander = std::move(expander)]( @@ -87,7 +93,7 @@ absl::optional ExpandAllMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("all() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "all() variable name must be a simple identifier"); } @@ -119,7 +125,7 @@ absl::optional ExpandExistsMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("exists() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "exists() variable name must be a simple identifier"); } @@ -153,7 +159,7 @@ absl::optional ExpandExistsOneMacro(MacroExprFactory& factory, if (args.size() != 2) { return factory.ReportError("exists_one() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "exists_one() variable name must be a simple identifier"); } @@ -192,7 +198,7 @@ absl::optional ExpandMap2Macro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("map() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } @@ -225,7 +231,7 @@ absl::optional ExpandMap3Macro(MacroExprFactory& factory, Expr& target, if (args.size() != 3) { return factory.ReportError("map() requires 3 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "map() variable name must be a simple identifier"); } @@ -260,7 +266,7 @@ absl::optional ExpandFilterMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("filter() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "filter() variable name must be a simple identifier"); } @@ -298,7 +304,7 @@ absl::optional ExpandOptMapMacro(MacroExprFactory& factory, Expr& target, if (args.size() != 2) { return factory.ReportError("optMap() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "optMap() variable name must be a simple identifier"); } @@ -337,7 +343,7 @@ absl::optional ExpandOptFlatMapMacro(MacroExprFactory& factory, if (args.size() != 2) { return factory.ReportError("optFlatMap() requires 2 arguments"); } - if (!args[0].has_ident_expr() || args[0].ident_expr().name().empty()) { + if (!IsSimpleIdentifier(args[0])) { return factory.ReportErrorAt( args[0], "optFlatMap() variable name must be a simple identifier"); } diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 1add80f84..35f11b413 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -631,6 +631,27 @@ std::vector test_cases = { "ERROR: :1:7: all() variable name must be a simple identifier\n" " | 1.all(2, 3)\n" " | ......^"}, + {"[].all(.x, x)", "", + "ERROR: :1:9: all() variable name must be a simple identifier\n" + " | [].all(.x, x)\n" + " | ........^"}, + {"[].exists(.x, x)", "", + "ERROR: :1:12: exists() variable name must be a simple identifier\n" + " | [].exists(.x, x)\n" + " | ...........^"}, + {"[].exists_one(.x, x)", "", + "ERROR: :1:16: exists_one() variable name must be a simple " + "identifier\n" + " | [].exists_one(.x, x)\n" + " | ...............^"}, + {"[].map(.x, x, x)", "", + "ERROR: :1:9: map() variable name must be a simple identifier\n" + " | [].map(.x, x, x)\n" + " | ........^"}, + {"[].filter(.x, x)", "", + "ERROR: :1:12: filter() variable name must be a simple identifier\n" + " | [].filter(.x, x)\n" + " | ...........^"}, {"x[\"a\"].single_int32 == 23", "_==_(\n" " _[_](\n" From 16074dcc55fddb53294aacd34afddd3db4a3138a Mon Sep 17 00:00:00 2001 From: Clayton Knittel Date: Sun, 21 Jun 2026 00:27:48 -0700 Subject: [PATCH 85/87] No public description PiperOrigin-RevId: 935529310 --- eval/public/message_wrapper_test.cc | 3 +-- .../structs/proto_message_type_adapter_test.cc | 4 +++- eval/public/testing/matchers.cc | 4 +--- internal/proto_matchers.h | 13 ++++++------- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/eval/public/message_wrapper_test.cc b/eval/public/message_wrapper_test.cc index ff0e691ab..15e5e88da 100644 --- a/eval/public/message_wrapper_test.cc +++ b/eval/public/message_wrapper_test.cc @@ -18,7 +18,6 @@ #include "eval/public/structs/trivial_legacy_type_info.h" #include "eval/testutil/test_message.pb.h" -#include "internal/casts.h" #include "internal/testing.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" @@ -60,7 +59,7 @@ TEST(MessageWrapperBuilder, Builder) { static_cast(&test_message)); auto mutable_message = - cel::internal::down_cast(builder.message_ptr()); + google::protobuf::DownCastMessage(builder.message_ptr()); mutable_message->set_int64_value(20); mutable_message->set_double_value(12.3); diff --git a/eval/public/structs/proto_message_type_adapter_test.cc b/eval/public/structs/proto_message_type_adapter_test.cc index 32608bc3f..270fd3ce1 100644 --- a/eval/public/structs/proto_message_type_adapter_test.cc +++ b/eval/public/structs/proto_message_type_adapter_test.cc @@ -36,6 +36,7 @@ #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" namespace google::api::expr::runtime { namespace { @@ -725,7 +726,8 @@ TEST(ProtoMesssageTypeAdapter, TypeInfoMutator) { ASSERT_OK_AND_ASSIGN(MessageWrapper::Builder builder, api->NewInstance(manager)); - EXPECT_NE(dynamic_cast(builder.message_ptr()), nullptr); + EXPECT_NE(google::protobuf::DynamicCastMessage(builder.message_ptr()), + nullptr); } TEST(ProtoMesssageTypeAdapter, TypeInfoAccesor) { diff --git a/eval/public/testing/matchers.cc b/eval/public/testing/matchers.cc index f79071fce..4f728c730 100644 --- a/eval/public/testing/matchers.cc +++ b/eval/public/testing/matchers.cc @@ -7,7 +7,6 @@ #include "absl/strings/string_view.h" #include "eval/public/cel_value.h" #include "eval/public/set_util.h" -#include "internal/casts.h" #include "internal/testing.h" #include "google/protobuf/message.h" @@ -76,8 +75,7 @@ class CelValueMatcherImpl CelValue::MessageWrapper arg; return v.GetValue(&arg) && arg.HasFullProto() && underlying_type_matcher_.Matches( - cel::internal::down_cast( - arg.message_ptr())); + google::protobuf::DownCastMessage(arg.message_ptr())); } void DescribeTo(std::ostream* os) const override { diff --git a/internal/proto_matchers.h b/internal/proto_matchers.h index 76d844036..02250634b 100644 --- a/internal/proto_matchers.h +++ b/internal/proto_matchers.h @@ -21,7 +21,6 @@ #include "absl/log/absl_check.h" #include "absl/memory/memory.h" -#include "internal/casts.h" #include "internal/testing.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" @@ -43,13 +42,13 @@ class TextProtoMatcher { bool MatchAndExplain(const google::protobuf::MessageLite& p, ::testing::MatchResultListener* listener) const { - return MatchAndExplain(cel::internal::down_cast(p), + return MatchAndExplain(google::protobuf::DownCastMessage(p), listener); } bool MatchAndExplain(const google::protobuf::MessageLite* p, ::testing::MatchResultListener* listener) const { - return MatchAndExplain(cel::internal::down_cast(p), + return MatchAndExplain(google::protobuf::DownCastMessage(p), listener); } @@ -58,7 +57,7 @@ class TextProtoMatcher { auto message = absl::WrapUnique(p.New()); ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(expected_, message.get())); return google::protobuf::util::MessageDifferencer::Equals( - *message, cel::internal::down_cast(p)); + *message, google::protobuf::DownCastMessage(p)); } bool MatchAndExplain(const google::protobuf::Message* p, @@ -66,7 +65,7 @@ class TextProtoMatcher { auto message = absl::WrapUnique(p->New()); ABSL_CHECK(google::protobuf::TextFormat::ParseFromString(expected_, message.get())); return google::protobuf::util::MessageDifferencer::Equals( - *message, cel::internal::down_cast(*p)); + *message, google::protobuf::DownCastMessage(*p)); } inline void DescribeTo(::std::ostream* os) const { *os << expected_; } @@ -93,13 +92,13 @@ class ProtoMatcher { bool MatchAndExplain(const google::protobuf::MessageLite& p, ::testing::MatchResultListener* listener) const { - return MatchAndExplain(cel::internal::down_cast(p), + return MatchAndExplain(google::protobuf::DownCastMessage(p), listener); } bool MatchAndExplain(const google::protobuf::MessageLite* p, ::testing::MatchResultListener* listener) const { - return MatchAndExplain(cel::internal::down_cast(p), + return MatchAndExplain(google::protobuf::DownCastMessage(p), listener); } From a32e1418a8385c4b9c1ebe3b5072aaa8fbdf23ce Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 22 Jun 2026 11:42:04 -0700 Subject: [PATCH 86/87] When generating YAML from config, skip overload_id if identical to signature PiperOrigin-RevId: 936168430 --- env/env_yaml.cc | 21 +++++++++++++------ env/env_yaml_test.cc | 50 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 6 deletions(-) diff --git a/env/env_yaml.cc b/env/env_yaml.cc index e7b8a7885..d5e3ad059 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -725,6 +725,9 @@ absl::StatusOr ParseFunctionOverloadConfig( function_name, "\"")); } overload_config.is_member_function = parsed_signature.is_member; + if (overload_config.overload_id.empty()) { + overload_config.overload_id = signature; + } if (!parsed_signature.signature_type.has_function()) { return absl::InternalError(absl::StrCat( "Function overload signature has no function type: ", signature)); @@ -1101,11 +1104,8 @@ void EmitFunctionOverloadConfig( const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out, const EnvConfigToYamlOptions& options) { out << YAML::BeginMap; - if (!overload_config.overload_id.empty()) { - out << YAML::Key << "id"; - out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; - } bool signature_generated = false; + std::string signature_str; if (options.use_type_signatures) { bool param_type_spec_generated = true; std::vector params; @@ -1123,12 +1123,21 @@ void EmitFunctionOverloadConfig( common_internal::MakeOverloadSignature( function_name, params, overload_config.is_member_function); if (signature.ok()) { - out << YAML::Key << "signature"; - out << YAML::Value << YAML::DoubleQuoted << *signature; + signature_str = std::move(*signature); signature_generated = true; } } } + if (!overload_config.overload_id.empty()) { + if (!signature_generated || overload_config.overload_id != signature_str) { + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; + } + } + if (signature_generated) { + out << YAML::Key << "signature"; + out << YAML::Value << YAML::DoubleQuoted << signature_str; + } if (!signature_generated) { if (overload_config.is_member_function) { out << YAML::Key << "target" << YAML::Value; diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index 38f08e371..c5bd1b787 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -545,12 +545,15 @@ std::vector GetParseFunctionTestCases() { .overload_configs = { Config::FunctionOverloadConfig{ + .overload_id = + "google.protobuf.StringValue.isEmpty()", .examples = {"''.isEmpty() // true"}, .is_member_function = true, .parameters = {{.name = "string_wrapper"}}, .return_type = {.name = "bool"}, }, Config::FunctionOverloadConfig{ + .overload_id = "list<~T>.isEmpty()", .examples = {"[].isEmpty() // true", "[1].isEmpty() // false"}, .is_member_function = true, @@ -635,6 +638,7 @@ std::vector GetParseFunctionTestCases() { .overload_configs = { Config::FunctionOverloadConfig{ + .overload_id = "contains(list<~T>, ~T)", .examples = {"contains([1, 2, 3], 2) // true"}, .is_member_function = false, .parameters = @@ -1740,6 +1744,45 @@ std::vector GetExportTestCases() { - type_name: "int" )yaml", }, + ExportTestCase{ + .config = []() -> absl::StatusOr { + Config config; + CEL_RETURN_IF_ERROR(config.AddFunctionConfig( + {.name = "foo", + .overload_configs = { + {.overload_id = "timestamp.foo(A<~B>)", + .is_member_function = true, + .parameters = {{.name = "timestamp"}, + {.name = "A", + .params = {{.name = "B", + .is_type_param = true}}}}, + .return_type = {.name = "int"}}, + }})); + return config; + }(), + .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", + .expected_alt_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "timestamp.foo(A<~B>)" + target: + type_name: "timestamp" + args: + - type_name: "A" + params: + - type_name: "B" + is_type_param: true + return: + type_name: "int" + )yaml", + }, }; }; @@ -1888,6 +1931,13 @@ std::vector GetSignatureRoundTripTestCases() { signature: "foo(timestamp,A<~B>)" return: "list" )yaml", + R"yaml( + functions: + - name: "foo" + overloads: + - signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", }; } From 76ae0b3c1768d93a10270f904101de338867bdb1 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 22 Jun 2026 12:33:00 -0700 Subject: [PATCH 87/87] Move signature.h and signature.cc from common/internal to common This is done to make it visible to cel/python PiperOrigin-RevId: 936194518 --- checker/BUILD | 2 +- checker/type_checker_subset_factory.cc | 10 +++---- common/BUILD | 38 ++++++++++++++++++++++++- common/decl.cc | 5 ++-- common/internal/BUILD | 36 ----------------------- common/{internal => }/signature.cc | 6 ++-- common/{internal => }/signature.h | 10 +++---- common/{internal => }/signature_test.cc | 33 ++++++++++----------- env/BUILD | 4 +-- env/env.cc | 8 ++---- env/env_yaml.cc | 25 +++++++--------- 11 files changed, 83 insertions(+), 94 deletions(-) rename common/{internal => }/signature.cc (99%) rename common/{internal => }/signature.h (93%) rename common/{internal => }/signature_test.cc (96%) diff --git a/checker/BUILD b/checker/BUILD index efca3ff73..7f3ccfef7 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -231,7 +231,7 @@ cc_library( deps = [ ":type_checker_builder", "//common:decl", - "//common/internal:signature", + "//common:signature", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", diff --git a/checker/type_checker_subset_factory.cc b/checker/type_checker_subset_factory.cc index e5335e220..1b146c5a5 100644 --- a/checker/type_checker_subset_factory.cc +++ b/checker/type_checker_subset_factory.cc @@ -22,7 +22,7 @@ #include "absl/types/span.h" #include "checker/type_checker_builder.h" #include "common/decl.h" -#include "common/internal/signature.h" +#include "common/signature.h" namespace cel { @@ -33,8 +33,8 @@ TypeCheckerSubset::FunctionPredicate IncludeOverloadsByIdPredicate( if (overload_ids.contains(overload.id())) { return true; } - auto signature = common_internal::MakeOverloadSignature( - function, overload.args(), overload.member()); + auto signature = + MakeOverloadSignature(function, overload.args(), overload.member()); return signature.ok() && overload_ids.contains(*signature); }; } @@ -52,8 +52,8 @@ TypeCheckerSubset::FunctionPredicate ExcludeOverloadsByIdPredicate( if (overload_ids.contains(overload.id())) { return false; } - auto signature = common_internal::MakeOverloadSignature( - function, overload.args(), overload.member()); + auto signature = + MakeOverloadSignature(function, overload.args(), overload.member()); return !signature.ok() || !overload_ids.contains(*signature); }; } diff --git a/common/BUILD b/common/BUILD index 0bd3632dd..0426c0827 100644 --- a/common/BUILD +++ b/common/BUILD @@ -79,6 +79,42 @@ cc_test( ], ) +cc_library( + name = "signature", + srcs = ["signature.cc"], + hdrs = ["signature.h"], + deps = [ + ":ast", + ":type", + ":type_spec_resolver", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "signature_test", + srcs = ["signature_test.cc"], + deps = [ + ":ast", + ":signature", + ":type", + ":type_kind", + ":type_spec_resolver", + "//internal:testing", + "//internal:testing_descriptor_pool", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "expr", srcs = ["expr.cc"], @@ -145,9 +181,9 @@ cc_library( hdrs = ["decl.h"], deps = [ ":constant", + ":signature", ":type", ":type_kind", - "//common/internal:signature", "//internal:status_macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", diff --git a/common/decl.cc b/common/decl.cc index d2d50964a..858e6fb49 100644 --- a/common/decl.cc +++ b/common/decl.cc @@ -26,7 +26,7 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "common/internal/signature.h" +#include "common/signature.h" #include "common/type.h" #include "common/type_kind.h" @@ -117,8 +117,7 @@ void AddOverloadInternal(std::string_view function_name, } absl::StatusOr signature = - common_internal::MakeOverloadSignature(function_name, overload.args(), - overload.member()); + MakeOverloadSignature(function_name, overload.args(), overload.member()); if (!signature.ok()) { status = signature.status(); return; diff --git a/common/internal/BUILD b/common/internal/BUILD index b07faf229..3be350754 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -135,39 +135,3 @@ cc_library( "@com_google_protobuf//src/google/protobuf/io", ], ) - -cc_library( - name = "signature", - srcs = ["signature.cc"], - hdrs = ["signature.h"], - deps = [ - "//common:ast", - "//common:type", - "//common:type_spec_resolver", - "//internal:status_macros", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_protobuf//:protobuf", - ], -) - -cc_test( - name = "signature_test", - srcs = ["signature_test.cc"], - deps = [ - ":signature", - "//common:ast", - "//common:type", - "//common:type_kind", - "//common:type_spec_resolver", - "//internal:testing", - "//internal:testing_descriptor_pool", - "@com_google_absl//absl/base:no_destructor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:status_matchers", - "@com_google_absl//absl/status:statusor", - "@com_google_protobuf//:protobuf", - ], -) diff --git a/common/internal/signature.cc b/common/signature.cc similarity index 99% rename from common/internal/signature.cc rename to common/signature.cc index fe315bb04..e497e780d 100644 --- a/common/internal/signature.cc +++ b/common/signature.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "common/internal/signature.h" +#include "common/signature.h" #include #include @@ -34,7 +34,7 @@ #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" -namespace cel::common_internal { +namespace cel { // Signature generator helper functions. namespace { @@ -637,4 +637,4 @@ absl::StatusOr ParseType(std::string_view signature, google::protobuf::Are return cel::ConvertTypeSpecToType(type_spec, arena, pool); } -} // namespace cel::common_internal +} // namespace cel diff --git a/common/internal/signature.h b/common/signature.h similarity index 93% rename from common/internal/signature.h rename to common/signature.h index 8a44fbd5c..777f03439 100644 --- a/common/internal/signature.h +++ b/common/signature.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ -#define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ +#ifndef THIRD_PARTY_CEL_CPP_COMMON_SIGNATURE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_SIGNATURE_H_ #include #include @@ -25,7 +25,7 @@ #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" -namespace cel::common_internal { +namespace cel { // Generates a signature for a `cel::Type`, which is a string representation of // the type. @@ -96,6 +96,6 @@ struct ParsedFunctionOverload { absl::StatusOr ParseFunctionSignature( std::string_view signature); -} // namespace cel::common_internal +} // namespace cel -#endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_SIGNATURE_H_ +#endif // THIRD_PARTY_CEL_CPP_COMMON_SIGNATURE_H_ diff --git a/common/internal/signature_test.cc b/common/signature_test.cc similarity index 96% rename from common/internal/signature_test.cc rename to common/signature_test.cc index 17b628d88..ea51eb566 100644 --- a/common/internal/signature_test.cc +++ b/common/signature_test.cc @@ -1,4 +1,4 @@ -#include "common/internal/signature.h" +#include "common/signature.h" // Copyright 2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -30,7 +30,7 @@ #include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" -namespace cel::common_internal { +namespace cel { namespace { using ::absl_testing::IsOkAndHolds; @@ -77,8 +77,7 @@ using TypeSignatureTest = testing::TestWithParam; TEST_P(TypeSignatureTest, TypeSignature) { const auto& param = GetParam(); - absl::StatusOr signature = - common_internal::MakeTypeSpecSignature(param.type); + absl::StatusOr signature = MakeTypeSpecSignature(param.type); if (!param.expected_error.empty()) { EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(param.expected_error))); @@ -257,25 +256,24 @@ INSTANTIATE_TEST_SUITE_P(TypeSignatureTest, TypeSignatureTest, ValuesIn(GetTypeSignatureTestCases())); TEST(TypeSignatureTest, UnsupportedTypes) { - EXPECT_THAT(common_internal::MakeTypeSignature(UnknownType{}), + EXPECT_THAT(MakeTypeSignature(UnknownType{}), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Unsupported Type kind: *unknown*"))); - EXPECT_THAT(common_internal::MakeTypeSignature(ErrorType{}), + EXPECT_THAT(MakeTypeSignature(ErrorType{}), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Unsupported type in signature: *error*"))); - EXPECT_THAT(common_internal::MakeTypeSpecSignature( - TypeSpec(static_cast(999))), + EXPECT_THAT(MakeTypeSpecSignature(TypeSpec(static_cast(999))), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Unsupported primitive type"))); - EXPECT_THAT(common_internal::MakeTypeSpecSignature( - TypeSpec(static_cast(999))), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Unsupported well-known type"))); + EXPECT_THAT( + MakeTypeSpecSignature(TypeSpec(static_cast(999))), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unsupported well-known type"))); - EXPECT_THAT(common_internal::MakeTypeSpecSignature(TypeSpec( + EXPECT_THAT(MakeTypeSpecSignature(TypeSpec( PrimitiveTypeWrapper(static_cast(999)))), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Unsupported wrapper type"))); @@ -308,8 +306,7 @@ TEST_P(OverloadSignatureTest, OverloadSignature) { const auto& param = GetParam(); absl::StatusOr signature = - common_internal::MakeOverloadSignature(param.function_name, param.args, - param.is_member); + MakeOverloadSignature(param.function_name, param.args, param.is_member); if (!param.expected_error.empty()) { EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr(param.expected_error))); @@ -433,8 +430,8 @@ std::vector GetOverloadSignatureTestCases() { } TEST(OverloadSignatureTest, MemberFunctionNoReceiverError) { - auto signature = common_internal::MakeOverloadSignature( - "hello", std::vector{}, true); + auto signature = + MakeOverloadSignature("hello", std::vector{}, true); EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Member function with no receiver"))); @@ -784,4 +781,4 @@ TEST(OverloadSignatureTest, ArgumentTypeVector) { } } // namespace -} // namespace cel::common_internal +} // namespace cel diff --git a/env/BUILD b/env/BUILD index 1816238a5..0c17d6305 100644 --- a/env/BUILD +++ b/env/BUILD @@ -55,8 +55,8 @@ cc_library( "//common:constant", "//common:container", "//common:decl", + "//common:signature", "//common:type", - "//common/internal:signature", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", @@ -124,7 +124,7 @@ cc_library( ":config", "//common:ast", "//common:constant", - "//common/internal:signature", + "//common:signature", "//internal:status_macros", "//internal:strings", "@com_google_absl//absl/algorithm:container", diff --git a/env/env.cc b/env/env.cc index 4fa4e7398..85c5139da 100644 --- a/env/env.cc +++ b/env/env.cc @@ -26,7 +26,7 @@ #include "common/constant.h" #include "common/container.h" #include "common/decl.h" -#include "common/internal/signature.h" +#include "common/signature.h" #include "common/type.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" @@ -71,8 +71,7 @@ bool ShouldIncludeFunction(const Config::StandardLibraryConfig& config, return false; } absl::StatusOr signature = - common_internal::MakeOverloadSignature(function, overload.args(), - overload.member()); + MakeOverloadSignature(function, overload.args(), overload.member()); if (signature.ok() && config.excluded_functions.contains(std::make_pair( std::string(function), *std::move(signature)))) { return false; @@ -89,8 +88,7 @@ bool ShouldIncludeFunction(const Config::StandardLibraryConfig& config, // Ok to call MakeOverloadSignature() again, because in practice either // included or excluded functions may be specified, but not both. absl::StatusOr signature = - common_internal::MakeOverloadSignature(function, overload.args(), - overload.member()); + MakeOverloadSignature(function, overload.args(), overload.member()); if (signature.ok() && config.included_functions.contains(std::make_pair( std::string(function), *std::move(signature)))) { return true; diff --git a/env/env_yaml.cc b/env/env_yaml.cc index d5e3ad059..281cf3ff1 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -38,7 +38,7 @@ #include "absl/time/time.h" #include "common/ast.h" #include "common/constant.h" -#include "common/internal/signature.h" +#include "common/signature.h" #include "env/config.h" #include "env/type_info.h" #include "internal/status_macros.h" @@ -434,8 +434,7 @@ absl::StatusOr ParseTypeInfo(const YAML::Node& node, if (!type.IsScalar()) { return YamlError(yaml, type, "Node 'type' is not a string"); } - CEL_ASSIGN_OR_RETURN(auto type_spec, - common_internal::ParseTypeSpec(GetString(yaml, type))); + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSpec(GetString(yaml, type))); CEL_ASSIGN_OR_RETURN(auto type_config, TypeSpecToTypeInfo(type_spec)); return type_config; } @@ -714,9 +713,8 @@ absl::StatusOr ParseFunctionOverloadConfig( } std::string signature = GetString(yaml, signature_node); - CEL_ASSIGN_OR_RETURN( - common_internal::ParsedFunctionOverload parsed_signature, - common_internal::ParseFunctionSignature(signature)); + CEL_ASSIGN_OR_RETURN(ParsedFunctionOverload parsed_signature, + ParseFunctionSignature(signature)); if (parsed_signature.function_name != function_name) { return YamlError(yaml, signature_node, absl::StrCat("Function overload name \"", @@ -767,8 +765,8 @@ absl::StatusOr ParseFunctionOverloadConfig( const YAML::Node return_type = overload["return"]; if (return_type.IsDefined()) { if (return_type.IsScalar()) { - CEL_ASSIGN_OR_RETURN(auto type_spec, common_internal::ParseTypeSpec( - GetString(yaml, return_type))); + CEL_ASSIGN_OR_RETURN(auto type_spec, + ParseTypeSpec(GetString(yaml, return_type))); CEL_ASSIGN_OR_RETURN(overload_config.return_type, TypeSpecToTypeInfo(type_spec)); } else if (return_type.IsMap()) { @@ -993,8 +991,7 @@ void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out, if (options.use_type_signatures) { absl::StatusOr type_spec = TypeInfoToTypeSpec(type_info); if (type_spec.ok()) { - absl::StatusOr signature = - common_internal::MakeTypeSpecSignature(*type_spec); + absl::StatusOr signature = MakeTypeSpecSignature(*type_spec); if (signature.ok()) { out << YAML::Key << "type"; out << YAML::Value << YAML::DoubleQuoted << *signature; @@ -1119,9 +1116,8 @@ void EmitFunctionOverloadConfig( params.push_back(std::move(*type_spec)); } if (param_type_spec_generated) { - absl::StatusOr signature = - common_internal::MakeOverloadSignature( - function_name, params, overload_config.is_member_function); + absl::StatusOr signature = MakeOverloadSignature( + function_name, params, overload_config.is_member_function); if (signature.ok()) { signature_str = std::move(*signature); signature_generated = true; @@ -1177,8 +1173,7 @@ void EmitFunctionOverloadConfig( absl::StatusOr type_spec = TypeInfoToTypeSpec(overload_config.return_type); if (type_spec.ok()) { - absl::StatusOr signature = - common_internal::MakeTypeSpecSignature(*type_spec); + absl::StatusOr signature = MakeTypeSpecSignature(*type_spec); if (signature.ok()) { out << YAML::Key << "return"; out << YAML::Value << YAML::DoubleQuoted << *signature;