diff --git a/src/ir/effects.cpp b/src/ir/effects.cpp index 2f9dbddad7e..a3ce1a95aaf 100644 --- a/src/ir/effects.cpp +++ b/src/ir/effects.cpp @@ -19,7 +19,7 @@ namespace std { -std::ostream& operator<<(std::ostream& o, wasm::EffectAnalyzer& effects) { +std::ostream& operator<<(std::ostream& o, const wasm::EffectAnalyzer& effects) { o << "EffectAnalyzer {\n"; if (effects.branchesOut) { o << "branchesOut\n"; diff --git a/src/ir/effects.h b/src/ir/effects.h index d1d38af595a..7841ad13536 100644 --- a/src/ir/effects.h +++ b/src/ir/effects.h @@ -57,6 +57,9 @@ class EffectAnalyzer { walk(func); } + // EffectAnalyzer(const EffectAnalyzer&) = default; + // EffectAnalyzer& operator=(const EffectAnalyzer&) = default; + bool ignoreImplicitTraps : 1; bool trapsNeverHappen : 1; @@ -1428,6 +1431,97 @@ class EffectAnalyzer { assert(!transfersControlFlow()); } + int count() const { + size_t count = 0; + auto& effects = *this; + + if (effects.branchesOut) + count++; + if (effects.calls) + count++; + + // Size-based effects (counted as 1 if not empty) + if (!effects.localsRead.empty()) + count++; + if (!effects.localsWritten.empty()) + count++; + if (!effects.mutableGlobalsRead.empty()) + count++; + if (!effects.globalsWritten.empty()) + count++; + + if (effects.readsMemory) + count++; + if (effects.writesMemory) + count++; + if (effects.readsSharedMemory) + count++; + if (effects.writesSharedMemory) + count++; + if (effects.readsTable) + count++; + if (effects.writesTable) + count++; + if (effects.readsMutableStruct) + count++; + if (effects.writesStruct) + count++; + if (effects.readsSharedMutableStruct) + count++; + if (effects.writesSharedStruct) + count++; + if (effects.readsMutableArray) + count++; + if (effects.writesArray) + count++; + if (effects.readsSharedMutableArray) + count++; + if (effects.writesSharedArray) + count++; + if (effects.trap) + count++; + if (effects.implicitTrap) + count++; + + // Order-based effects + if (effects.readOrder != wasm::MemoryOrder::Unordered) + count++; + if (effects.writeOrder != wasm::MemoryOrder::Unordered) + count++; + + if (effects.throws_) + count++; + if (effects.tryDepth) + count++; + if (effects.catchDepth) + count++; + if (effects.danglingPop) + count++; + if (effects.mayNotReturn) + count++; + if (effects.hasReturnCallThrow) + count++; + + // Method-based checks + // if (effects.accessesLocal()) count++; + // if (effects.accessesMutableGlobal()) count++; + // if (effects.accessesMemory()) count++; + // if (effects.accessesTable()) count++; + // if (effects.accessesMutableStruct()) count++; + // if (effects.accessesArray()) count++; + // if (effects.throws()) count++; + // if (effects.transfersControlFlow()) count++; + // if (effects.writesGlobalState()) count++; + // if (effects.readsMutableGlobalState()) count++; + // if (effects.hasNonTrapSideEffects()) count++; + // if (effects.hasSideEffects()) count++; + // if (effects.hasUnremovableSideEffects()) count++; + // if (effects.hasAnything()) count++; + // if (effects.hasExternalBreakTargets()) count++; + + return count; + } + private: void post() { assert(tryDepth == 0); @@ -1457,7 +1551,7 @@ class ShallowEffectAnalyzer : public EffectAnalyzer { } // namespace wasm namespace std { -std::ostream& operator<<(std::ostream& o, wasm::EffectAnalyzer& effects); +std::ostream& operator<<(std::ostream& o, const wasm::EffectAnalyzer& effects); } // namespace std #endif // wasm_ir_effects_h diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index c2952e174b8..0260c1823f4 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -30,6 +30,7 @@ set(passes_SOURCES DeadArgumentElimination.cpp DeadArgumentElimination2.cpp DeadCodeElimination.cpp + DeadStoreElimination.cpp DeAlign.cpp DebugLocationPropagation.cpp DeNaN.cpp diff --git a/src/passes/DeadStoreElimination.cpp b/src/passes/DeadStoreElimination.cpp new file mode 100644 index 00000000000..ae8e14f8703 --- /dev/null +++ b/src/passes/DeadStoreElimination.cpp @@ -0,0 +1,196 @@ +#include "analysis/cfg.h" +#include "ir/local-graph.h" +#include "ir/module-utils.h" +#include "ir/properties.h" +#include "pass.h" +#include "passes/passes.h" + +namespace wasm { +namespace { + +std::mutex m; + +struct ComparingLocalGraph : public LocalGraph { + PassOptions& passOptions; + Module& wasm; + + ComparingLocalGraph(Function* func, PassOptions& passOptions, Module& wasm) + : LocalGraph(func), passOptions(passOptions), wasm(wasm) {} + + // Check whether the values of two expressions will definitely be equal at + // runtime. + // TODO: move to LocalGraph if we find more users? + bool equalValues(Expression* a, Expression* b) { + a = Properties::getFallthrough(a, passOptions, wasm); + b = Properties::getFallthrough(b, passOptions, wasm); + if (auto* aGet = a->dynCast()) { + if (auto* bGet = b->dynCast()) { + if (LocalGraph::equivalent(aGet, bGet)) { + return true; + } + } + } + + // not relevant + // if (auto* aConst = a->dynCast()) { + // if (auto* bConst = b->dynCast()) { + // return aConst->value == bConst->value; + // } + // } + return false; + } +}; + +struct StoreInfo { + const StructSet* store = nullptr; + int duplicateStores = 0; + + // A struct.get observed this store. It may or may not be dead besides that + // (check `duplicateStores` for that) but it is definitely not dead because of + // the get. + int conflictingGets = 0; + + // This is counted differently. If this is set, then there is definitely a + // duplicate store, and this would definitely be dead if it weren't for these + // effects (AND it's possible that there are conflicting gets as well). + std::optional conflictingEffects = std::nullopt; + + friend std::ostream& operator<<(std::ostream& os, const StoreInfo& info) { + os << "StoreInfo { "; + + // Handle the pointer to StructSet + os << "store: "; + if (info.store) { + os << *info.store; + } else { + os << "nullptr"; + } + + os << ", duplicateStores: " << info.duplicateStores; + os << ", conflictingGets: " << info.conflictingGets; + + // Handle the std::optional EffectAnalyzer + os << ", conflictingEffects: "; + if (info.conflictingEffects.has_value()) { + os << *info.conflictingEffects; + } else { + os << "none"; + } + + os << " }"; + return os; + } + + // } +}; + +using Info = std::variant; + +class DeadStoreEliminationPass : public Pass { + virtual std::unique_ptr create() { + return std::make_unique(); + } + + bool isFunctionParallel() override { return true; } + + void runOnFunction(Module* module, Function* function) override { + + ComparingLocalGraph localGraph(function, getPassOptions(), *module); + + auto cfg = analysis::CFG::fromFunction(function); + + // todo might want to use a map here + // keyed by the ref expression + std::vector storeInfos; + for (auto& block : cfg) { + for (const auto* inst : block) { + if (const StructSet* structSet = inst->dynCast()) { + // std::vector barriers; + EffectAnalyzer barriers(getPassOptions(), *module); + assert(!barriers.hasAnything()); + + for (auto it = storeInfos.rbegin(); it != storeInfos.rend(); ++it) { + if (auto* storeInfo = std::get_if(&*it)) { + if (localGraph.equalValues(structSet->ref, + storeInfo->store->ref) && + structSet->index == storeInfo->store->index) { + storeInfo->duplicateStores++; + + if (barriers.hasAnything()) { + storeInfo->conflictingEffects.emplace(barriers); + } + break; + } + } else if (auto* barrier = std::get_if(&*it)) { + barriers.mergeIn(*barrier); + } + } + std::cout << "Should have got here twice\n"; + storeInfos.push_back(StoreInfo{structSet}); + } else if (const StructGet* structGet = inst->dynCast()) { + for (auto it = storeInfos.rbegin(); it != storeInfos.rend(); ++it) { + // Don't care about barriers here. + if (!std::holds_alternative(*it)) { + continue; + } + + auto& storeInfo = std::get(*it); + + if (localGraph.equalValues(structGet->ref, storeInfo.store->ref) && + structGet->index == storeInfo.store->index) { + storeInfo.conflictingGets++; + break; + } + } + } else { + ShallowEffectAnalyzer effects( + getPassOptions(), *module, const_cast(inst)); + // Add all the possible effects here + // Maybe prune the ones that matter from effects + if (effects.branchesOut || effects.calls || effects.throws() || + (!getPassOptions().trapsNeverHappen && effects.trap)) { + ShallowEffectAnalyzer prunedEffects(getPassOptions(), *module); + prunedEffects.branchesOut = effects.branchesOut; + prunedEffects.calls = effects.calls; + prunedEffects.throws_ = effects.throws_; + prunedEffects.delegateTargets = effects.delegateTargets; + // trap left out because we're using TNH in practice for now + storeInfos.push_back(prunedEffects); + } + } + } + } + + for (const auto& info : storeInfos) { + if (!std::holds_alternative(info)) { + continue; + } + + auto& storeInfo = std::get(info); + + std::cout << storeInfo << "\n"; + + // When running on the small binary and adding a throw, we can tell that + // the store is not dead but the effects don't print out for some reason. + if (storeInfo.conflictingEffects) { + std::lock_guard _(m); + std::cout << "??\n"; + // std::cout<<*const_cast(&*storeInfo.conflictingEffects); + std::cout << *storeInfo.conflictingEffects; + } + // if (storeInfo.duplicateStores && !storeInfo.conflictingGets && + // !storeInfo.conflictingEffects) { + // std::lock_guard _(m); + // std::cout << storeInfo.duplicateStores << "\n"; + // } + } + } +}; + +} // namespace + +Pass* createDeadStoreEliminationPass() { + return new DeadStoreEliminationPass(); +} + +} // namespace wasm \ No newline at end of file diff --git a/src/passes/GlobalEffects.cpp b/src/passes/GlobalEffects.cpp index fbfb278e93f..cf716a50233 100644 --- a/src/passes/GlobalEffects.cpp +++ b/src/passes/GlobalEffects.cpp @@ -84,7 +84,7 @@ struct GenerateGlobalEffects : public Pass { // worst. To do so, clear the effects, which indicates nothing // is known (so anything is possible). // TODO: We could group effects by function type etc. - funcInfo.effects.reset(); + // funcInfo.effects.reset(); } else { // No call here, but update throwing if we see it. (Only do so, // however, if we have effects; if we cleared it - see before - diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index f887075fc2d..6f7668ffb86 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -3971,8 +3971,8 @@ std::ostream& operator<<(std::ostream& o, wasm::Function& func) { return o; } -std::ostream& operator<<(std::ostream& o, wasm::Expression& expression) { - return wasm::printExpression(&expression, o); +std::ostream& operator<<(std::ostream& o, const wasm::Expression& expression) { + return wasm::printExpression(const_cast(&expression), o); } std::ostream& operator<<(std::ostream& o, wasm::Expression* expression) { diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 0e6e28267c2..f6ff1233524 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -90,6 +90,7 @@ bool PassRegistry::isPassHidden(std::string name) { // PassRunner void PassRegistry::registerPasses() { + registerPass("ldse", "removes dead stores", createDeadStoreEliminationPass); registerPass("alignment-lowering", "lower unaligned loads and stores to smaller aligned ones", createAlignmentLoweringPass); diff --git a/src/passes/passes.h b/src/passes/passes.h index be06369a9f8..c03a81838cd 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -22,6 +22,7 @@ namespace wasm { class Pass; // Normal passes: +Pass* createDeadStoreEliminationPass(); Pass* createAbstractTypeRefiningPass(); Pass* createAlignmentLoweringPass(); Pass* createAsyncifyPass(); diff --git a/src/wasm.h b/src/wasm.h index 716cf27a2cb..e32acbdc08a 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -818,7 +818,7 @@ class Expression { } // Print the expression to stderr. Meant for use while debugging. - void dump(); + void dump(std::ostream& o = std::cout) const; }; const char* getExpressionName(Expression* curr); @@ -2766,7 +2766,7 @@ struct ShallowExpression { std::ostream& operator<<(std::ostream& o, wasm::Module& module); std::ostream& operator<<(std::ostream& o, wasm::Function& func); -std::ostream& operator<<(std::ostream& o, wasm::Expression& expression); +std::ostream& operator<<(std::ostream& o, const wasm::Expression& expression); std::ostream& operator<<(std::ostream& o, wasm::ModuleExpression pair); std::ostream& operator<<(std::ostream& o, wasm::ShallowExpression expression); std::ostream& operator<<(std::ostream& o, wasm::ModuleType pair); diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index b9df7a232f6..85b9d50e750 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -124,7 +124,7 @@ Name TUPLE("tuple"); // Expressions -void Expression::dump() { std::cout << *this << '\n'; } +void Expression::dump(std::ostream& o) const { o << *this << '\n'; } const char* getExpressionName(Expression* curr) { switch (curr->_id) { diff --git a/test-dse-out.wat b/test-dse-out.wat new file mode 100644 index 00000000000..95271d0f264 --- /dev/null +++ b/test-dse-out.wat @@ -0,0 +1,14 @@ +(module + (type $s (struct (field (mut i32)))) + (type $1 (func (param (ref $s)))) + (func $asdf (type $1) (param $ref (ref $s)) + (struct.set $s 0 + (local.get $ref) + (i32.const 1) + ) + (struct.set $s 0 + (local.get $ref) + (i32.const 2) + ) + ) +) diff --git a/test-dse.wat b/test-dse.wat new file mode 100644 index 00000000000..6f36f09bc38 --- /dev/null +++ b/test-dse.wat @@ -0,0 +1,11 @@ +(module + (type $s (struct (field (mut i32)))) + (tag $t) + (func $asdf (param $ref (ref $s)) (param $b i32) + (struct.set $s 0 (local.get $ref) (i32.const 1)) + + (if (local.get $b) (then (throw $t))) + + (struct.set $s 0 (local.get $ref) (i32.const 2)) + ) +) diff --git a/test-out.wat b/test-out.wat new file mode 100644 index 00000000000..c246539be88 --- /dev/null +++ b/test-out.wat @@ -0,0 +1,22 @@ +(module + (type $s (struct (field (mut i32)))) + (type $1 (func)) + (type $2 (func (param (ref $s) i32))) + (tag $t (type $1)) + (func $asdf (type $2) (param $ref (ref $s)) (param $b i32) + (struct.set $s 0 + (local.get $ref) + (i32.const 1) + ) + (if + (local.get $b) + (then + (throw $t) + ) + ) + (struct.set $s 0 + (local.get $ref) + (i32.const 2) + ) + ) +) diff --git a/test.wat b/test.wat new file mode 100644 index 00000000000..1237097b8d7 --- /dev/null +++ b/test.wat @@ -0,0 +1,8 @@ +(module + (type $s (struct (field (mut i32)))) + (func $a (param $ref (ref $s)) + (struct.set $s 0 (local.get $ref) (i32.const 5)) + (struct.set $s 0 (local.get $ref) (i32.const 4)) + (struct.set $s 0 (local.get $ref) (i32.const 4)) + ) +) \ No newline at end of file diff --git a/test/lit/dse.wast b/test/lit/dse.wast new file mode 100644 index 00000000000..2bb45098cb9 --- /dev/null +++ b/test/lit/dse.wast @@ -0,0 +1,23 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py and should not be edited. +;; RUN: wasm-opt -all --ldse %s -S -o - | filecheck %s + +(module + (type $struct (struct (field (mut i32)))) + (global $g (ref $struct) (struct.new $struct (i32.const 1))) + + ;; CHECK: (func $a (type $0) + ;; CHECK-NEXT: ) + (func $a + (local $a (ref $struct)) + (local.set $a (global.get $g)) + (struct.set $struct 0 (local.get $a) (i32.const 2)) + (struct.set $struct 0 (local.get $a) (i32.const 3)) + + ;; doesn't work with localGraph + ;; (struct.set $struct 0 (global.get $g) (i32.const 2)) + ;; (struct.set $struct 0 (global.get $g) (i32.const 3)) + ) + ;; CHECK: (func $b (type $0) + ;; CHECK-NEXT: ) + (func $b) +)