|
| 1 | +/* Copyright 2019 Google Inc. All Rights Reserved. |
| 2 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | + * you may not use this file except in compliance with the License. |
| 4 | + * You may obtain a copy of the License at |
| 5 | + * |
| 6 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | + * |
| 8 | + * Unless required by applicable law or agreed to in writing, software |
| 9 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 10 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 11 | + * See the License for the specific language governing permissions and |
| 12 | + * limitations under the License. |
| 13 | + * ===========================================================================*/ |
| 14 | + |
| 15 | +#ifdef __EMSCRIPTEN__ |
| 16 | +#include <emscripten.h> |
| 17 | +#endif |
| 18 | + |
| 19 | +#include "src/cc/conv2d_impl.h" |
| 20 | + |
| 21 | +#include <xnnpack.h> |
| 22 | +#include <array> |
| 23 | +#include <cmath> |
| 24 | +#include <limits> |
| 25 | +#include <map> |
| 26 | +#include <memory> |
| 27 | +#include <unordered_map> |
| 28 | +#include <utility> |
| 29 | +#include <vector> |
| 30 | + |
| 31 | +#include "src/cc/backend.h" |
| 32 | +#include "src/cc/transpose_impl.h" |
| 33 | +#include "src/cc/util.h" |
| 34 | + |
| 35 | +namespace { |
| 36 | +// These integer values are keys to creating the conv2d operator. We use |
| 37 | +// std::array instead of a vanilla array as it implements the compare operator |
| 38 | +// needed for std::map. |
| 39 | +typedef std::array<int, 16> OperatorCacheKey; |
| 40 | + |
| 41 | +// The operator cache maps the cache key to the xnn_operator_t instantiated for |
| 42 | +// this set of arguments to the xnn_operator. |
| 43 | +std::map<OperatorCacheKey, xnn_operator_t> operator_cache; |
| 44 | + |
| 45 | +// Maps a filter id to a list of operator cache keys that this filter belongs |
| 46 | +// to. |
| 47 | +std::unordered_map<int, std::vector<OperatorCacheKey>> |
| 48 | + filter_operator_cache_key_map; |
| 49 | + |
| 50 | +// Maps a bias id to a list of operator cache keys that this filter belongs |
| 51 | +// to. |
| 52 | +std::unordered_map<int, std::vector<OperatorCacheKey>> |
| 53 | + bias_operator_cache_key_map; |
| 54 | + |
| 55 | +void erase_from_cache(const int tensor_id, |
| 56 | + std::unordered_map<int, std::vector<OperatorCacheKey>>& |
| 57 | + operator_cache_key_map) { |
| 58 | + auto operator_cache_keys_idx = operator_cache_key_map.find(tensor_id); |
| 59 | + if (operator_cache_keys_idx != operator_cache_key_map.end()) { |
| 60 | + std::vector<OperatorCacheKey> operator_cache_keys = |
| 61 | + operator_cache_keys_idx->second; |
| 62 | + for (auto& operator_cache_key : operator_cache_keys) { |
| 63 | + auto operator_cache_key_idx = operator_cache.find(operator_cache_key); |
| 64 | + if (operator_cache_key_idx != operator_cache.end()) { |
| 65 | + auto& conv2d_op = operator_cache_key_idx->second; |
| 66 | + |
| 67 | + xnn_delete_operator(conv2d_op); |
| 68 | + tfjs::backend::xnn_operator_count--; |
| 69 | + |
| 70 | + operator_cache.erase(operator_cache_key); |
| 71 | + } |
| 72 | + } |
| 73 | + operator_cache_key_map.erase(tensor_id); |
| 74 | + } |
| 75 | +} |
| 76 | + |
| 77 | +void delete_xnn_operators(int tensor_id) { |
| 78 | + erase_from_cache(tensor_id, filter_operator_cache_key_map); |
| 79 | + erase_from_cache(tensor_id, bias_operator_cache_key_map); |
| 80 | +} |
| 81 | + |
| 82 | +void associate_tensor_with_key( |
| 83 | + const int tensor_id, const OperatorCacheKey& cache_key, |
| 84 | + std::unordered_map<int, std::vector<OperatorCacheKey>>& |
| 85 | + operator_cache_key_map) { |
| 86 | + auto cache_keys_idx = operator_cache_key_map.find(tensor_id); |
| 87 | + if (cache_keys_idx == operator_cache_key_map.end()) { |
| 88 | + std::vector<OperatorCacheKey> cache_keys = {cache_key}; |
| 89 | + operator_cache_key_map.emplace(tensor_id, std::move(cache_keys)); |
| 90 | + tfjs::backend::register_disposal_callback(tensor_id, *delete_xnn_operators); |
| 91 | + |
| 92 | + } else { |
| 93 | + auto& cache_keys = operator_cache_key_map.at(tensor_id); |
| 94 | + cache_keys.emplace_back(cache_key); |
| 95 | + } |
| 96 | +} |
| 97 | + |
| 98 | +} // namespace |
| 99 | + |
| 100 | +namespace tfjs { |
| 101 | +namespace wasm { |
| 102 | + |
| 103 | +void conv2d(const int x_id, const int batch_size, const int input_height, |
| 104 | + const int input_width, const int filter_id, const int filter_height, |
| 105 | + const int filter_width, const int bias_id, int pad_top, |
| 106 | + int pad_right, int pad_bottom, int pad_left, const int is_same_pad, |
| 107 | + const int dilation_height, const int dilation_width, |
| 108 | + const int stride_height, const int stride_width, |
| 109 | + const int input_channels, const int output_channels, |
| 110 | + const int out_id) { |
| 111 | + auto& x_info = backend::get_tensor_info(x_id); |
| 112 | + auto& filter_info = backend::get_tensor_info(filter_id); |
| 113 | + auto& out_info = backend::get_tensor_info_out(out_id); |
| 114 | + |
| 115 | + const float* x_buf = x_info.f32(); |
| 116 | + const float* filter_buf = filter_info.f32(); |
| 117 | + const float* bias_buf = nullptr; |
| 118 | + if (bias_id != -1) { |
| 119 | + bias_buf = backend::get_tensor_info_out(bias_id).f32(); |
| 120 | + } |
| 121 | + float* out_buf = out_info.f32_write(); |
| 122 | + |
| 123 | + xnn_operator_t conv2d_op = nullptr; |
| 124 | + |
| 125 | + int flags = 0; |
| 126 | + if (is_same_pad) { |
| 127 | + pad_top = 0, pad_right = 0, pad_bottom = 0, pad_left = 0; |
| 128 | + flags = XNN_FLAG_TENSORFLOW_SAME_PADDING; |
| 129 | + } |
| 130 | + |
| 131 | + const int groups = 1; |
| 132 | + |
| 133 | + OperatorCacheKey cache_key = { |
| 134 | + pad_top, pad_right, pad_bottom, pad_left, |
| 135 | + filter_height, filter_width, stride_height, stride_width, |
| 136 | + dilation_height, dilation_width, groups, input_channels, |
| 137 | + output_channels, filter_id, bias_id, flags}; |
| 138 | + |
| 139 | + auto operator_cache_idx = operator_cache.find(cache_key); |
| 140 | + if (operator_cache_idx == operator_cache.end()) { |
| 141 | + float output_min = -std::numeric_limits<float>::infinity(); |
| 142 | + float output_max = std::numeric_limits<float>::infinity(); |
| 143 | + |
| 144 | + // xnn pack expects weights layed out like: |
| 145 | + // [output_channels, filter_height, filter_width, input_channels] |
| 146 | + // TensorFlow has weights layed out like: |
| 147 | + // [filter_height, filter_width, input_channels, output_channels] |
| 148 | + // This can be transposed with a 2d transpose to move output_channels to the |
| 149 | + // outer most dimension. |
| 150 | + std::vector<float> transposed_filter(filter_info.size); |
| 151 | + |
| 152 | + const std::vector<int> filter_shape = { |
| 153 | + filter_height * filter_width * input_channels, output_channels}; |
| 154 | + const std::vector<int> perm = {1, 0}; |
| 155 | + tfjs::wasm::transpose(filter_buf, filter_shape, perm, |
| 156 | + transposed_filter.data()); |
| 157 | + |
| 158 | + xnn_status status = xnn_create_convolution2d_nhwc_f32( |
| 159 | + pad_top, pad_right, pad_bottom, pad_left, filter_height, filter_width, |
| 160 | + stride_height, stride_width, dilation_height, dilation_width, groups, |
| 161 | + input_channels /* group_input_channels */, |
| 162 | + output_channels /* group_output_channels */, |
| 163 | + input_channels /* input_pixel_stride */, |
| 164 | + output_channels /* output_pixel_stride */, transposed_filter.data(), |
| 165 | + bias_buf, output_min, output_max, flags, &conv2d_op); |
| 166 | + if (status != xnn_status_success) { |
| 167 | + util::warn( |
| 168 | + "XNN status for xnn_create_convolution2d_nhwc_f32 is not successful. " |
| 169 | + "Got status %d. Use -c dbg to see XNN logs.", |
| 170 | + status); |
| 171 | + } |
| 172 | + |
| 173 | + operator_cache.emplace(cache_key, conv2d_op); |
| 174 | + |
| 175 | + associate_tensor_with_key(filter_id, cache_key, |
| 176 | + filter_operator_cache_key_map); |
| 177 | + if (bias_id != -1) { |
| 178 | + associate_tensor_with_key(bias_id, cache_key, |
| 179 | + bias_operator_cache_key_map); |
| 180 | + } |
| 181 | + |
| 182 | + tfjs::backend::xnn_operator_count++; |
| 183 | + } else { |
| 184 | + conv2d_op = operator_cache_idx->second; |
| 185 | + } |
| 186 | + |
| 187 | + xnn_status status = xnn_setup_convolution2d_nhwc_f32( |
| 188 | + conv2d_op, batch_size, input_height, input_width, x_buf, out_buf, |
| 189 | + nullptr /* thread pool */); |
| 190 | + if (status != xnn_status_success) { |
| 191 | + util::warn( |
| 192 | + "XNN status for xnn_setup_convolution2d_nhwc_f32 is not successful. " |
| 193 | + "Got status %d. Use -c dbg to see XNN logs.", |
| 194 | + status); |
| 195 | + } |
| 196 | + |
| 197 | + xnn_run_operator(conv2d_op, nullptr /* thread pool */); |
| 198 | +} |
| 199 | + |
| 200 | +} // namespace wasm |
| 201 | +} // namespace tfjs |
0 commit comments