forked from arrayfire/arrayfire
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsolve.cpp
More file actions
146 lines (118 loc) · 4.8 KB
/
solve.cpp
File metadata and controls
146 lines (118 loc) · 4.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/*******************************************************
* Copyright (c) 2014, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#include <backend.hpp>
#include <common/ArrayInfo.hpp>
#include <common/err_common.hpp>
#include <handle.hpp>
#include <solve.hpp>
#include <af/array.h>
#include <af/defines.h>
#include <af/lapack.h>
using af::dim4;
using detail::Array;
using detail::cdouble;
using detail::cfloat;
using detail::solveLU;
template<typename T>
static inline af_array solve(const af_array a, const af_array b,
const af_mat_prop options) {
return getHandle(solve<T>(getArray<T>(a), getArray<T>(b), options));
}
af_err af_solve(af_array* out, const af_array a, const af_array b,
const af_mat_prop options) {
try {
const ArrayInfo& a_info = getInfo(a);
const ArrayInfo& b_info = getInfo(b);
if (a_info.ndims() > 2 || b_info.ndims() > 2) {
AF_ERROR("solve can not be used in batch mode", AF_ERR_BATCH);
}
af_dtype a_type = a_info.getType();
af_dtype b_type = b_info.getType();
dim4 adims = a_info.dims();
dim4 bdims = b_info.dims();
ARG_ASSERT(1, a_info.isFloating()); // Only floating and complex types
ARG_ASSERT(2, b_info.isFloating()); // Only floating and complex types
TYPE_ASSERT(a_type == b_type);
DIM_ASSERT(1, bdims[0] == adims[0]);
DIM_ASSERT(1, bdims[2] == adims[2]);
DIM_ASSERT(1, bdims[3] == adims[3]);
if (a_info.ndims() == 0 || b_info.ndims() == 0) {
return af_create_handle(out, 0, nullptr, a_type);
}
bool is_triangle_solve =
(options & AF_MAT_LOWER) || (options & AF_MAT_UPPER);
if (options != AF_MAT_NONE && !is_triangle_solve) {
AF_ERROR("Using this property is not yet supported in solve",
AF_ERR_NOT_SUPPORTED);
}
if (is_triangle_solve) {
DIM_ASSERT(1, adims[0] == adims[1]);
if ((options & AF_MAT_TRANS || options & AF_MAT_CTRANS)) {
AF_ERROR("Using AF_MAT_TRANS is not yet supported in solve",
AF_ERR_NOT_SUPPORTED);
}
}
af_array output;
switch (a_type) {
case f32: output = solve<float>(a, b, options); break;
case f64: output = solve<double>(a, b, options); break;
case c32: output = solve<cfloat>(a, b, options); break;
case c64: output = solve<cdouble>(a, b, options); break;
default: TYPE_ERROR(1, a_type);
}
std::swap(*out, output);
}
CATCHALL;
return AF_SUCCESS;
}
template<typename T>
static inline af_array solve_lu(const af_array a, const af_array pivot,
const af_array b, const af_mat_prop options) {
return getHandle(solveLU<T>(getArray<T>(a), getArray<int>(pivot),
getArray<T>(b), options));
}
af_err af_solve_lu(af_array* out, const af_array a, const af_array piv,
const af_array b, const af_mat_prop options) {
try {
const ArrayInfo& a_info = getInfo(a);
const ArrayInfo& b_info = getInfo(b);
if (a_info.ndims() > 2 || b_info.ndims() > 2) {
AF_ERROR("solveLU can not be used in batch mode", AF_ERR_BATCH);
}
af_dtype a_type = a_info.getType();
af_dtype b_type = b_info.getType();
dim4 adims = a_info.dims();
dim4 bdims = b_info.dims();
if (a_info.ndims() == 0 || b_info.ndims() == 0) {
return af_create_handle(out, 0, nullptr, a_type);
}
ARG_ASSERT(1, a_info.isFloating()); // Only floating and complex types
ARG_ASSERT(2, b_info.isFloating()); // Only floating and complex types
TYPE_ASSERT(a_type == b_type);
DIM_ASSERT(1, adims[0] == adims[1]);
DIM_ASSERT(1, bdims[0] == adims[0]);
DIM_ASSERT(1, bdims[2] == adims[2]);
DIM_ASSERT(1, bdims[3] == adims[3]);
if (options != AF_MAT_NONE) {
AF_ERROR("Using this property is not yet supported in solveLU",
AF_ERR_NOT_SUPPORTED);
}
af_array output;
switch (a_type) {
case f32: output = solve_lu<float>(a, piv, b, options); break;
case f64: output = solve_lu<double>(a, piv, b, options); break;
case c32: output = solve_lu<cfloat>(a, piv, b, options); break;
case c64: output = solve_lu<cdouble>(a, piv, b, options); break;
default: TYPE_ERROR(1, a_type);
}
std::swap(*out, output);
}
CATCHALL;
return AF_SUCCESS;
}