Skip to content

Commit 92afb40

Browse files
authored
Some simple integer math opts (WebAssembly#1504)
Stuff like x + 5 != 2 => x != -3. Also some cleanups of utility functions I noticed while writing this, isTypeFloat => isFloatType. Inspired by https://github.com/golang/go/blob/master/src/cmd/compile/internal/ssa/gen/generic.rules
1 parent 9b5ce47 commit 92afb40

18 files changed

Lines changed: 945 additions & 43 deletions

build-js.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ if [ "$1" == "-g" ]; then
4949
EMCC_ARGS="$EMCC_ARGS -O2" # need emcc js opts to be decently fast
5050
EMCC_ARGS="$EMCC_ARGS --llvm-opts 0 --llvm-lto 0"
5151
EMCC_ARGS="$EMCC_ARGS -profiling"
52+
EMCC_ARGS="$EMCC_ARGS -s ASSERTIONS=1"
5253
else
5354
EMCC_ARGS="$EMCC_ARGS -Oz"
5455
EMCC_ARGS="$EMCC_ARGS --llvm-lto 1"

src/asm2wasm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1719,7 +1719,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
17191719
ret->right = process(ast[3]);
17201720
ret->op = parseAsmBinaryOp(ast[1]->getIString(), ast[2], ast[3], ret->left, ret->right);
17211721
ret->finalize();
1722-
if (ret->op == BinaryOp::RemSInt32 && isTypeFloat(ret->type)) {
1722+
if (ret->op == BinaryOp::RemSInt32 && isFloatType(ret->type)) {
17231723
// WebAssembly does not have floating-point remainder, we have to emit a call to a special import of ours
17241724
CallImport *call = allocator.alloc<CallImport>();
17251725
call->target = F64_REM;

src/ir/abstract.h

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright 2018 WebAssembly Community Group participants
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
// Abstracts out operations from specific opcodes.
18+
19+
#ifndef wasm_ir_abstract_h
20+
#define wasm_ir_abstract_h
21+
22+
#include <wasm.h>
23+
24+
namespace wasm {
25+
26+
namespace Abstract {
27+
28+
enum Op {
29+
// Unary
30+
Neg,
31+
// Binary
32+
Add, Sub, Mul, DivU, DivS, Rem, RemU, RemS,
33+
Shl, ShrU, ShrS,
34+
And, Or, Xor,
35+
// Relational
36+
Eq, Ne,
37+
};
38+
39+
// Provide a wasm type and an abstract op and get the concrete one. For example, you can
40+
// provide i32 and Add and receive the specific opcode for a 32-bit addition,
41+
// AddInt32.
42+
// If the op does not exist, it returns Invalid.
43+
inline UnaryOp getUnary(Type type, Op op) {
44+
switch (type) {
45+
case i32: {
46+
switch (op) {
47+
default: return InvalidUnary;
48+
}
49+
break;
50+
}
51+
case i64: {
52+
switch (op) {
53+
default: return InvalidUnary;
54+
}
55+
break;
56+
}
57+
case f32: {
58+
switch (op) {
59+
case Neg: return NegFloat32;
60+
default: return InvalidUnary;
61+
}
62+
break;
63+
}
64+
case f64: {
65+
switch (op) {
66+
case Neg: return NegFloat64;
67+
default: return InvalidUnary;
68+
}
69+
break;
70+
}
71+
default: return InvalidUnary;
72+
}
73+
WASM_UNREACHABLE();
74+
}
75+
76+
inline BinaryOp getBinary(Type type, Op op) {
77+
switch (type) {
78+
case i32: {
79+
switch (op) {
80+
case Add: return AddInt32;
81+
case Sub: return SubInt32;
82+
case Mul: return MulInt32;
83+
case DivU: return DivUInt32;
84+
case DivS: return DivSInt32;
85+
case RemU: return RemUInt32;
86+
case RemS: return RemSInt32;
87+
case Shl: return ShlInt32;
88+
case ShrU: return ShrUInt32;
89+
case ShrS: return ShrSInt32;
90+
case And: return AndInt32;
91+
case Or: return OrInt32;
92+
case Xor: return XorInt32;
93+
case Eq: return EqInt32;
94+
case Ne: return NeInt32;
95+
default: return InvalidBinary;
96+
}
97+
break;
98+
}
99+
case i64: {
100+
switch (op) {
101+
case Add: return AddInt64;
102+
case Sub: return SubInt64;
103+
case Mul: return MulInt64;
104+
case DivU: return DivUInt64;
105+
case DivS: return DivSInt64;
106+
case RemU: return RemUInt64;
107+
case RemS: return RemSInt64;
108+
case Shl: return ShlInt64;
109+
case ShrU: return ShrUInt64;
110+
case ShrS: return ShrSInt64;
111+
case And: return AndInt64;
112+
case Or: return OrInt64;
113+
case Xor: return XorInt64;
114+
case Eq: return EqInt64;
115+
case Ne: return NeInt64;
116+
default: return InvalidBinary;
117+
}
118+
break;
119+
}
120+
case f32: {
121+
switch (op) {
122+
case Add: return AddFloat32;
123+
case Sub: return SubFloat32;
124+
case Mul: return MulFloat32;
125+
case DivU: return DivFloat32;
126+
case DivS: return DivFloat32;
127+
case Eq: return EqFloat32;
128+
case Ne: return NeFloat32;
129+
default: return InvalidBinary;
130+
}
131+
break;
132+
}
133+
case f64: {
134+
switch (op) {
135+
case Add: return AddFloat64;
136+
case Sub: return SubFloat64;
137+
case Mul: return MulFloat64;
138+
case DivU: return DivFloat64;
139+
case DivS: return DivFloat64;
140+
case Eq: return EqFloat64;
141+
case Ne: return NeFloat64;
142+
default: return InvalidBinary;
143+
}
144+
break;
145+
}
146+
default: return InvalidBinary;
147+
}
148+
WASM_UNREACHABLE();
149+
}
150+
151+
} // namespace Abstract
152+
153+
} // namespace wasm
154+
155+
#endif // wasm_ir_abstract_h
156+

src/ir/load-utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace LoadUtils {
2929
inline bool isSignRelevant(Load* load) {
3030
auto type = load->type;
3131
if (load->type == unreachable) return false;
32-
return !isTypeFloat(type) && load->bytes < getTypeSize(type);
32+
return !isFloatType(type) && load->bytes < getTypeSize(type);
3333
}
3434

3535
// check if a load can be signed (which some opts want to do)

src/parsing.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ inline Expression* parseConst(cashew::IString s, Type type, MixedArena& allocato
8080
const char *str = s.str;
8181
auto ret = allocator.alloc<Const>();
8282
ret->type = type;
83-
if (isTypeFloat(type)) {
83+
if (isFloatType(type)) {
8484
if (s == _INFINITY) {
8585
switch (type) {
8686
case f32: ret->value = Literal(std::numeric_limits<float>::infinity()); break;

src/passes/OptimizeInstructions.cpp

Lines changed: 121 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <pass.h>
2525
#include <wasm-s-parser.h>
2626
#include <support/threads.h>
27+
#include <ir/abstract.h>
2728
#include <ir/utils.h>
2829
#include <ir/cost.h>
2930
#include <ir/effects.h>
@@ -541,9 +542,11 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
541542
}
542543
}
543544
}
544-
return optimizeAddedConstants(binary);
545+
auto* ret = optimizeAddedConstants(binary);
546+
if (ret) return ret;
545547
} else if (binary->op == SubInt32) {
546-
return optimizeAddedConstants(binary);
548+
auto* ret = optimizeAddedConstants(binary);
549+
if (ret) return ret;
547550
}
548551
// a bunch of operations on a constant right side can be simplified
549552
if (auto* right = binary->right->dynCast<Const>()) {
@@ -567,19 +570,9 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
567570
}
568571
}
569572
}
570-
// some math operations have trivial results TODO: many more
571-
if (right->value == Literal(int32_t(0))) {
572-
if (binary->op == ShlInt32 || binary->op == ShrUInt32 || binary->op == ShrSInt32 || binary->op == OrInt32) {
573-
return binary->left;
574-
} else if ((binary->op == MulInt32 || binary->op == AndInt32) &&
575-
!EffectAnalyzer(getPassOptions(), binary->left).hasSideEffects()) {
576-
return binary->right;
577-
}
578-
} else if (right->value == Literal(int32_t(1))) {
579-
if (binary->op == MulInt32) {
580-
return binary->left;
581-
}
582-
}
573+
// some math operations have trivial results
574+
Expression* ret = optimizeWithConstantOnRight(binary);
575+
if (ret) return ret;
583576
// the square of some operations can be merged
584577
if (auto* left = binary->left->dynCast<Binary>()) {
585578
if (left->op == binary->op) {
@@ -645,6 +638,10 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
645638
if (binary->op == AndInt32 || binary->op == OrInt32) {
646639
return conditionalizeExpensiveOnBitwise(binary);
647640
}
641+
// relation/comparisons allow for math optimizations
642+
if (binary->isRelational()) {
643+
return optimizeRelational(binary);
644+
}
648645
} else if (auto* unary = curr->dynCast<Unary>()) {
649646
// de-morgan's laws
650647
if (unary->op == EqZInt32) {
@@ -1098,6 +1095,115 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
10981095
}
10991096
return false;
11001097
}
1098+
1099+
// optimize trivial math operations, given that the right side of a binary
1100+
// is a constant
1101+
// TODO: templatize on type?
1102+
Expression* optimizeWithConstantOnRight(Binary* binary) {
1103+
auto type = binary->right->type;
1104+
auto* right = binary->right->cast<Const>();
1105+
if (isIntegerType(type)) {
1106+
// operations on zero
1107+
if (right->value == LiteralUtils::makeLiteralFromInt32(0, type)) {
1108+
if (binary->op == Abstract::getBinary(type, Abstract::Shl) ||
1109+
binary->op == Abstract::getBinary(type, Abstract::ShrU) ||
1110+
binary->op == Abstract::getBinary(type, Abstract::ShrS) ||
1111+
binary->op == Abstract::getBinary(type, Abstract::Or)) {
1112+
return binary->left;
1113+
} else if ((binary->op == Abstract::getBinary(type, Abstract::Mul) ||
1114+
binary->op == Abstract::getBinary(type, Abstract::And)) &&
1115+
!EffectAnalyzer(getPassOptions(), binary->left).hasSideEffects()) {
1116+
return binary->right;
1117+
}
1118+
}
1119+
// wasm binary encoding uses signed LEBs, which slightly favor negative
1120+
// numbers: -64 is more efficient than +64 etc. we therefore prefer
1121+
// x - -64 over x + 64
1122+
if (binary->op == Abstract::getBinary(type, Abstract::Add) ||
1123+
binary->op == Abstract::getBinary(type, Abstract::Sub)) {
1124+
auto value = right->value.getInteger();
1125+
if (value == 0x40 ||
1126+
value == 0x2000 ||
1127+
value == 0x100000 ||
1128+
value == 0x8000000 ||
1129+
value == 0x400000000LL ||
1130+
value == 0x20000000000LL ||
1131+
value == 0x1000000000000LL ||
1132+
value == 0x80000000000000LL ||
1133+
value == 0x4000000000000000LL) {
1134+
right->value = right->value.neg();
1135+
if (binary->op == Abstract::getBinary(type, Abstract::Add)) {
1136+
binary->op = Abstract::getBinary(type, Abstract::Sub);
1137+
} else {
1138+
binary->op = Abstract::getBinary(type, Abstract::Add);
1139+
}
1140+
}
1141+
}
1142+
}
1143+
// note that this is correct even on floats with a NaN on the left,
1144+
// as a NaN would skip the computation and just return the NaN,
1145+
// and that is precisely what we do here. but, the same with -1
1146+
// (change to a negation) would be incorrect for that reason.
1147+
if (right->value == LiteralUtils::makeLiteralFromInt32(1, type)) {
1148+
if (binary->op == Abstract::getBinary(type, Abstract::Mul) ||
1149+
binary->op == Abstract::getBinary(type, Abstract::DivS) ||
1150+
binary->op == Abstract::getBinary(type, Abstract::DivU)) {
1151+
return binary->left;
1152+
}
1153+
}
1154+
return nullptr;
1155+
}
1156+
1157+
// integer math, even on 2s complement, allows stuff like
1158+
// x + 5 == 7
1159+
// =>
1160+
// x == 2
1161+
// TODO: templatize on type?
1162+
Expression* optimizeRelational(Binary* binary) {
1163+
// TODO: inequalities can also work, if the constants do not overflow
1164+
auto type = binary->right->type;
1165+
if (binary->op ==Abstract::getBinary(type, Abstract::Eq) ||
1166+
binary->op ==Abstract::getBinary(type, Abstract::Ne)) {
1167+
if (isIntegerType(binary->left->type)) {
1168+
if (auto* left = binary->left->dynCast<Binary>()) {
1169+
if (left->op == Abstract::getBinary(type, Abstract::Add) ||
1170+
left->op == Abstract::getBinary(type, Abstract::Sub)) {
1171+
if (auto* leftConst = left->right->dynCast<Const>()) {
1172+
if (auto* rightConst = binary->right->dynCast<Const>()) {
1173+
return combineRelationalConstants(binary, left, leftConst, nullptr, rightConst);
1174+
} else if (auto* rightBinary = binary->right->dynCast<Binary>()) {
1175+
if (rightBinary->op == Abstract::getBinary(type, Abstract::Add) ||
1176+
rightBinary->op == Abstract::getBinary(type, Abstract::Sub)) {
1177+
if (auto* rightConst = rightBinary->right->dynCast<Const>()) {
1178+
return combineRelationalConstants(binary, left, leftConst, rightBinary, rightConst);
1179+
}
1180+
}
1181+
}
1182+
}
1183+
}
1184+
}
1185+
}
1186+
}
1187+
return nullptr;
1188+
}
1189+
1190+
// given a relational binary with a const on both sides, combine the constants
1191+
// left is also a binary, and has a constant; right may be just a constant, in which
1192+
// case right is nullptr
1193+
Expression* combineRelationalConstants(Binary* binary, Binary* left, Const* leftConst, Binary* right, Const* rightConst) {
1194+
auto type = binary->right->type;
1195+
// we fold constants to the right
1196+
Literal extra = leftConst->value;
1197+
if (left->op == Abstract::getBinary(type, Abstract::Sub)) {
1198+
extra = extra.neg();
1199+
}
1200+
if (right && right->op == Abstract::getBinary(type, Abstract::Sub)) {
1201+
extra = extra.neg();
1202+
}
1203+
rightConst->value = rightConst->value.sub(extra);
1204+
binary->left = left->left;
1205+
return binary;
1206+
}
11011207
};
11021208

11031209
Pass *createOptimizeInstructionsPass() {

src/passes/SafeHeap.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ static Name getLoadName(Load* curr) {
3838
std::string ret = "SAFE_HEAP_LOAD_";
3939
ret += printType(curr->type);
4040
ret += "_" + std::to_string(curr->bytes) + "_";
41-
if (!isTypeFloat(curr->type) && !curr->signed_) {
41+
if (!isFloatType(curr->type) && !curr->signed_) {
4242
ret += "U_";
4343
}
4444
if (curr->isAtomic) {
@@ -160,7 +160,7 @@ struct SafeHeap : public Pass {
160160
if (bytes > getTypeSize(type)) continue;
161161
for (auto signed_ : { true, false }) {
162162
load.signed_ = signed_;
163-
if (isTypeFloat(type) && signed_) continue;
163+
if (isFloatType(type) && signed_) continue;
164164
for (Index align : { 1, 2, 4, 8 }) {
165165
load.align = align;
166166
if (align > bytes) continue;

0 commit comments

Comments
 (0)