Skip to content

Commit 6099075

Browse files
committed
SOLVE, MATMUL and INVERSE now use af_mat_prop
af_mat_prop values can be used for performance improvements by calling specialized routines
1 parent 2a48e06 commit 6099075

21 files changed

Lines changed: 115 additions & 120 deletions

File tree

include/af/blas.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ namespace af
4646
\ingroup blas_func_matmul
4747
*/
4848
AFAPI array matmul(const array &lhs, const array &rhs,
49-
const af::trans optLhs = AF_NO_TRANS,
50-
const af::trans optRhs = AF_NO_TRANS);
49+
const matProp optLhs = AF_MAT_NONE,
50+
const matProp optRhs = AF_MAT_NONE);
5151

5252
/**
5353
\brief Matrix multiply on two arrays
@@ -102,8 +102,8 @@ namespace af
102102
\ingroup blas_func_dot
103103
*/
104104
AFAPI array dot (const array &lhs, const array &rhs,
105-
const af::trans optLhs = AF_NO_TRANS,
106-
const af::trans optRhs = AF_NO_TRANS);
105+
const matProp optLhs = AF_MAT_NONE,
106+
const matProp optRhs = AF_MAT_NONE);
107107

108108
/**
109109
\brief Transposes a matrix
@@ -144,7 +144,7 @@ extern "C" {
144144
*/
145145
AFAPI af_err af_matmul( af_array *out ,
146146
const af_array lhs, const af_array rhs,
147-
const af_transpose_t optLhs, const af_transpose_t optRhs);
147+
const af_mat_prop optLhs, const af_mat_prop optRhs);
148148

149149

150150
/**
@@ -161,7 +161,7 @@ extern "C" {
161161

162162
AFAPI af_err af_dot( af_array *out,
163163
const af_array lhs, const af_array rhs,
164-
const af_transpose_t optLhs, const af_transpose_t optRhs);
164+
const af_mat_prop optLhs, const af_mat_prop optRhs);
165165

166166
/**
167167
\brief Transposes a matrix

include/af/defines.h

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -151,24 +151,18 @@ typedef enum {
151151
} af_cspace_t;
152152

153153
typedef enum {
154-
AF_SOLVE_NONE = 0, ///< Default
155-
AF_SOLVE_POSDEF = 1, ///< A is positive definite
156-
AF_SOLVE_NONPOSDEF = 2, ///< A is not positive definite
157-
AF_SOLVE_GAUSSIAN = 3, ///< Use Gaussian elimination (fast, cannot be combined with other options)
158-
AF_SOLVE_PSEUDO = 4, ///< Use pseudo inverse (fast, cannot be combined with other options)
159-
AF_SOLVE_CTRANS = 256, ///< Solve A.H() (conjugate transpose)
160-
AF_SOLVE_TRANS = 512, ///< Solve A.T() (non-conjugate transpose)
161-
AF_SOLVE_UPPERTRI = 1024, ///< Solve uppertri(A) (upper triangular system)
162-
AF_SOLVE_LOWERTRI = 2048, ///< Solve lowertri(A) (lower triangular system)
163-
AF_SOLVE_TRIDIAG = 4096,
164-
AF_SOLVE_BLKDIAG = 8192
165-
} af_solve_t;
166-
167-
typedef enum {
168-
AF_NO_TRANS,
169-
AF_TRANS,
170-
AF_CONJ_TRANS
171-
} af_transpose_t;
154+
AF_MAT_NONE = 0, ///< Default
155+
AF_MAT_TRANS = 1, ///< Data needs to be transposed
156+
AF_MAT_CTRANS = 2, ///< Data needs to be conjugate tansposed
157+
AF_MAT_UPPER = 32, ///< Matrix is upper triangular
158+
AF_MAT_LOWER = 64, ///< Matrix is lower triangular
159+
AF_MAT_DIAG_UNIT = 128, ///< Matrix diagonal contains unitary values
160+
AF_MAT_SYM = 512, ///< Matrix is symmetric
161+
AF_MAT_POSDEF = 1024, ///< Matrix is positive definite
162+
AF_MAT_ORTHOG = 2048, ///< Matrix is orthogonal
163+
AF_MAT_TRI_DIAG = 4096, ///< Matrix is tri diagonal
164+
AF_MAT_BLOCK_DIAG = 8192 ///< Matrix is block diagonal
165+
} af_mat_prop;
172166

173167
// Below enum is purely added for example purposes
174168
// it doesn't and shoudn't be used anywhere in the
@@ -191,8 +185,9 @@ namespace af
191185
typedef af_match_type matchType;
192186
typedef af_cspace_t CSpace;
193187
typedef af_someenum_t SomeEnum; // Purpose of Addition: How to add Function example
194-
typedef af_transpose_t trans;
188+
typedef af_mat_prop trans;
195189
typedef af_conv_mode convMode;
190+
typedef af_mat_prop matProp;
196191

197192
const double NaN = std::numeric_limits<double>::quiet_NaN();
198193
const double Inf = std::numeric_limits<double>::infinity();

include/af/lapack.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ namespace af
3535

3636
AFAPI void choleskyInPlace(array &in, int *info = NULL, const bool is_upper = true);
3737

38-
AFAPI array solve(const array &a, const array &b, const af_solve_t options = AF_SOLVE_NONE);
38+
AFAPI array solve(const array &a, const array &b, const matProp options = AF_MAT_NONE);
3939

40-
AFAPI array inverse(const array &in);
40+
AFAPI array inverse(const array &in, const matProp options = AF_MAT_NONE);
4141

4242
/**
4343
@}
@@ -65,9 +65,9 @@ extern "C" {
6565
AFAPI af_err af_cholesky_inplace(int *info, af_array in, const bool is_upper);
6666

6767
AFAPI af_err af_solve(af_array *out, const af_array a, const af_array b,
68-
const af_solve_t options);
68+
const af_mat_prop options);
6969

70-
AFAPI af_err af_inverse(af_array *out, const af_array in);
70+
AFAPI af_err af_inverse(af_array *out, const af_array in, const af_mat_prop options);
7171

7272
#ifdef __cplusplus
7373
}

src/api/c/blas.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,21 @@
1919

2020
template<typename T>
2121
static inline af_array matmul(const af_array lhs, const af_array rhs,
22-
af_transpose_t optLhs, af_transpose_t optRhs)
22+
af_mat_prop optLhs, af_mat_prop optRhs)
2323
{
2424
return getHandle(detail::matmul<T>(getArray<T>(lhs), getArray<T>(rhs), optLhs, optRhs));
2525
}
2626

2727
template<typename T>
2828
static inline af_array dot(const af_array lhs, const af_array rhs,
29-
af_transpose_t optLhs, af_transpose_t optRhs)
29+
af_mat_prop optLhs, af_mat_prop optRhs)
3030
{
3131
return getHandle(detail::dot<T>(getArray<T>(lhs), getArray<T>(rhs), optLhs, optRhs));
3232
}
3333

3434
af_err af_matmul( af_array *out,
3535
const af_array lhs, const af_array rhs,
36-
const af_transpose_t optLhs, const af_transpose_t optRhs)
36+
const af_mat_prop optLhs, const af_mat_prop optRhs)
3737
{
3838
using namespace detail;
3939

@@ -47,8 +47,8 @@ af_err af_matmul( af_array *out,
4747
TYPE_ASSERT(lhs_type == rhs_type);
4848
af_array output = 0;
4949

50-
int aColDim = (optLhs == AF_NO_TRANS) ? 1 : 0;
51-
int bRowDim = (optRhs == AF_NO_TRANS) ? 0 : 1;
50+
int aColDim = (optLhs == AF_MAT_NONE) ? 1 : 0;
51+
int bRowDim = (optRhs == AF_MAT_NONE) ? 0 : 1;
5252

5353
DIM_ASSERT(1, lhsInfo.dims()[aColDim] == rhsInfo.dims()[bRowDim]);
5454

@@ -67,7 +67,7 @@ af_err af_matmul( af_array *out,
6767

6868
af_err af_dot( af_array *out,
6969
const af_array lhs, const af_array rhs,
70-
const af_transpose_t optLhs, const af_transpose_t optRhs)
70+
const af_mat_prop optLhs, const af_mat_prop optRhs)
7171
{
7272
using namespace detail;
7373

src/api/c/inverse.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ static inline af_array inverse(const af_array in)
2525
return getHandle(inverse<T>(getArray<T>(in)));
2626
}
2727

28-
af_err af_inverse(af_array *out, const af_array in)
28+
af_err af_inverse(af_array *out, const af_array in, const af_mat_prop options)
2929
{
3030
try {
3131
ArrayInfo i_info = getInfo(in);

src/api/c/solve.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ using af::dim4;
2020
using namespace detail;
2121

2222
template<typename T>
23-
static inline af_array solve(const af_array a, const af_array b, const af_solve_t options)
23+
static inline af_array solve(const af_array a, const af_array b, const af_mat_prop options)
2424
{
2525
return getHandle(solve<T>(getArray<T>(a), getArray<T>(b), options));
2626
}
2727

28-
af_err af_solve(af_array *out, const af_array a, const af_array b, const af_solve_t options)
28+
af_err af_solve(af_array *out, const af_array a, const af_array b, const af_mat_prop options)
2929
{
3030
try {
3131
ArrayInfo a_info = getInfo(a);
@@ -46,7 +46,7 @@ af_err af_solve(af_array *out, const af_array a, const af_array b, const af_solv
4646
DIM_ASSERT(1, bdims[2] == adims[2]);
4747
DIM_ASSERT(1, bdims[3] == adims[3]);
4848

49-
ARG_ASSERT(3, options == AF_SOLVE_NONE);
49+
ARG_ASSERT(3, options == AF_MAT_NONE);
5050

5151
af_array output;
5252

src/api/cpp/blas.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
namespace af
1515
{
1616
array matmul(const array &lhs, const array &rhs,
17-
const af_transpose_t optLhs, const af_transpose_t optRhs)
17+
const matProp optLhs, const matProp optRhs)
1818
{
1919
af_array out = 0;
2020
AF_THROW(af_matmul(&out, lhs.get(), rhs.get(), optLhs, optRhs));
@@ -25,28 +25,28 @@ namespace af
2525
{
2626
af_array out = 0;
2727
AF_THROW(af_matmul(&out, lhs.get(), rhs.get(),
28-
AF_NO_TRANS, AF_TRANS));
28+
AF_MAT_NONE, AF_MAT_TRANS));
2929
return array(out);
3030
}
3131

3232
array matmulTN(const array &lhs, const array &rhs)
3333
{
3434
af_array out = 0;
3535
AF_THROW(af_matmul(&out, lhs.get(), rhs.get(),
36-
AF_TRANS, AF_NO_TRANS));
36+
AF_MAT_TRANS, AF_MAT_NONE));
3737
return array(out);
3838
}
3939

4040
array matmulTT(const array &lhs, const array &rhs)
4141
{
4242
af_array out = 0;
4343
AF_THROW(af_matmul(&out, lhs.get(), rhs.get(),
44-
AF_TRANS, AF_TRANS));
44+
AF_MAT_TRANS, AF_MAT_TRANS));
4545
return array(out);
4646
}
4747

4848
array dot (const array &lhs, const array &rhs,
49-
const af_transpose_t optLhs, const af_transpose_t optRhs)
49+
const matProp optLhs, const matProp optRhs)
5050
{
5151
af_array out = 0;
5252
AF_THROW(af_dot(&out, lhs.get(), rhs.get(), optLhs, optRhs));

src/api/cpp/lapack.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@ namespace af
7373
AF_THROW(af_cholesky_inplace(info, in.get(), is_upper));
7474
}
7575

76-
array solve(const array &a, const array &b, const af_solve_t options)
76+
array solve(const array &a, const array &b, const matProp options)
7777
{
7878
af_array out;
7979
AF_THROW(af_solve(&out, a.get(), b.get(), options));
8080
return array(out);
8181
}
8282

83-
array inverse(const array &in)
83+
array inverse(const array &in, const matProp options)
8484
{
8585
af_array out;
86-
AF_THROW(af_inverse(&out, in.get()));
86+
AF_THROW(af_inverse(&out, in.get(), options));
8787
return array(out);
8888
}
8989

src/backend/cpu/blas.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,14 @@ getScale()
9797
}
9898

9999
CBLAS_TRANSPOSE
100-
toCblasTranspose(af_transpose_t opt)
100+
toCblasTranspose(af_mat_prop opt)
101101
{
102102
CBLAS_TRANSPOSE out = CblasNoTrans;
103103
switch(opt) {
104-
case AF_NO_TRANS : out = CblasNoTrans; break;
105-
case AF_TRANS : out = CblasTrans; break;
106-
case AF_CONJ_TRANS : out = CblasConjTrans; break;
107-
default : AF_ERROR("INVALID af_transpose_t", AF_ERR_INVALID_ARG);
104+
case AF_MAT_NONE : out = CblasNoTrans; break;
105+
case AF_MAT_TRANS : out = CblasTrans; break;
106+
case AF_MAT_CTRANS : out = CblasConjTrans; break;
107+
default : AF_ERROR("INVALID af_mat_prop", AF_ERR_INVALID_ARG);
108108
}
109109
return out;
110110
}
@@ -143,7 +143,7 @@ struct cblas_types<cdouble> {
143143

144144
template<typename T>
145145
Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
146-
af_transpose_t optLhs, af_transpose_t optRhs)
146+
af_mat_prop optLhs, af_mat_prop optRhs)
147147
{
148148
CBLAS_TRANSPOSE lOpts = toCblasTranspose(optLhs);
149149
CBLAS_TRANSPOSE rOpts = toCblasTranspose(optRhs);
@@ -187,7 +187,7 @@ Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
187187

188188
template<typename T>
189189
Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
190-
af_transpose_t optLhs, af_transpose_t optRhs)
190+
af_mat_prop optLhs, af_mat_prop optRhs)
191191
{
192192
int N = lhs.dims()[0];
193193

@@ -206,7 +206,7 @@ Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
206206

207207
#define INSTANTIATE_BLAS(TYPE) \
208208
template Array<TYPE> matmul<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
209-
af_transpose_t optLhs, af_transpose_t optRhs);
209+
af_mat_prop optLhs, af_mat_prop optRhs);
210210

211211
INSTANTIATE_BLAS(float)
212212
INSTANTIATE_BLAS(cfloat)
@@ -215,7 +215,7 @@ INSTANTIATE_BLAS(cdouble)
215215

216216
#define INSTANTIATE_DOT(TYPE) \
217217
template Array<TYPE> dot<TYPE>(const Array<TYPE> &lhs, const Array<TYPE> &rhs, \
218-
af_transpose_t optLhs, af_transpose_t optRhs);
218+
af_mat_prop optLhs, af_mat_prop optRhs);
219219

220220
INSTANTIATE_DOT(float)
221221
INSTANTIATE_DOT(double)

src/backend/cpu/blas.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ namespace cpu
2828

2929
template<typename T>
3030
Array<T> matmul(const Array<T> &lhs, const Array<T> &rhs,
31-
af_transpose_t optLhs, af_transpose_t optRhs);
31+
af_mat_prop optLhs, af_mat_prop optRhs);
3232
template<typename T>
3333
Array<T> dot(const Array<T> &lhs, const Array<T> &rhs,
34-
af_transpose_t optLhs, af_transpose_t optRhs);
34+
af_mat_prop optLhs, af_mat_prop optRhs);
3535

3636
}

0 commit comments

Comments
 (0)