Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/api/c/median.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ static af_array median(const af_array& in, const dim_t dim)
af_seq slices[4] = {af_span, af_span, af_span, af_span};
slices[dim] = af_make_seq(mid-1.0, mid-1.0, 1.0);

AF_CHECK(af_index(&left, getHandle<T>(sortedIn), input.ndims(), slices));
af_array sortedIn_handle = getHandle<T>(sortedIn);
AF_CHECK(af_index(&left, sortedIn_handle, input.ndims(), slices));

if (nElems % 2 == 1) {
// mid-1 is our guy
Expand All @@ -82,14 +83,15 @@ static af_array median(const af_array& in, const dim_t dim)
af_array out;
AF_CHECK(af_cast(&out, left, f32));
AF_CHECK(af_release_array(left));
AF_CHECK(af_release_array(sortedIn_handle));
return out;
} else {
// ((mid-1)+mid)/2 is our guy
dim4 dims = input.dims();
af_array right = 0;
slices[dim] = af_make_seq(mid, mid, 1.0);

AF_CHECK(af_index(&right, getHandle<T>(sortedIn), dims.ndims(), slices));
AF_CHECK(af_index(&right, sortedIn_handle, dims.ndims(), slices));

af_array sumarr = 0;
af_array carr = 0;
Expand All @@ -115,6 +117,7 @@ static af_array median(const af_array& in, const dim_t dim)
AF_CHECK(af_release_array(right));
AF_CHECK(af_release_array(sumarr));
AF_CHECK(af_release_array(carr));
AF_CHECK(af_release_array(sortedIn_handle));
return result;
}
}
Expand Down
30 changes: 28 additions & 2 deletions src/backend/cblas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,33 @@
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/

#include <blas.hpp>

#ifdef USE_F77_BLAS

#ifdef AF_CPU
#include <blas.hpp>
#else
#ifdef __APPLE__
#include <Accelerate/Accelerate.h>
#else
#ifdef USE_MKL
#include <mkl_cblas.h>
#else
extern "C" {
#include <cblas.h>
}
#endif
#endif

// TODO: Ask upstream for a more official way to detect it
#ifdef OPENBLAS_CONST
#define IS_OPENBLAS
#endif

#ifndef IS_OPENBLAS
typedef int blasint;
#endif
#endif

#define ADD_
#include <cblas_f77.h>

Expand Down Expand Up @@ -80,4 +104,6 @@ GEMM_F77(c, void *, void, float)
GEMM_F77(z, void *, void, double)
#undef ADDR

#else
#include <blas.hpp>
#endif
37 changes: 32 additions & 5 deletions src/backend/cpu/blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,39 @@ namespace cpu
using std::remove_const;
using std::conditional;

// Some implementations of BLAS require void* for complex pointers while
// others use float*/double*
#if (defined(OS_WIN) && !defined(USE_MKL)) || defined(IS_OPENBLAS)
static const bool cplx_void_ptr = false;
// Some implementations of BLAS require void* for complex pointers while others use float*/double*
//
// Sample cgemm API
// OpenBLAS
// void cblas_cgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB,
// OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
// OPENBLAS_CONST float *alpha, OPENBLAS_CONST float *A, OPENBLAS_CONST blasint lda,
// OPENBLAS_CONST float *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float *beta,
// float *C, OPENBLAS_CONST blasint ldc);
//
// MKL
// void cblas_cgemm(const CBLAS_LAYOUT Layout, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
// const MKL_INT M, const MKL_INT N, const MKL_INT K,
// const void *alpha, const void *A, const MKL_INT lda,
// const void *B, const MKL_INT ldb, const void *beta,
// void *C, const MKL_INT ldc);
// atlas cblas
// void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
// const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
// const void *alpha, const void *A, const int lda,
// const void *B, const int ldb, const void *beta,
// void *C, const int ldc);
//
// LAPACKE
// void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
// const enum CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
// const void *alpha, const void *A, const int lda,
// const void *B, const int ldb, const void *beta,
// void *C, const int ldc);
#if defined(IS_OPENBLAS)
static const bool cplx_void_ptr = false;
#else
static const bool cplx_void_ptr = true;
static const bool cplx_void_ptr = true;
#endif

template<typename T, class Enable = void>
Expand Down