4242 *
4343 * Rewrite a chain of element-wise multiply hops that contain identical elements.
4444 * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
45- * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
45+ * The order of the multiplicands depends on their data types, dimensions (matrix or vector), and sparsity.
4646 *
4747 * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
4848 * since foreign parents may rely on the individual results.
@@ -136,6 +136,8 @@ private static Hop constructReplacement(final Set<BinaryOp> emults, final Multis
136136 // sorted contains all leaves, sorted by data type, stripped from their parents
137137
138138 // Construct right-deep EMult tree
139+ // TODO compile binary outer mult for transition from row and column vectors to matrices
140+ // TODO compile subtree for column vectors to avoid blow-up of intermediates on row-col vector transition
139141 final Iterator <Map .Entry <Hop , Integer >> iterator = sorted .entrySet ().iterator ();
140142 Hop first = constructPower (iterator .next ());
141143
@@ -160,13 +162,15 @@ private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
160162 }
161163
162164 /**
163- * A Comparator that orders Hops by their data type, dimention , and sparsity.
165+ * A Comparator that orders Hops by their data type, dimension , and sparsity.
164166 * The order is as follows:
165167 * scalars > row vectors > col vectors >
166168 * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
167169 * other data types.
168170 * Disambiguate by Hop ID.
169171 */
172+ //TODO replace by ComparableHop wrapper around hop that implements equals and compareTo
173+ //in order to ensure comparisons that are 'consistent with equals'
170174 private static final Comparator <Hop > compareByDataType = new Comparator <Hop >() {
171175 private final int [] orderDataType = new int [Expression .DataType .values ().length ];
172176 {
@@ -190,17 +194,17 @@ public final int compare(Hop o1, Hop o2) {
190194 case MATRIX :
191195 // two matrices; check for vectors
192196 if (o1 .getDim1 () == 1 ) { // row vector
193- if (o2 .getDim1 () != 1 ) return 1 ; // row vectors are greatest of matrices
194- return compareBySparsityThenId (o1 , o2 ); // both row vectors
197+ if (o2 .getDim1 () != 1 ) return 1 ; // row vectors are greatest of matrices
198+ return compareBySparsityThenId (o1 , o2 ); // both row vectors
195199 } else if (o2 .getDim1 () == 1 ) { // 2 is row vector; 1 is not
196- return -1 ; // row vectors are the greatest matrices
200+ return -1 ; // row vectors are the greatest matrices
197201 } else if (o1 .getDim2 () == 1 ) { // col vector
198- if (o2 .getDim2 () != 1 ) return 1 ; // col vectors greater than non-vectors
199- return compareBySparsityThenId (o1 , o2 ); // both col vectors
202+ if (o2 .getDim2 () != 1 ) return 1 ; // col vectors greater than non-vectors
203+ return compareBySparsityThenId (o1 , o2 ); // both col vectors
200204 } else if (o2 .getDim2 () == 1 ) { // 2 is col vector; 1 is not
201- return 1 ; // col vectors greater than non-vectors
205+ return - 1 ; // col vectors greater than non-vectors
202206 } else { // both non-vectors
203- return compareBySparsityThenId (o1 , o2 );
207+ return compareBySparsityThenId (o1 , o2 );
204208 }
205209 default :
206210 return Long .compare (o1 .getHopID (), o2 .getHopID ());
@@ -243,7 +247,10 @@ private static boolean checkForeignParent(final Set<BinaryOp> emults, final Bina
243247 private static void findEMultsAndLeaves (final BinaryOp root , final Set <BinaryOp > emults , final Multiset <Hop > leaves ) {
244248 // Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
245249 emults .add (root );
246-
250+
251+ // TODO proper handling of DAGs (avoid collecting the same leaf multiple times)
252+ // TODO exclude hops with unknown dimensions and move rewrites to dynamic rewrites
253+
247254 final ArrayList <Hop > inputs = root .getInput ();
248255 final Hop left = inputs .get (0 ), right = inputs .get (1 );
249256
0 commit comments