Skip to content

Commit 9e339f2

Browse files
authored
Simplify If Expressions (#183)
1 parent a2f5975 commit 9e339f2

File tree

4 files changed

+154
-0
lines changed

4 files changed

+154
-0
lines changed

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionFolding.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import liquidjava.rj_language.ast.BinaryExpression;
44
import liquidjava.rj_language.ast.Expression;
55
import liquidjava.rj_language.ast.GroupExpression;
6+
import liquidjava.rj_language.ast.Ite;
67
import liquidjava.rj_language.ast.LiteralBoolean;
78
import liquidjava.rj_language.ast.LiteralInt;
89
import liquidjava.rj_language.ast.LiteralReal;
910
import liquidjava.rj_language.ast.UnaryExpression;
1011
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
1112
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
13+
import liquidjava.rj_language.opt.derivation_node.IteDerivationNode;
1214
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
1315
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
1416

@@ -26,6 +28,9 @@ public static ValDerivationNode fold(ValDerivationNode node) {
2628
if (exp instanceof UnaryExpression)
2729
return foldUnary(node);
2830

31+
if (exp instanceof Ite)
32+
return foldIte(node);
33+
2934
if (exp instanceof GroupExpression group) {
3035
if (group.getChildren().size() == 1) {
3136
return fold(new ValDerivationNode(group.getChildren().get(0), node.getOrigin()));
@@ -197,4 +202,45 @@ private static ValDerivationNode foldUnary(ValDerivationNode node) {
197202
DerivationNode origin = operandNode.getOrigin() != null ? new UnaryDerivationNode(operandNode, operator) : null;
198203
return new ValDerivationNode(unaryExp, origin);
199204
}
205+
206+
/**
207+
* Folds ternary expressions by checking if condition is a boolean literal or both branches are the same
208+
*/
209+
private static ValDerivationNode foldIte(ValDerivationNode node) {
210+
Ite iteExp = (Ite) node.getValue();
211+
212+
ValDerivationNode condNode = fold(new ValDerivationNode(iteExp.getCondition(), null));
213+
ValDerivationNode thenNode = fold(new ValDerivationNode(iteExp.getThen(), null));
214+
ValDerivationNode elseNode = fold(new ValDerivationNode(iteExp.getElse(), null));
215+
216+
Expression condition = condNode.getValue();
217+
Expression thenExp = thenNode.getValue();
218+
Expression elseExp = elseNode.getValue();
219+
220+
iteExp.setChild(0, condition);
221+
iteExp.setChild(1, thenExp);
222+
iteExp.setChild(2, elseExp);
223+
224+
// if condition is a boolean literal, select the corresponding branch: true ? a : b => a, false ? a : b => b
225+
if (condition instanceof LiteralBoolean boolCond) {
226+
Expression selected = boolCond.isBooleanTrue() ? thenExp : elseExp;
227+
DerivationNode origin = new IteDerivationNode(condNode, thenNode, elseNode);
228+
return new ValDerivationNode(selected, origin);
229+
}
230+
231+
// if both branches are the same, return one of them (e.g. cond ? b : b => b)
232+
if (thenExp.equals(elseExp)) {
233+
DerivationNode origin = new IteDerivationNode(condNode, thenNode, elseNode);
234+
return new ValDerivationNode(thenExp, origin);
235+
}
236+
237+
// no folding, but keep track of the folding steps in the origin
238+
DerivationNode origin = hasIteChildOrigin(condNode, thenNode, elseNode)
239+
? new IteDerivationNode(condNode, thenNode, elseNode) : node.getOrigin();
240+
return new ValDerivationNode(iteExp, origin);
241+
}
242+
243+
private static boolean hasIteChildOrigin(ValDerivationNode cond, ValDerivationNode then, ValDerivationNode els) {
244+
return cond.getOrigin() != null || then.getOrigin() != null || els.getOrigin() != null;
245+
}
200246
}

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import liquidjava.rj_language.ast.Var;
77
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
88
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
9+
import liquidjava.rj_language.opt.derivation_node.IteDerivationNode;
910
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
1011
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
1112
import liquidjava.rj_language.opt.derivation_node.VarDerivationNode;
@@ -134,6 +135,10 @@ private static void extractVarOrigins(ValDerivationNode node, Map<String, Deriva
134135
extractVarOrigins(binOrigin.getRight(), varOrigins);
135136
} else if (origin instanceof UnaryDerivationNode unaryOrigin) {
136137
extractVarOrigins(unaryOrigin.getOperand(), varOrigins);
138+
} else if (origin instanceof IteDerivationNode iteOrigin) {
139+
extractVarOrigins(iteOrigin.getCondition(), varOrigins);
140+
extractVarOrigins(iteOrigin.getThenBranch(), varOrigins);
141+
extractVarOrigins(iteOrigin.getElseBranch(), varOrigins);
137142
} else if (origin instanceof ValDerivationNode valOrigin) {
138143
extractVarOrigins(valOrigin, varOrigins);
139144
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package liquidjava.rj_language.opt.derivation_node;
2+
3+
public class IteDerivationNode extends DerivationNode {
4+
5+
private final ValDerivationNode condition;
6+
private final ValDerivationNode thenBranch;
7+
private final ValDerivationNode elseBranch;
8+
9+
public IteDerivationNode(ValDerivationNode condition, ValDerivationNode thenBranch, ValDerivationNode elseBranch) {
10+
this.condition = condition;
11+
this.thenBranch = thenBranch;
12+
this.elseBranch = elseBranch;
13+
}
14+
15+
public ValDerivationNode getCondition() {
16+
return condition;
17+
}
18+
19+
public ValDerivationNode getThenBranch() {
20+
return thenBranch;
21+
}
22+
23+
public ValDerivationNode getElseBranch() {
24+
return elseBranch;
25+
}
26+
}

liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
import liquidjava.rj_language.ast.BinaryExpression;
66
import liquidjava.rj_language.ast.Expression;
7+
import liquidjava.rj_language.ast.Ite;
78
import liquidjava.rj_language.ast.LiteralBoolean;
89
import liquidjava.rj_language.ast.LiteralInt;
910
import liquidjava.rj_language.ast.UnaryExpression;
1011
import liquidjava.rj_language.ast.Var;
1112
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
1213
import liquidjava.rj_language.opt.derivation_node.DerivationNode;
14+
import liquidjava.rj_language.opt.derivation_node.IteDerivationNode;
1315
import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode;
1416
import liquidjava.rj_language.opt.derivation_node.ValDerivationNode;
1517
import liquidjava.rj_language.opt.derivation_node.VarDerivationNode;
@@ -877,6 +879,76 @@ void testInternalToInternalNoFurtherResolution() {
877879
"#a_3 (lower counter) replaced by #b_7 (higher counter); equality collapses to trivial");
878880
}
879881

882+
@Test
883+
void testIteTrueConditionSimplifiesToThenBranch() {
884+
// Given: true ? a : b
885+
// Expected: a
886+
887+
Expression expr = new Ite(new LiteralBoolean(true), new Var("a"), new Var("b"));
888+
889+
// When
890+
ValDerivationNode result = ExpressionSimplifier.simplify(expr);
891+
892+
// Then
893+
assertNotNull(result, "Result should not be null");
894+
assertEquals("a", result.getValue().toString(), "Expected result to be a");
895+
896+
ValDerivationNode conditionNode = new ValDerivationNode(new LiteralBoolean(true), null);
897+
ValDerivationNode thenNode = new ValDerivationNode(new Var("a"), null);
898+
ValDerivationNode elseNode = new ValDerivationNode(new Var("b"), null);
899+
IteDerivationNode iteOrigin = new IteDerivationNode(conditionNode, thenNode, elseNode);
900+
ValDerivationNode expected = new ValDerivationNode(new Var("a"), iteOrigin);
901+
902+
assertDerivationEquals(expected, result, "");
903+
}
904+
905+
@Test
906+
void testIteFalseConditionSimplifiesToElseBranch() {
907+
// Given: false ? a : b
908+
// Expected: b
909+
910+
Expression expr = new Ite(new LiteralBoolean(false), new Var("a"), new Var("b"));
911+
912+
// When
913+
ValDerivationNode result = ExpressionSimplifier.simplify(expr);
914+
915+
// Then
916+
assertNotNull(result, "Result should not be null");
917+
assertEquals("b", result.getValue().toString(), "Expected result to be b");
918+
919+
ValDerivationNode conditionNode = new ValDerivationNode(new LiteralBoolean(false), null);
920+
ValDerivationNode thenNode = new ValDerivationNode(new Var("a"), null);
921+
ValDerivationNode elseNode = new ValDerivationNode(new Var("b"), null);
922+
IteDerivationNode iteOrigin = new IteDerivationNode(conditionNode, thenNode, elseNode);
923+
ValDerivationNode expected = new ValDerivationNode(new Var("b"), iteOrigin);
924+
925+
assertDerivationEquals(expected, result, "");
926+
}
927+
928+
@Test
929+
void testIteEqualBranchesSimplifiesToBranch() {
930+
// Given: cond ? b : b
931+
// Expected: b
932+
933+
Expression branch = new Var("b");
934+
Expression expr = new Ite(new Var("cond"), branch, branch.clone());
935+
936+
// When
937+
ValDerivationNode result = ExpressionSimplifier.simplify(expr);
938+
939+
// Then
940+
assertNotNull(result, "Result should not be null");
941+
assertEquals("b", result.getValue().toString(), "Expected result to be b");
942+
943+
ValDerivationNode conditionNode = new ValDerivationNode(new Var("cond"), null);
944+
ValDerivationNode thenNode = new ValDerivationNode(new Var("b"), null);
945+
ValDerivationNode elseNode = new ValDerivationNode(new Var("b"), null);
946+
IteDerivationNode iteOrigin = new IteDerivationNode(conditionNode, thenNode, elseNode);
947+
ValDerivationNode expected = new ValDerivationNode(new Var("b"), iteOrigin);
948+
949+
assertDerivationEquals(expected, result, "");
950+
}
951+
880952
/**
881953
* Helper method to compare two derivation nodes recursively
882954
*/
@@ -903,6 +975,11 @@ private void assertDerivationEquals(DerivationNode expected, DerivationNode actu
903975
UnaryDerivationNode actualUnary = (UnaryDerivationNode) actual;
904976
assertEquals(expectedUnary.getOp(), actualUnary.getOp(), message + ": operators should match");
905977
assertDerivationEquals(expectedUnary.getOperand(), actualUnary.getOperand(), message + " > operand");
978+
} else if (expected instanceof IteDerivationNode expectedIte) {
979+
IteDerivationNode actualIte = (IteDerivationNode) actual;
980+
assertDerivationEquals(expectedIte.getCondition(), actualIte.getCondition(), message + " > condition");
981+
assertDerivationEquals(expectedIte.getThenBranch(), actualIte.getThenBranch(), message + " > then");
982+
assertDerivationEquals(expectedIte.getElseBranch(), actualIte.getElseBranch(), message + " > else");
906983
}
907984
}
908985
}

0 commit comments

Comments
 (0)