Skip to content

Commit f0b3ca7

Browse files
author
Nikhil Thorat
authored
[WASM] Add FusedConv2D which only supports fusing bias. (tensorflow#2356)
FEATURE - Moves the calls to xnn pack conv2d to a shared utility. - Adds an optional bias. If bias id is -1, we pass a nullptr to xnn pack, otherwise we pass the bias buffer. - Cache operators from bias and filter ids. If a bias **or** a filter dies, the associated xnn operator dies.
1 parent a42cb30 commit f0b3ca7

17 files changed

Lines changed: 823 additions & 194 deletions

File tree

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"shelljs": "~0.8.3",
88
"ts-node": "~4.1.0",
99
"tslint": "~5.20.0",
10-
"typescript": "3.5.3"
10+
"typescript": "3.6.3"
1111
},
1212
"scripts": {
1313
"diff": "./scripts/diff.js",

tfjs-backend-wasm/package.json

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@
3636
"karma-jasmine": "~1.1.1",
3737
"karma-typescript": "~4.0.0",
3838
"rimraf": "~2.6.2",
39-
"rollup": "^1.17.0",
40-
"rollup-plugin-commonjs": "^10.0.1",
41-
"rollup-plugin-node-resolve": "^5.2.0",
39+
"rollup": "~1.26.3",
40+
"rollup-plugin-commonjs": "~10.1.0",
41+
"rollup-plugin-node-resolve": "~5.2.0",
4242
"rollup-plugin-terser": "^5.1.1",
43-
"rollup-plugin-typescript2": "^0.22.1",
43+
"rollup-plugin-typescript2": "~0.25.2",
4444
"tslint": "^5.20.0",
4545
"tslint-no-circular-imports": "^0.7.0",
46-
"typescript": "3.5.3",
46+
"typescript": "3.6.3",
4747
"yalc": "~1.0.0-pre.21"
4848
},
4949
"license": "Apache-2.0"

tfjs-backend-wasm/scripts/test-bundle-size.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ const {exec} = require('../../scripts/test-util');
1919
const {showDiff, getFileSizeBytes} = require('../../scripts/bundle-size-util');
2020

2121
// Get the bundle sizes from this change.
22-
exec(`yarn rollup -c`, {silent: true});
22+
exec(`yarn rollup -c`, {silent: false});
2323

2424
const bundleFilename = 'dist/tf-backend-wasm.min.js';
2525
const minBundleSize = getFileSizeBytes(bundleFilename);
@@ -36,7 +36,7 @@ exec(
3636

3737
shell.cd(dirName);
3838
shell.cd(wasmDirName);
39-
exec(`yarn && ./scripts/build-ci.sh && yarn rollup -c`, {silent: true});
39+
exec(`yarn && ./scripts/build-ci.sh && yarn rollup -c`, {silent: false});
4040

4141
const masterMinBundleSize = getFileSizeBytes(bundleFilename);
4242
const masterWasmSize = getFileSizeBytes(wasmFileName);

tfjs-backend-wasm/src/cc/BUILD

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,17 @@ tfjs_cc_library(
6262
deps = [":util"],
6363
)
6464

65+
tfjs_cc_library(
66+
name = "conv2d_impl",
67+
hdrs = ["conv2d_impl.h"],
68+
srcs = ["conv2d_impl.cc"],
69+
deps = [
70+
":backend",
71+
":transpose_impl",
72+
":util",
73+
],
74+
)
75+
6576
tfjs_cc_library(
6677
name = "all_kernels",
6778
deps = [
@@ -70,6 +81,7 @@ tfjs_cc_library(
7081
":BatchMatMul",
7182
":CropAndResize",
7283
":Conv2D",
84+
":FusedConv2D",
7385
":Div",
7486
":Mul",
7587
":Prelu",
@@ -151,9 +163,16 @@ tfjs_cc_library(
151163
srcs = ["kernels/Conv2D.cc"],
152164
hdrs = ["kernels/Conv2D.h"],
153165
deps = [
154-
":backend",
155-
":transpose_impl",
156-
":util",
166+
":conv2d_impl",
167+
],
168+
)
169+
170+
tfjs_cc_library(
171+
name = "FusedConv2D",
172+
srcs = ["kernels/FusedConv2D.cc"],
173+
hdrs = ["kernels/FusedConv2D.h"],
174+
deps = [
175+
":conv2d_impl",
157176
],
158177
)
159178

@@ -260,3 +279,11 @@ tfjs_unit_test(
260279
":Conv2D",
261280
]
262281
)
282+
283+
tfjs_unit_test(
284+
name = "FusedConv2D_test",
285+
srcs = ["kernels/FusedConv2D_test.cc"],
286+
deps = [
287+
":FusedConv2D",
288+
]
289+
)
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/* Copyright 2019 Google Inc. All Rights Reserved.
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
* ===========================================================================*/
14+
15+
#ifdef __EMSCRIPTEN__
16+
#include <emscripten.h>
17+
#endif
18+
19+
#include "src/cc/conv2d_impl.h"
20+
21+
#include <xnnpack.h>
22+
#include <array>
23+
#include <cmath>
24+
#include <limits>
25+
#include <map>
26+
#include <memory>
27+
#include <unordered_map>
28+
#include <utility>
29+
#include <vector>
30+
31+
#include "src/cc/backend.h"
32+
#include "src/cc/transpose_impl.h"
33+
#include "src/cc/util.h"
34+
35+
namespace {
36+
// These integer values are keys to creating the conv2d operator. We use
37+
// std::array instead of a vanilla array as it implements the compare operator
38+
// needed for std::map.
39+
typedef std::array<int, 16> OperatorCacheKey;
40+
41+
// The operator cache maps the cache key to the xnn_operator_t instantiated for
42+
// this set of arguments to the xnn_operator.
43+
std::map<OperatorCacheKey, xnn_operator_t> operator_cache;
44+
45+
// Maps a filter id to a list of operator cache keys that this filter belongs
46+
// to.
47+
std::unordered_map<int, std::vector<OperatorCacheKey>>
48+
filter_operator_cache_key_map;
49+
50+
// Maps a bias id to a list of operator cache keys that this filter belongs
51+
// to.
52+
std::unordered_map<int, std::vector<OperatorCacheKey>>
53+
bias_operator_cache_key_map;
54+
55+
void erase_from_cache(const int tensor_id,
56+
std::unordered_map<int, std::vector<OperatorCacheKey>>&
57+
operator_cache_key_map) {
58+
auto operator_cache_keys_idx = operator_cache_key_map.find(tensor_id);
59+
if (operator_cache_keys_idx != operator_cache_key_map.end()) {
60+
std::vector<OperatorCacheKey> operator_cache_keys =
61+
operator_cache_keys_idx->second;
62+
for (auto& operator_cache_key : operator_cache_keys) {
63+
auto operator_cache_key_idx = operator_cache.find(operator_cache_key);
64+
if (operator_cache_key_idx != operator_cache.end()) {
65+
auto& conv2d_op = operator_cache_key_idx->second;
66+
67+
xnn_delete_operator(conv2d_op);
68+
tfjs::backend::xnn_operator_count--;
69+
70+
operator_cache.erase(operator_cache_key);
71+
}
72+
}
73+
operator_cache_key_map.erase(tensor_id);
74+
}
75+
}
76+
77+
void delete_xnn_operators(int tensor_id) {
78+
erase_from_cache(tensor_id, filter_operator_cache_key_map);
79+
erase_from_cache(tensor_id, bias_operator_cache_key_map);
80+
}
81+
82+
void associate_tensor_with_key(
83+
const int tensor_id, const OperatorCacheKey& cache_key,
84+
std::unordered_map<int, std::vector<OperatorCacheKey>>&
85+
operator_cache_key_map) {
86+
auto cache_keys_idx = operator_cache_key_map.find(tensor_id);
87+
if (cache_keys_idx == operator_cache_key_map.end()) {
88+
std::vector<OperatorCacheKey> cache_keys = {cache_key};
89+
operator_cache_key_map.emplace(tensor_id, std::move(cache_keys));
90+
tfjs::backend::register_disposal_callback(tensor_id, *delete_xnn_operators);
91+
92+
} else {
93+
auto& cache_keys = operator_cache_key_map.at(tensor_id);
94+
cache_keys.emplace_back(cache_key);
95+
}
96+
}
97+
98+
} // namespace
99+
100+
namespace tfjs {
101+
namespace wasm {
102+
103+
void conv2d(const int x_id, const int batch_size, const int input_height,
104+
const int input_width, const int filter_id, const int filter_height,
105+
const int filter_width, const int bias_id, int pad_top,
106+
int pad_right, int pad_bottom, int pad_left, const int is_same_pad,
107+
const int dilation_height, const int dilation_width,
108+
const int stride_height, const int stride_width,
109+
const int input_channels, const int output_channels,
110+
const int out_id) {
111+
auto& x_info = backend::get_tensor_info(x_id);
112+
auto& filter_info = backend::get_tensor_info(filter_id);
113+
auto& out_info = backend::get_tensor_info_out(out_id);
114+
115+
const float* x_buf = x_info.f32();
116+
const float* filter_buf = filter_info.f32();
117+
const float* bias_buf = nullptr;
118+
if (bias_id != -1) {
119+
bias_buf = backend::get_tensor_info_out(bias_id).f32();
120+
}
121+
float* out_buf = out_info.f32_write();
122+
123+
xnn_operator_t conv2d_op = nullptr;
124+
125+
int flags = 0;
126+
if (is_same_pad) {
127+
pad_top = 0, pad_right = 0, pad_bottom = 0, pad_left = 0;
128+
flags = XNN_FLAG_TENSORFLOW_SAME_PADDING;
129+
}
130+
131+
const int groups = 1;
132+
133+
OperatorCacheKey cache_key = {
134+
pad_top, pad_right, pad_bottom, pad_left,
135+
filter_height, filter_width, stride_height, stride_width,
136+
dilation_height, dilation_width, groups, input_channels,
137+
output_channels, filter_id, bias_id, flags};
138+
139+
auto operator_cache_idx = operator_cache.find(cache_key);
140+
if (operator_cache_idx == operator_cache.end()) {
141+
float output_min = -std::numeric_limits<float>::infinity();
142+
float output_max = std::numeric_limits<float>::infinity();
143+
144+
// xnn pack expects weights layed out like:
145+
// [output_channels, filter_height, filter_width, input_channels]
146+
// TensorFlow has weights layed out like:
147+
// [filter_height, filter_width, input_channels, output_channels]
148+
// This can be transposed with a 2d transpose to move output_channels to the
149+
// outer most dimension.
150+
std::vector<float> transposed_filter(filter_info.size);
151+
152+
const std::vector<int> filter_shape = {
153+
filter_height * filter_width * input_channels, output_channels};
154+
const std::vector<int> perm = {1, 0};
155+
tfjs::wasm::transpose(filter_buf, filter_shape, perm,
156+
transposed_filter.data());
157+
158+
xnn_status status = xnn_create_convolution2d_nhwc_f32(
159+
pad_top, pad_right, pad_bottom, pad_left, filter_height, filter_width,
160+
stride_height, stride_width, dilation_height, dilation_width, groups,
161+
input_channels /* group_input_channels */,
162+
output_channels /* group_output_channels */,
163+
input_channels /* input_pixel_stride */,
164+
output_channels /* output_pixel_stride */, transposed_filter.data(),
165+
bias_buf, output_min, output_max, flags, &conv2d_op);
166+
if (status != xnn_status_success) {
167+
util::warn(
168+
"XNN status for xnn_create_convolution2d_nhwc_f32 is not successful. "
169+
"Got status %d. Use -c dbg to see XNN logs.",
170+
status);
171+
}
172+
173+
operator_cache.emplace(cache_key, conv2d_op);
174+
175+
associate_tensor_with_key(filter_id, cache_key,
176+
filter_operator_cache_key_map);
177+
if (bias_id != -1) {
178+
associate_tensor_with_key(bias_id, cache_key,
179+
bias_operator_cache_key_map);
180+
}
181+
182+
tfjs::backend::xnn_operator_count++;
183+
} else {
184+
conv2d_op = operator_cache_idx->second;
185+
}
186+
187+
xnn_status status = xnn_setup_convolution2d_nhwc_f32(
188+
conv2d_op, batch_size, input_height, input_width, x_buf, out_buf,
189+
nullptr /* thread pool */);
190+
if (status != xnn_status_success) {
191+
util::warn(
192+
"XNN status for xnn_setup_convolution2d_nhwc_f32 is not successful. "
193+
"Got status %d. Use -c dbg to see XNN logs.",
194+
status);
195+
}
196+
197+
xnn_run_operator(conv2d_op, nullptr /* thread pool */);
198+
}
199+
200+
} // namespace wasm
201+
} // namespace tfjs
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/* Copyright 2019 Google Inc. All Rights Reserved.
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
* ===========================================================================*/
14+
15+
#ifndef CONV2D_IMPL_H_
16+
#define CONV2D_IMPL_H_
17+
18+
namespace tfjs {
19+
namespace wasm {
20+
21+
void conv2d(const int x_id, const int batch_size, const int input_height,
22+
const int input_width, const int filter_id, const int filter_height,
23+
const int filter_width, const int bias_id, int pad_top,
24+
int pad_right, int pad_bottom, int pad_left, const int is_same_pad,
25+
const int dilation_height, const int dilation_width,
26+
const int stride_height, const int stride_width,
27+
const int input_channels, const int output_channels,
28+
const int out_id);
29+
} // namespace wasm
30+
} // namespace tfjs
31+
32+
#endif // CONV2D_IMPL_H_

0 commit comments

Comments
 (0)