Skip to content

Commit f6b5c1e

Browse files
committed
move function parallelism to pass and pass runner, which allows more efficient parallel execution (WebAssembly#564)
1 parent b76818e commit f6b5c1e

17 files changed

Lines changed: 145 additions & 139 deletions

src/pass.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ struct PassRunner {
8989
}
9090

9191
template<class P, class Arg>
92-
void add(Arg& arg){
92+
void add(Arg arg){
9393
passes.push_back(new P(arg));
9494
}
9595

@@ -116,6 +116,9 @@ struct PassRunner {
116116
P* getLast();
117117

118118
~PassRunner();
119+
120+
private:
121+
void runPassOnFunction(Pass* pass, Function* func);
119122
};
120123

121124
//
@@ -136,6 +139,25 @@ class Pass {
136139
WASM_UNREACHABLE(); // by default, passes cannot be run this way
137140
}
138141

142+
// Function parallelism. By default, passes are not run in parallel, but you
143+
// can override this method to say that functions are parallelizable. This
144+
// should always be safe *unless* you do something in the pass that makes it
145+
// not thread-safe; in other words, the Module and Function objects and
146+
// so forth are set up so that Functions can be processed in parallel, so
147+
// if you do not ad global state that could be raced on, your pass could be
148+
// function-parallel.
149+
//
150+
// Function-parallel passes create an instance of the Walker class per function.
151+
// That means that you can't rely on Walker object properties to persist across
152+
// your functions, and you can't expect a new object to be created for each
153+
// function either (which could be very inefficient).
154+
virtual bool isFunctionParallel() { return false; }
155+
156+
// This method is used to create instances per function for a function-parallel
157+
// pass. You may need to override this if you subclass a Walker, as otherwise
158+
// this will create the parent class.
159+
virtual Pass* create() { WASM_UNREACHABLE(); }
160+
139161
std::string name;
140162

141163
protected:
@@ -197,7 +219,7 @@ class Printer : public Pass {
197219

198220
public:
199221
Printer() : o(std::cout) {}
200-
Printer(std::ostream& o) : o(o) {}
222+
Printer(std::ostream* o) : o(*o) {}
201223

202224
void run(PassRunner* runner, Module* module) override;
203225
};

src/passes/CoalesceLocals.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ struct Liveness {
159159
};
160160

161161
struct CoalesceLocals : public WalkerPass<CFGWalker<CoalesceLocals, Visitor<CoalesceLocals>, Liveness>> {
162-
bool isFunctionParallel() { return true; }
162+
bool isFunctionParallel() override { return true; }
163+
164+
Pass* create() override { return new CoalesceLocals; }
163165

164166
Index numLocals;
165167

@@ -462,7 +464,7 @@ void CoalesceLocals::applyIndices(std::vector<Index>& indices, Expression* root)
462464
}
463465

464466
struct CoalesceLocalsWithLearning : public CoalesceLocals {
465-
virtual CoalesceLocals* create() override { return new CoalesceLocalsWithLearning; }
467+
virtual Pass* create() override { return new CoalesceLocalsWithLearning; }
466468

467469
virtual void pickIndices(std::vector<Index>& indices) override;
468470
};

src/passes/DeadCodeElimination.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
namespace wasm {
3636

3737
struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, Visitor<DeadCodeElimination>>> {
38-
bool isFunctionParallel() { return true; }
38+
bool isFunctionParallel() override { return true; }
39+
40+
Pass* create() override { return new DeadCodeElimination; }
3941

4042
// whether the current code is actually reachable
4143
bool reachable = true;

src/passes/DropReturnValues.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
namespace wasm {
2727

2828
struct DropReturnValues : public WalkerPass<PostWalker<DropReturnValues, Visitor<DropReturnValues>>> {
29-
bool isFunctionParallel() { return true; }
29+
bool isFunctionParallel() override { return true; }
30+
31+
Pass* create() override { return new DropReturnValues; }
3032

3133
std::vector<Expression*> expressionStack;
3234

src/passes/DuplicateFunctionElimination.cpp

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,13 @@
2626

2727
namespace wasm {
2828

29-
struct FunctionHasher : public PostWalker<FunctionHasher, Visitor<FunctionHasher>> {
30-
bool isFunctionParallel() { return true; }
29+
struct FunctionHasher : public WalkerPass<PostWalker<FunctionHasher, Visitor<FunctionHasher>>> {
30+
bool isFunctionParallel() override { return true; }
3131

32-
FunctionHasher* create() override {
33-
auto* ret = new FunctionHasher;
34-
ret->setOutput(output);
35-
return ret;
36-
}
32+
FunctionHasher(std::map<Function*, uint32_t>* output) : output(output) {}
3733

38-
void setOutput(std::map<Function*, uint32_t>* output_) {
39-
output = output_;
34+
FunctionHasher* create() override {
35+
return new FunctionHasher(output);
4036
}
4137

4238
void doWalkFunction(Function* func) {
@@ -63,17 +59,13 @@ struct FunctionHasher : public PostWalker<FunctionHasher, Visitor<FunctionHasher
6359
};
6460
};
6561

66-
struct FunctionReplacer : public PostWalker<FunctionReplacer, Visitor<FunctionReplacer>> {
67-
bool isFunctionParallel() { return true; }
62+
struct FunctionReplacer : public WalkerPass<PostWalker<FunctionReplacer, Visitor<FunctionReplacer>>> {
63+
bool isFunctionParallel() override { return true; }
6864

69-
FunctionReplacer* create() override {
70-
auto* ret = new FunctionReplacer;
71-
ret->setReplacements(replacements);
72-
return ret;
73-
}
65+
FunctionReplacer(std::map<Name, Name>* replacements) : replacements(replacements) {}
7466

75-
void setReplacements(std::map<Name, Name>* replacements_) {
76-
replacements = replacements_;
67+
FunctionReplacer* create() override {
68+
return new FunctionReplacer(replacements);
7769
}
7870

7971
void visitCall(Call* curr) {
@@ -95,9 +87,9 @@ struct DuplicateFunctionElimination : public Pass {
9587
for (auto& func : module->functions) {
9688
hashes[func.get()] = 0; // ensure an entry for each function - we must not modify the map shape in parallel, just the values
9789
}
98-
FunctionHasher hasher;
99-
hasher.setOutput(&hashes);
100-
hasher.walkModule(module);
90+
PassRunner hasherRunner(module);
91+
hasherRunner.add<FunctionHasher>(&hashes);
92+
hasherRunner.run();
10193
// Find hash-equal groups
10294
std::map<uint32_t, std::vector<Function*>> hashGroups;
10395
for (auto& func : module->functions) {
@@ -127,9 +119,9 @@ struct DuplicateFunctionElimination : public Pass {
127119
}), v.end());
128120
module->updateFunctionsMap();
129121
// replace direct calls
130-
FunctionReplacer replacer;
131-
replacer.setReplacements(&replacements);
132-
replacer.walkModule(module);
122+
PassRunner replacerRunner(module);
123+
replacerRunner.add<FunctionReplacer>(&replacements);
124+
replacerRunner.run();
133125
// replace in table
134126
for (auto& name : module->table.names) {
135127
auto iter = replacements.find(name);

src/passes/MergeBlocks.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@
6868
namespace wasm {
6969

7070
struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBlocks>>> {
71-
bool isFunctionParallel() { return true; }
71+
bool isFunctionParallel() override { return true; }
72+
73+
Pass* create() override { return new MergeBlocks; }
7274

7375
void visitBlock(Block *curr) {
7476
bool more = true;

src/passes/OptimizeInstructions.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
namespace wasm {
2727

2828
struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, Visitor<OptimizeInstructions>>> {
29-
bool isFunctionParallel() { return true; }
29+
bool isFunctionParallel() override { return true; }
30+
31+
Pass* create() override { return new OptimizeInstructions; }
3032

3133
void visitIf(If* curr) {
3234
// flip branches to get rid of an i32.eqz

src/passes/PostEmscripten.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
namespace wasm {
2626

2727
struct PostEmscripten : public WalkerPass<PostWalker<PostEmscripten, Visitor<PostEmscripten>>> {
28-
bool isFunctionParallel() { return true; }
28+
bool isFunctionParallel() override { return true; }
29+
30+
Pass* create() override { return new PostEmscripten; }
2931

3032
// When we have a Load from a local value (typically a GetLocal) plus a constant offset,
3133
// we may be able to fold it in.

src/passes/Print.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -628,9 +628,9 @@ static RegisterPass<Printer> registerPass("print", "print in s-expression format
628628
// Prints out a minified module
629629

630630
class MinifiedPrinter : public Printer {
631-
public:
631+
public:
632632
MinifiedPrinter() : Printer() {}
633-
MinifiedPrinter(std::ostream& o) : Printer(o) {}
633+
MinifiedPrinter(std::ostream* o) : Printer(o) {}
634634

635635
void run(PassRunner* runner, Module* module) override {
636636
PrintSExpression print(o);
@@ -644,9 +644,9 @@ static RegisterPass<MinifiedPrinter> registerMinifyPass("print-minified", "print
644644
// Prints out a module withough elision, i.e., the full ast
645645

646646
class FullPrinter : public Printer {
647-
public:
647+
public:
648648
FullPrinter() : Printer() {}
649-
FullPrinter(std::ostream& o) : Printer(o) {}
649+
FullPrinter(std::ostream* o) : Printer(o) {}
650650

651651
void run(PassRunner* runner, Module* module) override {
652652
PrintSExpression print(o);

src/passes/RemoveUnusedBrs.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
namespace wasm {
2626

2727
struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<RemoveUnusedBrs>>> {
28-
bool isFunctionParallel() { return true; }
28+
bool isFunctionParallel() override { return true; }
29+
30+
Pass* create() override { return new RemoveUnusedBrs; }
2931

3032
bool anotherCycle;
3133

0 commit comments

Comments
 (0)