Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding airthmetic operations
  • Loading branch information
pavanky committed Feb 27, 2017
commit c9232386d3b574f5d9266f9b5c6c21e12874f157
1 change: 1 addition & 0 deletions arrayfire.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ require('arrayfire.defines')
require('arrayfire.dim4')
require('arrayfire.util')
require('arrayfire.array')
require('arrayfire.arith')
require('arrayfire.device')

return af
188 changes: 188 additions & 0 deletions arrayfire/arith.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
require('arrayfire.lib')
require('arrayfire.defines')
require('arrayfire.array')
local ffi = require( "ffi" )

local funcs = {}

funcs[30] = [[
af_err af_add (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_sub (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_mul (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_div (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_lt (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_gt (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_le (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_ge (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_eq (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_neq (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_and (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_or (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_not (af_array *out, const af_array in);
af_err af_bitand (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_bitor (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_bitxor (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_bitshiftl(af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_bitshiftr(af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_cast (af_array *out, const af_array in, const af_dtype type);
af_err af_minof (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_maxof (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_rem (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_mod (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_abs (af_array *out, const af_array in);
af_err af_arg (af_array *out, const af_array in);
af_err af_sign (af_array *out, const af_array in);
af_err af_round (af_array *out, const af_array in);
af_err af_trunc (af_array *out, const af_array in);
af_err af_floor (af_array *out, const af_array in);
af_err af_ceil (af_array *out, const af_array in);
af_err af_hypot (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_sin (af_array *out, const af_array in);
af_err af_cos (af_array *out, const af_array in);
af_err af_tan (af_array *out, const af_array in);
af_err af_asin (af_array *out, const af_array in);
af_err af_acos (af_array *out, const af_array in);
af_err af_atan (af_array *out, const af_array in);
af_err af_atan2 (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_cplx2 (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_cplx (af_array *out, const af_array in);
af_err af_real (af_array *out, const af_array in);
af_err af_imag (af_array *out, const af_array in);
af_err af_conjg (af_array *out, const af_array in);
af_err af_sinh (af_array *out, const af_array in);
af_err af_cosh (af_array *out, const af_array in);
af_err af_tanh (af_array *out, const af_array in);
af_err af_asinh (af_array *out, const af_array in);
af_err af_acosh (af_array *out, const af_array in);
af_err af_atanh (af_array *out, const af_array in);
af_err af_root (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_pow (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_pow2 (af_array *out, const af_array in);
af_err af_exp (af_array *out, const af_array in);

af_err af_expm1 (af_array *out, const af_array in);
af_err af_erf (af_array *out, const af_array in);
af_err af_erfc (af_array *out, const af_array in);
af_err af_log (af_array *out, const af_array in);
af_err af_log1p (af_array *out, const af_array in);
af_err af_log10 (af_array *out, const af_array in);
af_err af_log2 (af_array *out, const af_array in);
af_err af_sqrt (af_array *out, const af_array in);
af_err af_cbrt (af_array *out, const af_array in);
af_err af_factorial (af_array *out, const af_array in);
af_err af_tgamma (af_array *out, const af_array in);
af_err af_lgamma (af_array *out, const af_array in);
af_err af_iszero (af_array *out, const af_array in);
af_err af_isinf (af_array *out, const af_array in);
af_err af_isnan (af_array *out, const af_array in);
]]

funcs[31] = [[
af_err af_sigmoid (af_array *out, const af_array in);
]]

funcs[34] = [[
af_err af_clamp(af_array *out, const af_array in,
const af_array lo, const af_array hi, const bool batch);
]]

af.lib.cdef(funcs)
local c_array_p = af.ffi.c_array_p
local init = af.Array.init

local binaryFuncs = {
'add',
'sub',
'mul',
'div',
'lt',
'gt',
'le',
'ge',
'eq',
'neq',
'and',
'or',
'bitand',
'bitor',
'bitxor',
'bitshiftl',
'bitshiftr',
'minof',
'maxof',
'rem',
'mod',
'hypot',
'atan2',
'cplx2',
'root',
'pow',
}


for _, func in ipairs(binaryFuncs) do
af[func] = function(lhs, rhs, batch)
-- TODO: add support for numbers
-- TODO: add support for batch mode
local res = c_array_p()
af.clib['af_' .. func](res, lhs:get(), rhs:get(), batch and true or false)
return init(res[0])
end
end

local unaryFuncs = {
'abs',
'arg',
'sign',
'round',
'trunc',
'floor',
'ceil',
'sin',
'cos',
'tan',
'asin',
'acos',
'atan',
'cplx',
'real',
'imag',
'conjg',
'sinh',
'cosh',
'tanh',
'asinh',
'acosh',
'atanh',
'pow2',
'exp',
'expm1',
'erf',
'erfc',
'log',
'log1p',
'log10',
'log2',
'sqrt',
'cbrt',
'factorial',
'tgamma',
'lgamma',
'iszero',
'isinf',
'isnan'
}

for _, func in ipairs(unaryFuncs) do
af[func] = function(input)
local res = c_array_p()
af.clib['af_' .. func](res, input:get())
return init(res[0])
end
end

af.cast = function(input, rtype)
local res = c_array_p()
af.clib.af_cast(res, input:get(), rtype)
return init(res[0])
end
28 changes: 16 additions & 12 deletions arrayfire/array.lua
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,19 @@ local c_uint_t = af.ffi.c_uint_t
local c_ptr_t = af.ffi.c_ptr_t
local Dim4 = af.Dim4

function release_array(ptr)
local res = af.clib.af_release_array(ptr)
-- TODO: Error handling logic
end

local c_array_p = function(ptr)
local arr_ptr = ffi.new('void *[1]', ptr)
arr_ptr[0] = ffi.gc(arr_ptr[0], af.clib.af_release_array)
return arr_ptr
end

local init = function(ptr)
local self = setmetatable({}, Array)
self._array = ptr
self._ptr = ffi.gc(ptr, release_array)
return self
end

Expand Down Expand Up @@ -117,51 +121,51 @@ Array.__tostring = function(self)
end

Array.get = function(self)
return self._array
return self._ptr
end

-- TODO: implement Array.write

Array.copy = function(self)
local res = c_array_p()
af.clib.af_copy_array(res, self._array)
af.clib.af_copy_array(res, self:get())
return Array.init(res[0])
end

Array.softCopy = function(self)
local res = c_array_p()
af.clib.af_copy_array(res, self._array)
af.clib.af_copy_array(res, self:get())
return Array.init(res[0])
end

Array.elements = function(self)
local res = c_ptr_t('dim_t')
af.clib.af_get_elements(res, self._array)
af.clib.af_get_elements(res, self:get())
return tonumber(res[0])
end

Array.type = function(self)
local res = c_ptr_t('af_dtype')
af.clib.af_get_type(res, self._array)
af.clib.af_get_type(res, self:get())
return tonumber(res[0])
end

Array.typeName = function(self)
local res = c_ptr_t('af_dtype')
af.clib.af_get_type(res, self._array)
af.clib.af_get_type(res, self:get())
return af.dtype_names[tonumber(res[0])]
end

Array.dims = function(self)
local res = c_dim4_t()
af.clib.af_get_dims(res + 0, res + 1, res + 2, res + 3, self._array)
af.clib.af_get_dims(res + 0, res + 1, res + 2, res + 3, self:get())
return Dim4(tonumber(res[0]), tonumber(res[1]),
tonumber(res[2]), tonumber(res[3]))
end

Array.numdims = function(self)
local res = c_ptr_t('unsigned int')
af.clib.af_get_numdims(res, self._array)
af.clib.af_get_numdims(res, self:get())
return tonumber(res[0])
end

Expand All @@ -184,13 +188,13 @@ local funcs = {
for name, cname in pairs(funcs) do
Array[name] = function(self)
local res = c_ptr_t('bool')
af.clib['af_' .. cname](res, self._array)
af.clib['af_' .. cname](res, self:get())
return res[0]
end
end

Array.eval = function(self)
af.clib.af_eval(self._array)
af.clib.af_eval(self:get())
end

-- Useful aliases
Expand Down
2 changes: 1 addition & 1 deletion arrayfire/util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ funcs[34] = [[
af.lib.cdef(funcs)

af.print = function(arr)
af.clib.af_print_array_gen(ffi.cast("char *", "ArrayFire Array"), arr._array, 4)
af.clib.af_print_array_gen(ffi.cast("char *", "ArrayFire Array"), arr:get(), 4)
end
1 change: 1 addition & 0 deletions rocks/arrayfire-scm-1.rockspec
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ build = {
["arrayfire.defines"] = "arrayfire/defines.lua",
["arrayfire.device"] = "arrayfire/device.lua",
["arrayfire.dim4"] = "arrayfire/dim4.lua",
["arrayfire.arith"] = "arrayfire/arith.lua",
},
}