@@ -75,6 +75,7 @@ using ::cel::TypeManager;
7575using ::cel::Value;
7676using ::cel::ValueFactory;
7777using ::cel::ast_internal::AstImpl;
78+ using ::cel::ast_internal::AstTraverse;
7879
7980constexpr int64_t kExprIdNotFromAst = -1 ;
8081
@@ -169,7 +170,7 @@ class ExhaustiveTernaryCondVisitor : public CondVisitor {
169170};
170171
171172// Visitor Comprehension expression.
172- class ComprehensionVisitor : public CondVisitor {
173+ class ComprehensionVisitor {
173174 public:
174175 explicit ComprehensionVisitor (FlatExprVisitor* visitor, bool short_circuiting,
175176 bool enable_vulnerability_check)
@@ -179,9 +180,10 @@ class ComprehensionVisitor : public CondVisitor {
179180 short_circuiting_(short_circuiting),
180181 enable_vulnerability_check_(enable_vulnerability_check) {}
181182
182- void PreVisit (const cel::ast_internal::Expr* expr) override ;
183- void PostVisitArg (int arg_num, const cel::ast_internal::Expr* expr) override ;
184- void PostVisit (const cel::ast_internal::Expr* expr) override ;
183+ void PreVisit (const cel::ast_internal::Expr* expr);
184+ void PostVisitArg (cel::ast_internal::ComprehensionArg arg_num,
185+ const cel::ast_internal::Expr* comprehension_expr);
186+ void PostVisit (const cel::ast_internal::Expr* expr);
185187
186188 private:
187189 FlatExprVisitor* visitor_;
@@ -585,15 +587,13 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
585587 ValidateOrError (comprehension->has_result (),
586588 " Invalid comprehension: 'result' must be set" );
587589 comprehension_stack_.push (
588- {comprehension,
590+ {expr, comprehension,
589591 IsOptimizableListAppend (comprehension,
590- options_.enable_comprehension_list_append )});
591- cond_visitor_stack_.push (
592- {expr, std::make_unique<ComprehensionVisitor>(
593- this , options_.short_circuiting ,
594- enable_comprehension_vulnerability_check_)});
595- auto cond_visitor = FindCondVisitor (expr);
596- cond_visitor->PreVisit (expr);
592+ options_.enable_comprehension_list_append ),
593+ std::make_unique<ComprehensionVisitor>(
594+ this , options_.short_circuiting ,
595+ enable_comprehension_vulnerability_check_)});
596+ comprehension_stack_.top ().visitor ->PreVisit (expr);
597597 }
598598
599599 // Invoked after all child nodes are processed.
@@ -604,11 +604,32 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
604604 if (!progress_status_.ok ()) {
605605 return ;
606606 }
607+
608+ if (comprehension_stack_.empty () ||
609+ comprehension_stack_.top ().comprehension != comprehension_expr) {
610+ return ;
611+ }
612+
613+ comprehension_stack_.top ().visitor ->PostVisit (expr);
607614 comprehension_stack_.pop ();
615+ }
608616
609- auto cond_visitor = FindCondVisitor (expr);
610- cond_visitor->PostVisit (expr);
611- cond_visitor_stack_.pop ();
617+ void PostVisitComprehensionSubexpression (
618+ const cel::ast_internal::Expr* subexpr,
619+ const cel::ast_internal::Comprehension* compr,
620+ cel::ast_internal::ComprehensionArg comprehension_arg,
621+ const cel::ast_internal::SourcePosition*) override {
622+ if (!progress_status_.ok ()) {
623+ return ;
624+ }
625+
626+ if (comprehension_stack_.empty () ||
627+ comprehension_stack_.top ().comprehension != compr) {
628+ return ;
629+ }
630+
631+ comprehension_stack_.top ().visitor ->PostVisitArg (
632+ comprehension_arg, comprehension_stack_.top ().expr );
612633 }
613634
614635 // Invoked after each argument node processed.
@@ -739,8 +760,10 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
739760
740761 private:
741762 struct ComprehensionStackRecord {
763+ const cel::ast_internal::Expr* expr;
742764 const cel::ast_internal::Comprehension* comprehension;
743765 bool is_optimizable_list_append;
766+ std::unique_ptr<ComprehensionVisitor> visitor;
744767 };
745768
746769 const Resolver& resolver_;
@@ -1089,8 +1112,9 @@ void ComprehensionVisitor::PreVisit(const cel::ast_internal::Expr*) {
10891112 kExprIdNotFromAst , false ));
10901113}
10911114
1092- void ComprehensionVisitor::PostVisitArg (int arg_num,
1093- const cel::ast_internal::Expr* expr) {
1115+ void ComprehensionVisitor::PostVisitArg (
1116+ cel::ast_internal::ComprehensionArg arg_num,
1117+ const cel::ast_internal::Expr* expr) {
10941118 const auto * comprehension = &expr->comprehension_expr ();
10951119 const auto & accu_var = comprehension->accu_var ();
10961120 const auto & iter_var = comprehension->iter_var ();
@@ -1204,7 +1228,9 @@ absl::StatusOr<FlatExpression> FlatExprBuilder::CreateExpressionImpl(
12041228 ast_impl.reference_map (), execution_path, value_factory, warnings_builder,
12051229 program_tree, extension_context);
12061230
1207- AstTraverse (&ast_impl.root_expr (), &ast_impl.source_info (), &visitor);
1231+ cel::ast_internal::TraversalOptions opts;
1232+ opts.use_comprehension_callbacks = true ;
1233+ AstTraverse (&ast_impl.root_expr (), &ast_impl.source_info (), &visitor, opts);
12081234
12091235 if (!visitor.progress_status ().ok ()) {
12101236 return visitor.progress_status ();
0 commit comments