Skip to content

Commit a5c834b

Browse files
committed
[HOTFIX][SYSTEMML-1663] Fix and disable element-wise mult chain rewrite
This patch fixes the custom hop comparator to find an ordering of element-wise multiplication chains (scalars, vectors, matrices), which fixes the test issue of PR549. Due to additional issues that could cause result incorrectness or runtime errors, I'm temporarily disabling this rewrite and related tests.
1 parent 9e7ce7b commit a5c834b

3 files changed

Lines changed: 23 additions & 14 deletions

File tree

src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ public ProgramRewriter( boolean staticRewrites, boolean dynamicRewrites )
9696
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
9797
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
9898
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
99-
if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
100-
_dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
99+
//if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
100+
// _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
101101
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
102102
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
103103
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
@@ -108,7 +108,7 @@ public ProgramRewriter( boolean staticRewrites, boolean dynamicRewrites )
108108
_dagRuleSet.add( new RewriteIndexingVectorization() ); //dependency: cse, simplifications
109109
_dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock
110110

111-
//add statment block rewrite rules
111+
//add statement block rewrite rules
112112
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )
113113
_sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding
114114
if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS )

src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
*
4343
* Rewrite a chain of element-wise multiply hops that contain identical elements.
4444
* For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
45-
* The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
45+
* The order of the multiplicands depends on their data types, dimensions (matrix or vector), and sparsity.
4646
*
4747
* Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
4848
* since foreign parents may rely on the individual results.
@@ -136,6 +136,8 @@ private static Hop constructReplacement(final Set<BinaryOp> emults, final Multis
136136
// sorted contains all leaves, sorted by data type, stripped from their parents
137137

138138
// Construct right-deep EMult tree
139+
// TODO compile binary outer mult for transition from row and column vectors to matrices
140+
// TODO compile subtree for column vectors to avoid blow-up of intermediates on row-col vector transition
139141
final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
140142
Hop first = constructPower(iterator.next());
141143

@@ -160,13 +162,15 @@ private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
160162
}
161163

162164
/**
163-
* A Comparator that orders Hops by their data type, dimention, and sparsity.
165+
* A Comparator that orders Hops by their data type, dimension, and sparsity.
164166
* The order is as follows:
165167
* scalars > row vectors > col vectors >
166168
* non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
167169
* other data types.
168170
* Disambiguate by Hop ID.
169171
*/
172+
//TODO replace by ComparableHop wrapper around hop that implements equals and compareTo
173+
//in order to ensure comparisons that are 'consistent with equals'
170174
private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
171175
private final int[] orderDataType = new int[Expression.DataType.values().length];
172176
{
@@ -190,17 +194,17 @@ public final int compare(Hop o1, Hop o2) {
190194
case MATRIX:
191195
// two matrices; check for vectors
192196
if (o1.getDim1() == 1) { // row vector
193-
if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
194-
return compareBySparsityThenId(o1, o2); // both row vectors
197+
if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
198+
return compareBySparsityThenId(o1, o2); // both row vectors
195199
} else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
196-
return -1; // row vectors are the greatest matrices
200+
return -1; // row vectors are the greatest matrices
197201
} else if (o1.getDim2() == 1) { // col vector
198-
if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
199-
return compareBySparsityThenId(o1, o2); // both col vectors
202+
if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
203+
return compareBySparsityThenId(o1, o2); // both col vectors
200204
} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
201-
return 1; // col vectors greater than non-vectors
205+
return -1; // col vectors greater than non-vectors
202206
} else { // both non-vectors
203-
return compareBySparsityThenId(o1, o2);
207+
return compareBySparsityThenId(o1, o2);
204208
}
205209
default:
206210
return Long.compare(o1.getHopID(), o2.getHopID());
@@ -243,7 +247,10 @@ private static boolean checkForeignParent(final Set<BinaryOp> emults, final Bina
243247
private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
244248
// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
245249
emults.add(root);
246-
250+
251+
// TODO proper handling of DAGs (avoid collecting the same leaf multiple times)
252+
// TODO exclude hops with unknown dimensions and move rewrites to dynamic rewrites
253+
247254
final ArrayList<Hop> inputs = root.getInput();
248255
final Hop left = inputs.get(0), right = inputs.get(1);
249256

src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public void setUp() {
5050
TestUtils.clearAssertionInformation();
5151
addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
5252
}
53-
53+
5454
@Test
5555
public void testMatrixMultChainOptNoRewritesCP() {
5656
testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
@@ -61,6 +61,7 @@ public void testMatrixMultChainOptNoRewritesSP() {
6161
testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
6262
}
6363

64+
/* TODO enable together with RewriteElementwiseMultChainOptimization
6465
@Test
6566
public void testMatrixMultChainOptRewritesCP() {
6667
testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
@@ -70,6 +71,7 @@ public void testMatrixMultChainOptRewritesCP() {
7071
public void testMatrixMultChainOptRewritesSP() {
7172
testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
7273
}
74+
*/
7375

7476
private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
7577
{

0 commit comments

Comments
 (0)