Skip to content

Commit fe53e95

Browse files
jnthntatumcopybara-github
authored andcommitted
Refactor comprehension planning in flat expr builder to use comprehension specific visitor callbacks.
No functional changes. PiperOrigin-RevId: 561377141
1 parent e6858ae commit fe53e95

1 file changed

Lines changed: 44 additions & 18 deletions

File tree

eval/compiler/flat_expr_builder.cc

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ using ::cel::TypeManager;
7575
using ::cel::Value;
7676
using ::cel::ValueFactory;
7777
using ::cel::ast_internal::AstImpl;
78+
using ::cel::ast_internal::AstTraverse;
7879

7980
constexpr int64_t kExprIdNotFromAst = -1;
8081

@@ -169,7 +170,7 @@ class ExhaustiveTernaryCondVisitor : public CondVisitor {
169170
};
170171

171172
// Visitor Comprehension expression.
172-
class ComprehensionVisitor : public CondVisitor {
173+
class ComprehensionVisitor {
173174
public:
174175
explicit ComprehensionVisitor(FlatExprVisitor* visitor, bool short_circuiting,
175176
bool enable_vulnerability_check)
@@ -179,9 +180,10 @@ class ComprehensionVisitor : public CondVisitor {
179180
short_circuiting_(short_circuiting),
180181
enable_vulnerability_check_(enable_vulnerability_check) {}
181182

182-
void PreVisit(const cel::ast_internal::Expr* expr) override;
183-
void PostVisitArg(int arg_num, const cel::ast_internal::Expr* expr) override;
184-
void PostVisit(const cel::ast_internal::Expr* expr) override;
183+
void PreVisit(const cel::ast_internal::Expr* expr);
184+
void PostVisitArg(cel::ast_internal::ComprehensionArg arg_num,
185+
const cel::ast_internal::Expr* comprehension_expr);
186+
void PostVisit(const cel::ast_internal::Expr* expr);
185187

186188
private:
187189
FlatExprVisitor* visitor_;
@@ -585,15 +587,13 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
585587
ValidateOrError(comprehension->has_result(),
586588
"Invalid comprehension: 'result' must be set");
587589
comprehension_stack_.push(
588-
{comprehension,
590+
{expr, comprehension,
589591
IsOptimizableListAppend(comprehension,
590-
options_.enable_comprehension_list_append)});
591-
cond_visitor_stack_.push(
592-
{expr, std::make_unique<ComprehensionVisitor>(
593-
this, options_.short_circuiting,
594-
enable_comprehension_vulnerability_check_)});
595-
auto cond_visitor = FindCondVisitor(expr);
596-
cond_visitor->PreVisit(expr);
592+
options_.enable_comprehension_list_append),
593+
std::make_unique<ComprehensionVisitor>(
594+
this, options_.short_circuiting,
595+
enable_comprehension_vulnerability_check_)});
596+
comprehension_stack_.top().visitor->PreVisit(expr);
597597
}
598598

599599
// Invoked after all child nodes are processed.
@@ -604,11 +604,32 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
604604
if (!progress_status_.ok()) {
605605
return;
606606
}
607+
608+
if (comprehension_stack_.empty() ||
609+
comprehension_stack_.top().comprehension != comprehension_expr) {
610+
return;
611+
}
612+
613+
comprehension_stack_.top().visitor->PostVisit(expr);
607614
comprehension_stack_.pop();
615+
}
608616

609-
auto cond_visitor = FindCondVisitor(expr);
610-
cond_visitor->PostVisit(expr);
611-
cond_visitor_stack_.pop();
617+
void PostVisitComprehensionSubexpression(
618+
const cel::ast_internal::Expr* subexpr,
619+
const cel::ast_internal::Comprehension* compr,
620+
cel::ast_internal::ComprehensionArg comprehension_arg,
621+
const cel::ast_internal::SourcePosition*) override {
622+
if (!progress_status_.ok()) {
623+
return;
624+
}
625+
626+
if (comprehension_stack_.empty() ||
627+
comprehension_stack_.top().comprehension != compr) {
628+
return;
629+
}
630+
631+
comprehension_stack_.top().visitor->PostVisitArg(
632+
comprehension_arg, comprehension_stack_.top().expr);
612633
}
613634

614635
// Invoked after each argument node processed.
@@ -739,8 +760,10 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
739760

740761
private:
741762
struct ComprehensionStackRecord {
763+
const cel::ast_internal::Expr* expr;
742764
const cel::ast_internal::Comprehension* comprehension;
743765
bool is_optimizable_list_append;
766+
std::unique_ptr<ComprehensionVisitor> visitor;
744767
};
745768

746769
const Resolver& resolver_;
@@ -1089,8 +1112,9 @@ void ComprehensionVisitor::PreVisit(const cel::ast_internal::Expr*) {
10891112
kExprIdNotFromAst, false));
10901113
}
10911114

1092-
void ComprehensionVisitor::PostVisitArg(int arg_num,
1093-
const cel::ast_internal::Expr* expr) {
1115+
void ComprehensionVisitor::PostVisitArg(
1116+
cel::ast_internal::ComprehensionArg arg_num,
1117+
const cel::ast_internal::Expr* expr) {
10941118
const auto* comprehension = &expr->comprehension_expr();
10951119
const auto& accu_var = comprehension->accu_var();
10961120
const auto& iter_var = comprehension->iter_var();
@@ -1204,7 +1228,9 @@ absl::StatusOr<FlatExpression> FlatExprBuilder::CreateExpressionImpl(
12041228
ast_impl.reference_map(), execution_path, value_factory, warnings_builder,
12051229
program_tree, extension_context);
12061230

1207-
AstTraverse(&ast_impl.root_expr(), &ast_impl.source_info(), &visitor);
1231+
cel::ast_internal::TraversalOptions opts;
1232+
opts.use_comprehension_callbacks = true;
1233+
AstTraverse(&ast_impl.root_expr(), &ast_impl.source_info(), &visitor, opts);
12081234

12091235
if (!visitor.progress_status().ok()) {
12101236
return visitor.progress_status();

0 commit comments

Comments
 (0)