-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmethods.lua
More file actions
141 lines (105 loc) · 2.18 KB
/
methods.lua
File metadata and controls
141 lines (105 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
--- Array methods.
-- Modules --
local af = require("arrayfire_lib")
-- Exports --
local M = {}
-- See also: https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/array.cpp
-- --
local Dims = {}
--
function M.Add (array_module, meta)
local Call = array_module.Call
local CallWrap = array_module.CallWrap
local GetLib = array_module.GetLib
--
local function Get (name)
name = "af_" .. name
return function(arr)
return Call(name, arr:get())
end
end
-- --
local SizeOf = {}
for prefix, size in ("f32 f64 s32 u32 s64 u64 u8 b8 c32 c64 s16 u16"):gmatch "(%a)(%d+)" do
local k = af[prefix .. size]
if k then -- account for earlier versions
SizeOf[k] = tonumber(size) / (prefix == "c" and 4 or 8) -- 8 bits to a byte; double complex types
end
end
--
for k, v in pairs{
--
as = function(arr, atype)
return CallWrap("af_cast", arr:get(), af[atype])
end,
--
bytes = function(arr)
local ha = arr:get()
local n, dtype = Call("af_get_elements", ha), Call("af_get_type", ha)
return n * (SizeOf[dtype] or 4)
end,
--
copy = function(arr)
return CallWrap("af_copy_array", arr:get())
end,
--
dims = function(arr, i)
if i then
return GetLib().getDims(arr, Dims)[i + 1]
else
return GetLib().getDims(arr)
end
end,
--
elements = Get("get_elements"),
--
eval = function(arr)
Call("af_eval", arr:get())
end,
--
get = array_module.GetHandle,
--
isbool = Get("is_bool"),
--
iscolumn = Get("is_column"),
--
iscomplex = Get("is_complex"),
--
isdouble = Get("is_double"),
--
isempty = Get("is_empty"),
--
isfloating = Get("is_floating"),
--
isinteger = Get("is_integer"),
--
isrealfloating = Get("is_real_floating"),
--
isrow = Get("is_row"),
--
isscalar = Get("is_scalar"),
--
issingle = Get("is_single"),
--
isvector = Get("is_vector"),
--
H = function(arr)
return GetLib().transpose(arr, true)
end,
--
numdims = function(arr)
return GetLib().numDims(arr)
end,
--
set = array_module.SetHandle,
--
T = function(arr)
return GetLib().transpose(arr)
end,
type = Get("get_type")
} do
meta[k] = v
end
end
-- Export the module.
return M