forked from arrayfire/arrayfire
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathArray.hpp
More file actions
94 lines (81 loc) · 3.23 KB
/
Array.hpp
File metadata and controls
94 lines (81 loc) · 3.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
/*******************************************************
* Copyright (c) 2015, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#pragma once
#include <Param.hpp>
#include <jit/Node.hpp>
#include <platform.hpp>
#include <vector>
namespace cpu {
namespace kernel {
template<typename T>
void evalMultiple(std::vector<Param<T>> arrays,
std::vector<common::Node_ptr> output_nodes_) {
af::dim4 odims = arrays[0].dims();
af::dim4 ostrs = arrays[0].strides();
common::Node_map_t nodes;
std::vector<T *> ptrs;
std::vector<TNode<T> *> output_nodes;
std::vector<common::Node *> full_nodes;
std::vector<common::Node_ids> ids;
int narrays = static_cast<int>(arrays.size());
for (int i = 0; i < narrays; i++) {
ptrs.push_back(arrays[i].get());
output_nodes.push_back(
reinterpret_cast<TNode<T> *>(output_nodes_[i].get()));
output_nodes_[i]->getNodesMap(nodes, full_nodes, ids);
}
bool is_linear = true;
for (auto node : full_nodes) { is_linear &= node->isLinear(odims.get()); }
if (is_linear) {
int num = arrays[0].dims().elements();
int cnum =
jit::VECTOR_LENGTH * std::ceil(double(num) / jit::VECTOR_LENGTH);
for (int i = 0; i < cnum; i += jit::VECTOR_LENGTH) {
int lim = std::min(jit::VECTOR_LENGTH, num - i);
for (int n = 0; n < (int)full_nodes.size(); n++) {
full_nodes[n]->calc(i, lim);
}
for (int n = 0; n < (int)output_nodes.size(); n++) {
std::copy(output_nodes[n]->m_val.begin(),
output_nodes[n]->m_val.begin() + lim, ptrs[n] + i);
}
}
} else {
for (int w = 0; w < (int)odims[3]; w++) {
dim_t offw = w * ostrs[3];
for (int z = 0; z < (int)odims[2]; z++) {
dim_t offz = z * ostrs[2] + offw;
for (int y = 0; y < (int)odims[1]; y++) {
dim_t offy = y * ostrs[1] + offz;
int dim0 = odims[0];
int cdim0 = jit::VECTOR_LENGTH *
std::ceil(double(dim0) / jit::VECTOR_LENGTH);
for (int x = 0; x < (int)cdim0; x += jit::VECTOR_LENGTH) {
int lim = std::min(jit::VECTOR_LENGTH, dim0 - x);
dim_t id = x + offy;
for (int n = 0; n < (int)full_nodes.size(); n++) {
full_nodes[n]->calc(x, y, z, w, lim);
}
for (int n = 0; n < (int)output_nodes.size(); n++) {
std::copy(output_nodes[n]->m_val.begin(),
output_nodes[n]->m_val.begin() + lim,
ptrs[n] + id);
}
}
}
}
}
}
}
template<typename T>
void evalArray(Param<T> arr, common::Node_ptr node) {
evalMultiple<T>({arr}, {node});
}
} // namespace kernel
} // namespace cpu