forked from arrayfire/arrayfire
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlookup.cpp
More file actions
71 lines (63 loc) · 2.88 KB
/
lookup.cpp
File metadata and controls
71 lines (63 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
/*******************************************************
* Copyright (c) 2014, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#include <kernel/lookup.hpp>
#include <lookup.hpp>
#include <common/half.hpp>
#include <platform.hpp>
#include <queue.hpp>
#include <cstdlib>
using common::half;
namespace cpu {
template<typename in_t, typename idx_t>
Array<in_t> lookup(const Array<in_t> &input, const Array<idx_t> &indices,
const unsigned dim) {
const dim4 &iDims = input.dims();
dim4 oDims(1);
for (int d = 0; d < 4; ++d) {
oDims[d] = (d == int(dim) ? indices.elements() : iDims[d]);
}
Array<in_t> out = createEmptyArray<in_t>(oDims);
getQueue().enqueue(kernel::lookup<in_t, idx_t>, out, input, indices, dim);
return out;
}
#define INSTANTIATE(T) \
template Array<T> lookup<T, float>(const Array<T> &, const Array<float> &, \
const unsigned); \
template Array<T> lookup<T, double>( \
const Array<T> &, const Array<double> &, const unsigned); \
template Array<T> lookup<T, int>(const Array<T> &, const Array<int> &, \
const unsigned); \
template Array<T> lookup<T, unsigned>( \
const Array<T> &, const Array<unsigned> &, const unsigned); \
template Array<T> lookup<T, short>(const Array<T> &, const Array<short> &, \
const unsigned); \
template Array<T> lookup<T, ushort>( \
const Array<T> &, const Array<ushort> &, const unsigned); \
template Array<T> lookup<T, intl>(const Array<T> &, const Array<intl> &, \
const unsigned); \
template Array<T> lookup<T, uintl>(const Array<T> &, const Array<uintl> &, \
const unsigned); \
template Array<T> lookup<T, uchar>(const Array<T> &, const Array<uchar> &, \
const unsigned); \
template Array<T> lookup<T, half>(const Array<T> &, const Array<half> &, \
const unsigned);
INSTANTIATE(float);
INSTANTIATE(cfloat);
INSTANTIATE(double);
INSTANTIATE(cdouble);
INSTANTIATE(int);
INSTANTIATE(unsigned);
INSTANTIATE(intl);
INSTANTIATE(uintl);
INSTANTIATE(uchar);
INSTANTIATE(char);
INSTANTIATE(ushort);
INSTANTIATE(short);
INSTANTIATE(half);
} // namespace cpu