1616
1717import static com.google.common.base.Preconditions.checkNotNull;
1818import static com.google.common.collect.ImmutableList.toImmutableList;
19+ import static java.util.stream.Collectors.toCollection;
1920
2021import com.google.auto.value.AutoValue;
21- import com.google.common.collect.ImmutableList;
2222import com.google.common.collect.Lists;
2323import dev.cel.bundle.Cel;
2424import dev.cel.common.CelAbstractSyntaxTree;
2525import dev.cel.common.CelMutableAst;
26+ import dev.cel.common.CelValidationException;
2627import dev.cel.common.ast.CelConstant.Kind;
2728import dev.cel.extensions.CelOptionalLibrary.Function;
2829import dev.cel.optimizer.AstMutator;
2930import dev.cel.optimizer.CelAstOptimizer;
3031import dev.cel.parser.Operator;
3132import dev.cel.policy.CelCompiledRule.CelCompiledMatch;
33+ import dev.cel.policy.CelCompiledRule.CelCompiledMatch.OutputValue;
3234import dev.cel.policy.CelCompiledRule.CelCompiledVariable;
35+ import java.util.ArrayList;
36+ import java.util.Arrays;
37+ import java.util.List;
3338
3439/** Package-private class for composing various rules into a single expression using optimizer. */
3540final class RuleComposer implements CelAstOptimizer {
@@ -39,13 +44,8 @@ final class RuleComposer implements CelAstOptimizer {
3944
4045 @Override
4146 public OptimizationResult optimize(CelAbstractSyntaxTree ast, Cel cel) {
42- RuleOptimizationResult result = optimizeRule(compiledRule);
43- return OptimizationResult.create(
44- result.ast().toParsedAst(),
45- compiledRule.variables().stream()
46- .map(CelCompiledVariable::celVarDecl)
47- .collect(toImmutableList()),
48- ImmutableList.of());
47+ RuleOptimizationResult result = optimizeRule(cel, compiledRule);
48+ return OptimizationResult.create(result.ast().toParsedAst());
4949 }
5050
5151 @AutoValue
@@ -59,20 +59,33 @@ static RuleOptimizationResult create(CelMutableAst ast, boolean isOptionalResult
5959 }
6060 }
6161
62- private RuleOptimizationResult optimizeRule(CelCompiledRule compiledRule) {
62+ private RuleOptimizationResult optimizeRule(Cel cel, CelCompiledRule compiledRule) {
63+ cel =
64+ cel.toCelBuilder()
65+ .addVarDeclarations(
66+ compiledRule.variables().stream()
67+ .map(CelCompiledVariable::celVarDecl)
68+ .collect(toImmutableList()))
69+ .build();
70+
6371 CelMutableAst matchAst = astMutator.newGlobalCall(Function.OPTIONAL_NONE.getFunction());
6472 boolean isOptionalResult = true;
73+ // Keep track of the last output ID that might cause type-check failure while attempting to
74+ // compose the subgraphs.
75+ long lastOutputId = 0;
6576 for (CelCompiledMatch match : Lists.reverse(compiledRule.matches())) {
6677 CelAbstractSyntaxTree conditionAst = match.condition();
6778 boolean isTriviallyTrue =
6879 conditionAst.getExpr().constantOrDefault().getKind().equals(Kind.BOOLEAN_VALUE)
6980 && conditionAst.getExpr().constant().booleanValue();
7081 switch (match.result().kind()) {
7182 case OUTPUT:
72- CelMutableAst outAst = CelMutableAst.fromCelAst(match.result().output());
83+ OutputValue matchOutput = match.result().output();
84+ CelMutableAst outAst = CelMutableAst.fromCelAst(matchOutput.ast());
7385 if (isTriviallyTrue) {
7486 matchAst = outAst;
7587 isOptionalResult = false;
88+ lastOutputId = matchOutput.id();
7689 continue;
7790 }
7891 if (isOptionalResult) {
@@ -85,9 +98,13 @@ private RuleOptimizationResult optimizeRule(CelCompiledRule compiledRule) {
8598 CelMutableAst.fromCelAst(conditionAst),
8699 outAst,
87100 matchAst);
101+ assertComposedAstIsValid(
102+ cel, matchAst, "conflicting output types found.", matchOutput.id(), lastOutputId);
103+ lastOutputId = matchOutput.id();
88104 continue;
89105 case RULE:
90- RuleOptimizationResult nestedRule = optimizeRule(match.result().rule());
106+ CelCompiledRule matchNestedRule = match.result().rule();
107+ RuleOptimizationResult nestedRule = optimizeRule(cel, matchNestedRule);
91108 CelMutableAst nestedRuleAst = nestedRule.ast();
92109 if (isOptionalResult && !nestedRule.isOptionalResult()) {
93110 nestedRuleAst =
@@ -101,6 +118,13 @@ private RuleOptimizationResult optimizeRule(CelCompiledRule compiledRule) {
101118 throw new IllegalArgumentException("Subrule early terminates policy");
102119 }
103120 matchAst = astMutator.newMemberCall(nestedRuleAst, Function.OR.getFunction(), matchAst);
121+ assertComposedAstIsValid(
122+ cel,
123+ matchAst,
124+ String.format(
125+ "failed composing the subrule '%s' due to conflicting output types.",
126+ matchNestedRule.id().map(ValueString::value).orElse("")),
127+ lastOutputId);
104128 break;
105129 }
106130 }
@@ -127,9 +151,38 @@ static RuleComposer newInstance(
127151 return new RuleComposer(compiledRule, variablePrefix, iterationLimit);
128152 }
129153
154+ private void assertComposedAstIsValid(
155+ Cel cel, CelMutableAst composedAst, String failureMessage, Long... ids) {
156+ assertComposedAstIsValid(cel, composedAst, failureMessage, Arrays.asList(ids));
157+ }
158+
159+ private void assertComposedAstIsValid(
160+ Cel cel, CelMutableAst composedAst, String failureMessage, List<Long> ids) {
161+ try {
162+ cel.check(composedAst.toParsedAst()).getAst();
163+ } catch (CelValidationException e) {
164+ ids = ids.stream().filter(id -> id > 0).collect(toCollection(ArrayList::new));
165+ throw new RuleCompositionException(failureMessage, e, ids);
166+ }
167+ }
168+
130169 private RuleComposer(CelCompiledRule compiledRule, String variablePrefix, int iterationLimit) {
131170 this.compiledRule = checkNotNull(compiledRule);
132171 this.variablePrefix = variablePrefix;
133172 this.astMutator = AstMutator.newInstance(iterationLimit);
134173 }
174+
175+ static final class RuleCompositionException extends RuntimeException {
176+ final String failureReason;
177+ final List<Long> errorIds;
178+ final CelValidationException compileException;
179+
180+ private RuleCompositionException(
181+ String failureReason, CelValidationException e, List<Long> errorIds) {
182+ super(e);
183+ this.failureReason = failureReason;
184+ this.errorIds = errorIds;
185+ this.compileException = e;
186+ }
187+ }
135188}
0 commit comments