-
Notifications
You must be signed in to change notification settings - Fork 250
Expand file tree
/
Copy pathRsqrt.hpp
More file actions
32 lines (28 loc) · 816 Bytes
/
Rsqrt.hpp
File metadata and controls
32 lines (28 loc) · 816 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#ifndef UTENSOR_RSQRT_H
#define UTENSOR_RSQRT_H
#include <cmath>
#include "uTensor/core/operatorBase.hpp"
#include "uTensor/core/tensor.hpp"
#include "uTensor/core/types.hpp"
namespace uTensor {
namespace ReferenceOperators {
template <typename Tin>
class RsqrtOperator : public OperatorInterface<1, 1> {
public:
enum names_in : uint8_t { input };
enum names_out : uint8_t { output };
protected:
void compute() {
Tensor &inputT = inputs[input].tensor();
Tensor &outputT = outputs[output].tensor();
for (uint32_t i = 0; i < outputT->num_elems(); ++i) {
Tin v = static_cast<Tin>(inputT(i));
Tin one = 1;
Tin sqrt = std::sqrt(v);
outputT(i) = static_cast<Tin>(one / sqrt);
}
}
};
} // namespace ReferenceOperators
} // namespace uTensor
#endif // UTENSOR_RSQRT_H