44
55import liquidjava .rj_language .ast .BinaryExpression ;
66import liquidjava .rj_language .ast .Expression ;
7+ import liquidjava .rj_language .ast .Ite ;
78import liquidjava .rj_language .ast .LiteralBoolean ;
89import liquidjava .rj_language .ast .LiteralInt ;
910import liquidjava .rj_language .ast .UnaryExpression ;
1011import liquidjava .rj_language .ast .Var ;
1112import liquidjava .rj_language .opt .derivation_node .BinaryDerivationNode ;
1213import liquidjava .rj_language .opt .derivation_node .DerivationNode ;
14+ import liquidjava .rj_language .opt .derivation_node .IteDerivationNode ;
1315import liquidjava .rj_language .opt .derivation_node .UnaryDerivationNode ;
1416import liquidjava .rj_language .opt .derivation_node .ValDerivationNode ;
1517import liquidjava .rj_language .opt .derivation_node .VarDerivationNode ;
@@ -877,6 +879,76 @@ void testInternalToInternalNoFurtherResolution() {
877879 "#a_3 (lower counter) replaced by #b_7 (higher counter); equality collapses to trivial" );
878880 }
879881
882+ @ Test
883+ void testIteTrueConditionSimplifiesToThenBranch () {
884+ // Given: true ? a : b
885+ // Expected: a
886+
887+ Expression expr = new Ite (new LiteralBoolean (true ), new Var ("a" ), new Var ("b" ));
888+
889+ // When
890+ ValDerivationNode result = ExpressionSimplifier .simplify (expr );
891+
892+ // Then
893+ assertNotNull (result , "Result should not be null" );
894+ assertEquals ("a" , result .getValue ().toString (), "Expected result to be a" );
895+
896+ ValDerivationNode conditionNode = new ValDerivationNode (new LiteralBoolean (true ), null );
897+ ValDerivationNode thenNode = new ValDerivationNode (new Var ("a" ), null );
898+ ValDerivationNode elseNode = new ValDerivationNode (new Var ("b" ), null );
899+ IteDerivationNode iteOrigin = new IteDerivationNode (conditionNode , thenNode , elseNode );
900+ ValDerivationNode expected = new ValDerivationNode (new Var ("a" ), iteOrigin );
901+
902+ assertDerivationEquals (expected , result , "" );
903+ }
904+
905+ @ Test
906+ void testIteFalseConditionSimplifiesToElseBranch () {
907+ // Given: false ? a : b
908+ // Expected: b
909+
910+ Expression expr = new Ite (new LiteralBoolean (false ), new Var ("a" ), new Var ("b" ));
911+
912+ // When
913+ ValDerivationNode result = ExpressionSimplifier .simplify (expr );
914+
915+ // Then
916+ assertNotNull (result , "Result should not be null" );
917+ assertEquals ("b" , result .getValue ().toString (), "Expected result to be b" );
918+
919+ ValDerivationNode conditionNode = new ValDerivationNode (new LiteralBoolean (false ), null );
920+ ValDerivationNode thenNode = new ValDerivationNode (new Var ("a" ), null );
921+ ValDerivationNode elseNode = new ValDerivationNode (new Var ("b" ), null );
922+ IteDerivationNode iteOrigin = new IteDerivationNode (conditionNode , thenNode , elseNode );
923+ ValDerivationNode expected = new ValDerivationNode (new Var ("b" ), iteOrigin );
924+
925+ assertDerivationEquals (expected , result , "" );
926+ }
927+
928+ @ Test
929+ void testIteEqualBranchesSimplifiesToBranch () {
930+ // Given: cond ? b : b
931+ // Expected: b
932+
933+ Expression branch = new Var ("b" );
934+ Expression expr = new Ite (new Var ("cond" ), branch , branch .clone ());
935+
936+ // When
937+ ValDerivationNode result = ExpressionSimplifier .simplify (expr );
938+
939+ // Then
940+ assertNotNull (result , "Result should not be null" );
941+ assertEquals ("b" , result .getValue ().toString (), "Expected result to be b" );
942+
943+ ValDerivationNode conditionNode = new ValDerivationNode (new Var ("cond" ), null );
944+ ValDerivationNode thenNode = new ValDerivationNode (new Var ("b" ), null );
945+ ValDerivationNode elseNode = new ValDerivationNode (new Var ("b" ), null );
946+ IteDerivationNode iteOrigin = new IteDerivationNode (conditionNode , thenNode , elseNode );
947+ ValDerivationNode expected = new ValDerivationNode (new Var ("b" ), iteOrigin );
948+
949+ assertDerivationEquals (expected , result , "" );
950+ }
951+
880952 /**
881953 * Helper method to compare two derivation nodes recursively
882954 */
@@ -903,6 +975,11 @@ private void assertDerivationEquals(DerivationNode expected, DerivationNode actu
903975 UnaryDerivationNode actualUnary = (UnaryDerivationNode ) actual ;
904976 assertEquals (expectedUnary .getOp (), actualUnary .getOp (), message + ": operators should match" );
905977 assertDerivationEquals (expectedUnary .getOperand (), actualUnary .getOperand (), message + " > operand" );
978+ } else if (expected instanceof IteDerivationNode expectedIte ) {
979+ IteDerivationNode actualIte = (IteDerivationNode ) actual ;
980+ assertDerivationEquals (expectedIte .getCondition (), actualIte .getCondition (), message + " > condition" );
981+ assertDerivationEquals (expectedIte .getThenBranch (), actualIte .getThenBranch (), message + " > then" );
982+ assertDerivationEquals (expectedIte .getElseBranch (), actualIte .getElseBranch (), message + " > else" );
906983 }
907984 }
908985}
0 commit comments