forked from WebAssembly/binaryen
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwasm-analyze.cpp
More file actions
705 lines (618 loc) · 25 KB
/
Copy pathwasm-analyze.cpp
File metadata and controls
705 lines (618 loc) · 25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
/*
* Copyright 2016 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//
// Expression analyzer utility
//
// Superoptimization is based on Bansal, Sorav; Aiken, Alex (21–25 October 2006): "Automatic Generation of Peephole Superoptimizers".
//
#include "support/colors.h"
#include "support/command-line.h"
#include "support/file.h"
#include "support/hash.h"
#include "support/permutation.h"
#include "wasm-s-parser.h"
#include "wasm-traversal.h"
#include "wasm-printing.h"
#include "wasm-interpreter.h"
#include "wasm-io.h"
#include "ast_utils.h"
#include "ast/cost.h"
using namespace cashew;
using namespace wasm;
// limits on what we care about
#define MAX_EXPRESSION_SIZE 20
#define MAX_LOCAL 4
// special values to make sure to consider in execution hashing
#define NUM_LIMITS 6
static int32_t LIMIT_I32S[NUM_LIMITS] = { std::numeric_limits<int32_t>::min(), std::numeric_limits<int32_t>::max(), int32_t(std::numeric_limits<uint32_t>::min()), int32_t(std::numeric_limits<uint32_t>::max()), 0xfffff, -0xfffff };
static int64_t LIMIT_I64S[NUM_LIMITS] = { std::numeric_limits<int64_t>::min(), std::numeric_limits<int64_t>::max(), int64_t(std::numeric_limits<uint64_t>::min()), int64_t(std::numeric_limits<uint64_t>::max()), 0xfffffLL, -0xfffffLL };
static float LIMIT_F32S[NUM_LIMITS] = { std::numeric_limits<float>::min(), std::numeric_limits<float>::max(), std::numeric_limits<float>::quiet_NaN(), std::numeric_limits<float>::infinity(), float(0xfffff), float(-0xfffff) };
static double LIMIT_F64S[NUM_LIMITS] = { std::numeric_limits<double>::min(), std::numeric_limits<double>::max(), std::numeric_limits<double>::quiet_NaN(), std::numeric_limits<double>::infinity(), double(0xfffff), double(-0xfffff) };
#define MAX_SMALL 260
#define NUM_SMALLS (MAX_SMALL + MAX_SMALL + 1) /* negatives, positives, and zero */
#define NUM_SPECIALS (NUM_LIMITS + NUM_SMALLS)
#define NUM_RANDOMS 1000
#define NUM_EXECUTIONS (NUM_SPECIALS + NUM_RANDOMS)
// An expression with a cached hash value
struct HashedExpression {
Expression* expr;
size_t hash;
HashedExpression(Expression* expr) : expr(expr) {
if (expr) {
hash = ExpressionAnalyzer::hash(expr);
}
}
HashedExpression(const HashedExpression& other) : expr(other.expr), hash(other.hash) {}
};
struct ExpressionHasher {
size_t operator()(const HashedExpression value) const {
return value.hash;
}
};
struct ExpressionComparer {
bool operator()(const HashedExpression a, const HashedExpression b) const {
if (a.hash != b.hash) return false;
return ExpressionAnalyzer::equal(a.expr, b.expr);
}
};
// expression -> a count
class ExpressionIntMap : public std::unordered_map<HashedExpression, size_t, ExpressionHasher, ExpressionComparer> {};
// global expression state
Module global; // a module that persists til the end
ExpressionIntMap freqs; // expression -> its frequency
// Normalize an expression, replacing irrelevant bits with
// generalizations to get_local, and make get_locals start
// from 0. Returns nullptr if the expression is irrelevant.
static Expression* normalize(Expression* expr, Module& wasm) {
struct Normalizer {
Module& wasm;
Builder builder;
Index nextLocal = 0;
std::unordered_map<Index, Index> localMap; // old local index => new
Normalizer(Module& wasm) : wasm(wasm), builder(wasm) {}
Expression* parentCopy(Expression* curr) {
return ExpressionManipulator::flexibleCopy(curr, wasm, [&](Expression* curr) { return this->copy(curr); });
}
Expression* copy(Expression* curr) {
// For now, we only handle math-type expressions: having a return value and no side effects
// TODO: do more stuff, modeling side effects etc.
if (!isConcreteWasmType(curr->type)) {
return builder.makeUnreachable();
}
if (auto* get = curr->dynCast<GetLocal>()) {
Index newIndex;
auto iter = localMap.find(get->index);
if (iter == localMap.end()) {
newIndex = nextLocal++;
localMap[get->index] = newIndex;
} else {
newIndex = iter->second;
}
return builder.makeGetLocal(newIndex, get->type);
}
if (curr->is<SetLocal>()) {
assert(curr->type != none); // this is a tee
// look through the tee
return parentCopy(curr->cast<SetLocal>()->value);
}
if (curr->is<Load>()) {
// consider the general case of an arbitrary expression here
return builder.makeGetLocal(nextLocal++, curr->type);
}
if (curr->is<Host>() || curr->is<Call>() || curr->is<CallImport>() || curr->is<CallIndirect>() || curr->is<GetGlobal>() || curr->is<Load>() || curr->is<Return>() || curr->is<Break>() || curr->is<Switch>()) {
return builder.makeUnreachable();
}
return nullptr; // allow the default copy to proceed
}
} normalizer(wasm);
auto* ret = ExpressionManipulator::flexibleCopy(expr, wasm, [&](Expression* curr) { return normalizer.copy(curr); });
if (!isConcreteWasmType(ret->type) || normalizer.nextLocal >= MAX_LOCAL || Measurer::measure(ret) > MAX_EXPRESSION_SIZE) {
return nullptr;
}
return ret;
}
// Scan an expression for local types. Assumes it has MAX_LOCAL locals at most
struct ScanLocals : public WalkerPass<PostWalker<ScanLocals, Visitor<ScanLocals>>> {
WasmType localTypes[MAX_LOCAL];
ScanLocals(Expression* expr) {
for (Index i = 0; i < MAX_LOCAL; i++) {
localTypes[i] = none;
}
walk(expr);
}
void visitGetLocal(GetLocal* curr) {
assert(curr->index < MAX_LOCAL);
localTypes[curr->index] = curr->type;
}
};
// Remap locals
struct RemapLocals : public WalkerPass<PostWalker<RemapLocals, Visitor<RemapLocals>>> {
std::vector<Index>& mapping;
RemapLocals(Expression* expr, std::vector<Index>& mapping) : mapping(mapping) {
walk(expr);
}
void visitGetLocal(GetLocal* curr) {
curr->index = mapping[curr->index];
assert(curr->index < MAX_LOCAL);
}
void visitSetLocal(SetLocal* curr) {
curr->index = mapping[curr->index];
assert(curr->index < MAX_LOCAL);
}
};
struct ScanSettings {
Index* totalExpressions;
bool adviseOnly;
ScanSettings(Index* totalExpressions, bool adviseOnly) : totalExpressions(totalExpressions), adviseOnly(adviseOnly) {}
};
// Scan a module for expressions
struct Scan : public WalkerPass<PostWalker<Scan, UnifiedExpressionVisitor<Scan>>> {
ScanSettings settings;
Scan(ScanSettings settings) : settings(settings) {}
void doWalkFunction(Function* func) {
//std::cout << " [" << func->name << ']' << '\n';
walk(func->body);
}
void visitExpression(Expression* curr) {
// normalize the expression, creating a temporary copy in this module,
// which is ephemeral TODO: avoid keeping them alive til the end of
// module processing to reduce peak mem usage?
auto* normalized = normalize(curr, *getModule());
if (!normalized) return;
if (!settings.adviseOnly) {
(*settings.totalExpressions)++; // this is relevant, count it
}
HashedExpression hashed(normalized);
auto iter = freqs.find(hashed);
if (iter != freqs.end()) {
if (!settings.adviseOnly) {
iter->second++; // just increment it
}
} else {
// create a persistent copy in the global module TODO: avoid the rehash here
auto* copy = ExpressionManipulator::copy(normalized, global);
freqs[HashedExpression(copy)] = settings.adviseOnly ? 0 : 1;
#if 1
// add the permutations on the locals as well, with freq 0, as we just want to use them as optimization targets,
// we don't need to optimize them, we optimize the canonical first form.
ScanLocals scanner(copy);
if (scanner.localTypes[1] != none) {
struct PermutationsLister {
std::vector<std::vector<std::vector<Index>>> list; // index => list of permutations of that size
PermutationsLister() {
list.resize(MAX_LOCAL + 1);
for (size_t i = 1; i < MAX_LOCAL + 1; i++) {
list[i] = Permutation::makeAllPermutations(i);
}
}
};
static PermutationsLister permutationsLister;
Index numLocals = 2;
while (numLocals < MAX_LOCAL && scanner.localTypes[numLocals] != none) {
numLocals++;
}
assert(numLocals <= MAX_LOCAL);
auto& perms = permutationsLister.list.at(numLocals);
// ignore the special first we already handled
for (size_t i = 0; i < perms.size(); i++) {
auto* remapped = ExpressionManipulator::copy(copy, global);
RemapLocals remapper(remapped, perms[i]);
auto hashed = HashedExpression(remapped);
if (freqs.find(hashed) == freqs.end()) {
freqs[hashed] = 0;
}
}
}
#endif
}
}
};
// Generate local values deterministically, using a seed
class LocalGenerator {
Index seed;
public:
LocalGenerator(Index seed) : seed(seed) {}
Literal get(Index index, WasmType type) {
// use low indexes to ensure we get representation of a few special values
// TODO: get each of the MAX_LOCALS to all of its NUM_SPECIALS values
int64_t special = seed; // start with 0-NS having them all taking the same value
if (special >= NUM_SPECIALS) { // then give each a range for itself
special = int64_t(seed) - int64_t(NUM_SPECIALS * (index + 1));
}
if (special >= 0 && special < NUM_SPECIALS) {
if (special < NUM_LIMITS) {
switch (type) {
case i32: return Literal(LIMIT_I32S[special]);
case i64: return Literal(LIMIT_I64S[special]);
case f32: return Literal(LIMIT_F32S[special]);
case f64: return Literal(LIMIT_F64S[special]);
default: WASM_UNREACHABLE();
}
} else {
special -= NUM_LIMITS;
assert(special >= 0 && special < NUM_SMALLS);
special -= MAX_SMALL;
assert(special >= -MAX_SMALL && special <= MAX_SMALL);
switch (type) {
case i32: return Literal(int32_t(special));
case i64: return Literal(int64_t(special));
case f32: return Literal(float(special));
case f64: return Literal(double(special));
default: WASM_UNREACHABLE();
}
}
}
// a general "random"/deterministic value
auto base = rehash(seed, index);
switch (type) {
case i32:
case f32: {
auto ret = Literal(rehash(base, Index(type)));
if (type == f32) ret = ret.castToF32();
return ret;
}
case i64:
case f64: {
auto ret = Literal(rehash(base, Index(type)) | (int64_t(rehash(base, Index(type + 1000))) << 32));
if (type == f64) ret = ret.castToF64();
return ret;
}
default: WASM_UNREACHABLE();
}
}
};
struct TrapException {}; // TODO: use a flow label for optimization?
// Execute the expression over a set of local values
class Runner : public ExpressionRunner<Runner> {
LocalGenerator& localGenerator;
public:
Runner(LocalGenerator& localGenerator) : localGenerator(localGenerator) {}
Flow visitLoop(Loop* curr) {
// loops might be infinite, so must be careful
// but we can't tell if non-infinite, since we don't have state, so loops are just impossible to optimize for now
trap("loop");
WASM_UNREACHABLE();
}
Flow visitCall(Call* curr) {
abort(); // we should not see this
}
Flow visitCallImport(CallImport* curr) {
abort(); // we should not see this
}
Flow visitCallIndirect(CallIndirect* curr) {
abort(); // we should not see this
}
Flow visitGetLocal(GetLocal* curr) {
return Flow(localGenerator.get(curr->index, curr->type));
}
Flow visitSetLocal(SetLocal* curr) {
abort(); // we should not see this
}
Flow visitGetGlobal(GetGlobal* curr) {
abort(); // we should not see this
}
Flow visitSetGlobal(SetGlobal* curr) {
abort(); // we should not see this
}
Flow visitLoad(Load* curr) {
abort(); // we should not see this
}
Flow visitStore(Store* curr) {
abort(); // we should not see this
}
Flow visitHost(Host* curr) {
abort(); // we should not see this
}
void trap(const char* why) override {
throw TrapException();
}
};
// Calculate a hash value based on executing an expression
struct ExecutionHasher {
std::unordered_map<size_t, std::vector<Expression*>> hashClasses; // hash value => list of expressions that have it, so they may be equal
void note(Expression* expr) {
size_t hash;
try {
hash = doHash(expr);
} catch (TrapException& e) {
// we don't bother trying to handle things that trap TODO: maybe abort the whole thing, move try out, for speed?
return;
}
hashClasses[hash].push_back(expr); // we depend on expr being unique, so the classes are mathematical sets
}
size_t doHash(Expression* expr) {
// combine the result of multiple executions into the final hash
size_t hash = 0;
for (Index i = 0; i < NUM_EXECUTIONS; i++) {
LocalGenerator localGenerator(i);
Flow flow = Runner(localGenerator).visit(expr);
if (flow.breaking()) {
hash = rehash(hash, 1);
hash = rehash(hash, 2);
hash = rehash(hash, 3);
hash = rehash(hash, size_t(flow.breakTo.str));
} else {
hash = rehash(hash, 4);
hash = rehash(hash, flow.value.type);
switch (flow.value.type) {
case f32: flow.value = flow.value.castToI32(); break;
case f64: flow.value = flow.value.castToI64(); break;
default: {}
}
switch (flow.value.type) {
case none: hash = rehash(hash, 5); hash = rehash(hash, 6); break;
case i32: hash = rehash(hash, flow.value.geti32()); hash = rehash(hash, 7); break;
case i64: hash = rehash(hash, flow.value.geti64()); hash = rehash(hash, flow.value.geti64() >> 32); break;
default: WASM_UNREACHABLE();
}
}
}
return hash;
}
};
// calculate the weight of an expression - a value we wish to minimize
Index calcWeight(Expression* expr) {
return /* CostAnalyzer(expr).cost + */ Measurer::measure(expr);
}
// can our optimizer do better on a than b?
static bool alreadyOptimizable(Expression* input, WasmType localTypes[MAX_LOCAL], Expression* output) {
Module temp;
// make a single function that receives the expressions locals and returns its output
auto* func = new Function();
func->name = Name("temp");
func->result = input->type;
for (Index i = 0; i < MAX_LOCAL; i++) {
func->params.push_back(localTypes[i]);
}
func->body = ExpressionManipulator::copy(input, temp);
temp.addFunction(func);
// export the function, so optimizations don't kill it!
auto* export_ = new Export();
export_->name = Name("export");
export_->value = func->name;
export_->kind = ExternalKind::Function;
temp.addExport(export_);
// run the optimizer
PassRunner passRunner(&temp);
passRunner.addDefaultOptimizationPasses();
passRunner.run();
// evaluate the output vs b
return calcWeight(func->body) <= calcWeight(output);
}
// Given two expressions that hashing suggests might be the same, try
// harder directly on the two to prove or disprove equivalence
bool looksValid(Expression* a, Expression* b) {
if (a->type != b->type) return false; // hash collision, these are not even the same type
// local types must be identical, otherwise the rule isn't even valid to apply
ScanLocals aScanner(a), bScanner(b);
for (Index i = 0; i < MAX_LOCAL; i++) {
if (aScanner.localTypes[i] != bScanner.localTypes[i]) {
return false; // mismatching local types
}
}
// Let's use brute force: we'll run the same checks we run for hashing,
// but instead of a single hash summarizing it all, we'll check each
// case on the two expressions.
for (Index i = 0; i < NUM_EXECUTIONS; i++) {
LocalGenerator localGenerator(i);
Flow aFlow = Runner(localGenerator).visit(a);
Flow bFlow = Runner(localGenerator).visit(b);
// TODO: breaking
if (aFlow.value != bFlow.value) return false;
}
// let's see if this possible optimization is already something our
// optimizer can do: if we optimize the input, do we get something
// as good or better than the output?
if (alreadyOptimizable(a, aScanner.localTypes, b)) return false;
// we see no reason these two should not be joined together in holy optimony
return true;
}
// Generalize an expression. Currently just generalizes away
// constant values, but we should do more, e.g. maybe fold away
// differences in shifts? TODO
Expression* generalize(Expression* expr, Module& wasm) {
struct Generalizer {
Module& wasm;
Builder builder;
Generalizer(Module& wasm) : wasm(wasm), builder(wasm) {}
Expression* copy(Expression* curr) {
if (curr->is<Const>()) {
return builder.makeUnreachable();
}
return nullptr; // allow the default copy to proceed
}
} generalizer(wasm);
return ExpressionManipulator::flexibleCopy(expr, wasm, [&](Expression* curr) { return generalizer.copy(curr); });
}
int main(int argc, const char *argv[]) {
// receive arguments
std::vector<std::string> filenames;
Options options("wasm-analyze", "Analyze a set of wasm modules. Provide a set of input files, optionally split by 'advice:' (in which case files afterwards are just advice, used to find optimization outputs but not inputs we focus on optimizing)");
options.add_positional("INFILES", Options::Arguments::N,
[&](Options *o, const std::string &argument) {
filenames.push_back(argument);
});
options.parse(argc, argv);
Index totalExpressions = 0;
bool adviseOnly = false;
// read inputs
for (auto& filename : filenames) {
if (filename == "advice:" || filename == "advise:") {
adviseOnly = true;
std::cerr << "[advice-only from here]\n";
continue;
}
auto input(read_file<std::string>(filename, Flags::Text, Flags::Release));
Module wasm;
std::cerr << "[processing: " << filename << ']' << '\n';
try {
ModuleReader reader;
reader.read(filename, wasm);
} catch (ParseException& p) {
p.dump(std::cerr);
Fatal() << "error in parsing input " << filename;
}
// scan all expressions in all functions, optimized and not
PassRunner passRunner(&wasm);
passRunner.add<Scan>(ScanSettings(&totalExpressions, adviseOnly));
passRunner.addDefaultOptimizationPasses();
passRunner.add<Scan>(ScanSettings(&totalExpressions, adviseOnly));
passRunner.run();
}
// print frequencies
#if 0
std::cout << "Frequencies:\n";
std::vector<HashedExpression> sorted;
for (auto& iter : freqs) {
sorted.push_back(iter.first);
}
std::sort(sorted.begin(), sorted.end(), [&](const HashedExpression& a, const HashedExpression& b) {
auto diff = int64_t(freqs[a]) - int64_t(freqs[b]);
if (diff > 0) return true;
if (diff < 0) return false;
return size_t(a.expr) < size_t(b.expr);
});
for (auto& item : sorted) {
std::cout << freqs[item] << " : " << item.expr << '\n';
}
#endif
// perform execution hashing, looking for expressions that are functionally equivalent,
// so one can be optimized to the other
std::cerr << "[hashing executions]\n";
ExecutionHasher executionHasher;
for (auto& iter : freqs) {
auto* expr = iter.first.expr;
executionHasher.note(expr);
}
// Basic statistics
std::cerr << "[writing basic output]\n";
std::cout << "Execution hashing info:\n";
std::cout << " num expression nodes in total: " << totalExpressions << '\n';
std::cout << " num unique expressions: " << freqs.size() << '\n';
{
size_t total = 0;
for (auto& pair : executionHasher.hashClasses) {
total += pair.second.size();
}
std::cout << " num relevant expressions: " << total << '\n';
}
std::cout << " num execution classes: " << executionHasher.hashClasses.size() << '\n';
{
size_t max = 0;
for (auto& pair : executionHasher.hashClasses) {
max = std::max(max, pair.second.size());
}
std::cout << " max class size: " << max << '\n';
}
// Detailed output
{
// a rule is a connection between one pattern and another, which we think may be equivalent to it,
// and which may provide a measured benefit
// TODO: test rules on more random inputs, trying to prove they are not equivalent?
struct Rule {
Expression* from;
Expression* to;
size_t benefit;
Rule(Expression* from, Expression* to, size_t benefit) : from(from), to(to), benefit(benefit) {}
};
std::vector<Rule> rules;
std::cerr << "[finding rules]\n";
for (auto& pair : executionHasher.hashClasses) {
auto& clazz = pair.second;
Index size = clazz.size();
if (size < 2) continue;
// consider all pairs, since some may be spurious hash collisions
for (Index i = 0; i < size; i++) {
auto* iExpr = clazz[i];
auto iFreq = freqs[iExpr];
if (iFreq == 0) continue; // no frequency means no benefit to optimize it; this expression is just a target of optimization, not an origin
Index iSize = calcWeight(iExpr);
Expression* best = nullptr;
Index bestSize = -1;
for (Index j = 0; j < size; j++) {
if (i == j) continue;
auto* jExpr = clazz[j];
Index jSize = calcWeight(jExpr);
// we are looking for a rule where i => j, so we need j to be smaller
if (iSize <= jSize) continue; // TODO: for equality, look not just at size, but cost etc.
// a likely candidate, if direct attempts to prove they differ fail, this is worth reporting to the user
if (best && jSize >= bestSize) continue; // we can't do better
if (looksValid(iExpr, jExpr)) {
best = jExpr;
bestSize = jSize;
}
}
if (best) {
rules.emplace_back(iExpr, best, (iSize - bestSize) * iFreq);
}
}
}
// Many rules are part of a more general pattern, for example x + 1 + 1 === x + 2 is
// closely related to x + 1 + 2 === x + 3. The generalized rule is what the human would
// write in the optimizer, so to assess the benefit of rules, we must generalize in
// our output.
std::cerr << "[generalizing]\n";
struct GeneralizedRule : public Rule {
std::vector<Rule*> rules; // the specific rules underlying this generalization
GeneralizedRule(Expression* from, Rule* rule) : Rule(from, nullptr, 0) {
addRule(rule);
}
void addRule(Rule* rule) {
benefit += rule->benefit;
rules.push_back(rule);
}
};
// hashed from expression => the generalized rules for that expression
std::unordered_map<HashedExpression, GeneralizedRule, ExpressionHasher, ExpressionComparer> generalizedRules;
for (auto& rule : rules) {
auto generalizedFrom = HashedExpression(generalize(rule.from, global)); // TODO: save memory, don't use global unless needed?
auto iter = generalizedRules.find(generalizedFrom);
if (iter == generalizedRules.end()) {
generalizedRules.emplace(generalizedFrom, GeneralizedRule(generalizedFrom.expr, &rule));
} else {
iter->second.addRule(&rule);
}
}
// final sorting and output
std::cerr << "[sorting generalized rules]\n";
std::vector<GeneralizedRule*> sortedGeneralizedRules;
for (auto& pair : generalizedRules) {
sortedGeneralizedRules.push_back(&pair.second);
}
auto ruleSorter = [](const Rule* a, const Rule* b) {
// primary sorting criteria is the size benefit
auto diff = int64_t(a->benefit) - int64_t(b->benefit);
if (diff > 0) return true;
if (diff < 0) return false;
return size_t(a->from) < size_t(b->from);
};
std::sort(sortedGeneralizedRules.begin(), sortedGeneralizedRules.end(), [&ruleSorter](const GeneralizedRule* a, const GeneralizedRule* b) {
return ruleSorter(a, b);
});
std::cout << "sorted possible optimization rules:\n";
Index totalWeight = totalExpressions * 2; // Just an estimate FIXME
size_t i = 0;
for (auto* item : sortedGeneralizedRules) {
std::cout << "\n[generalized rule " << (i++) << ": benefit: " << item->benefit << ", (" << (100*double(item->benefit)/totalWeight) << "%)], input pattern:\n" << item->from << '\n';
// show the specific rules underlying the generalized one
std::sort(item->rules.begin(), item->rules.end(), ruleSorter);
for (auto* rule : item->rules) {
std::cout << "\n[child specific rule benefit: " << rule->benefit << ", (" << (100*double(rule->benefit)/totalWeight) << "%)], possible rule:\n" << rule->from << "\n =->\n" << rule->to << '\n';
}
}
}
// TODO TODO: if all execution hashes of expr are the same, it might be constant (avoids needing to have all constants hashed)
}