forked from arrayfire/arrayfire-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutil.py
More file actions
80 lines (64 loc) · 2.09 KB
/
util.py
File metadata and controls
80 lines (64 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#######################################################
# Copyright (c) 2015, ArrayFire
# All rights reserved.
#
# This file is distributed under 3-clause BSD license.
# The complete license agreement can be obtained at:
# http://arrayfire.com/licenses/BSD-3-Clause
########################################################
from .library import *
import numbers
def dim4(d0=1, d1=1, d2=1, d3=1):
ct.c_dim4 = ct.c_longlong * 4
out = ct.c_dim4(1, 1, 1, 1)
for i, dim in enumerate((d0, d1, d2, d3)):
if (dim is not None): out[i] = dim
return out
def is_number(a):
return isinstance(a, numbers.Number)
def dim4_tuple(dims, default=1):
assert(isinstance(dims, tuple))
if (default is not None):
assert(is_number(default))
out = [default]*4
for i, dim in enumerate(dims):
out[i] = dim
return tuple(out)
def to_str(c_str):
return str(c_str.value.decode('utf-8'))
def safe_call(af_error):
if (af_error != AF_SUCCESS.value):
err_str = ct.c_char_p(0)
err_len = ct.c_longlong(0)
clib.af_get_last_error(ct.pointer(err_str), ct.pointer(err_len))
raise RuntimeError(to_str(err_str), af_error)
def get_version():
major=ct.c_int(0)
minor=ct.c_int(0)
patch=ct.c_int(0)
safe_call(clib.af_get_version(ct.pointer(major), ct.pointer(minor), ct.pointer(patch)))
return major,minor,patch
to_dtype = {'f' : f32,
'd' : f64,
'b' : b8,
'B' : u8,
'i' : s32,
'I' : u32,
'l' : s64,
'L' : u64}
to_typecode = {f32.value : 'f',
f64.value : 'd',
b8.value : 'b',
u8.value : 'B',
s32.value : 'i',
u32.value : 'I',
s64.value : 'l',
u64.value : 'L'}
to_c_type = {f32.value : ct.c_float,
f64.value : ct.c_double,
b8.value : ct.c_char,
u8.value : ct.c_ubyte,
s32.value : ct.c_int,
u32.value : ct.c_uint,
s64.value : ct.c_longlong,
u64.value : ct.c_ulonglong}