/******************************************************* * Copyright (c) 2014, ArrayFire * All rights reserved. * * This file is distributed under 3-clause BSD license. * The complete license agreement can be obtained at: * http://arrayfire.com/licenses/BSD-3-Clause ********************************************************/ #include #include #include #include #include #include #include #include using af::dim4; using detail::Array; using detail::cdouble; using detail::cfloat; using detail::createEmptyArray; using detail::isLAPACKAvailable; template static inline void lu(af_array *lower, af_array *upper, af_array *pivot, const af_array in) { Array lowerArray = createEmptyArray(af::dim4()); Array upperArray = createEmptyArray(af::dim4()); Array pivotArray = createEmptyArray(af::dim4()); lu(lowerArray, upperArray, pivotArray, getArray(in)); *lower = getHandle(lowerArray); *upper = getHandle(upperArray); *pivot = getHandle(pivotArray); } template static inline af_array lu_inplace(af_array in, bool is_lapack_piv) { return getHandle(lu_inplace(getArray(in), !is_lapack_piv)); } af_err af_lu(af_array *lower, af_array *upper, af_array *pivot, const af_array in) { try { const ArrayInfo &i_info = getInfo(in); if (i_info.ndims() > 2) { AF_ERROR("lu can not be used in batch mode", AF_ERR_BATCH); } af_dtype type = i_info.getType(); ARG_ASSERT(0, lower != nullptr); ARG_ASSERT(1, upper != nullptr); ARG_ASSERT(2, pivot != nullptr); ARG_ASSERT(3, i_info.isFloating()); // Only floating and complex types if (i_info.ndims() == 0) { AF_CHECK(af_create_handle(lower, 0, nullptr, type)); AF_CHECK(af_create_handle(upper, 0, nullptr, type)); AF_CHECK(af_create_handle(pivot, 0, nullptr, type)); return AF_SUCCESS; } switch (type) { case f32: lu(lower, upper, pivot, in); break; case f64: lu(lower, upper, pivot, in); break; case c32: lu(lower, upper, pivot, in); break; case c64: lu(lower, upper, pivot, in); break; default: TYPE_ERROR(1, type); } } CATCHALL; return AF_SUCCESS; } af_err af_lu_inplace(af_array *pivot, af_array in, const bool is_lapack_piv) { try { const ArrayInfo &i_info = getInfo(in); af_dtype type = i_info.getType(); if (i_info.ndims() > 2) { AF_ERROR("lu can not be used in batch mode", AF_ERR_BATCH); } ARG_ASSERT(1, i_info.isFloating()); // Only floating and complex types ARG_ASSERT(0, pivot != nullptr); if (i_info.ndims() == 0) { return af_create_handle(pivot, 0, nullptr, type); } af_array out; switch (type) { case f32: out = lu_inplace(in, is_lapack_piv); break; case f64: out = lu_inplace(in, is_lapack_piv); break; case c32: out = lu_inplace(in, is_lapack_piv); break; case c64: out = lu_inplace(in, is_lapack_piv); break; default: TYPE_ERROR(1, type); } std::swap(*pivot, out); } CATCHALL; return AF_SUCCESS; } af_err af_is_lapack_available(bool *out) { try { *out = isLAPACKAvailable(); } CATCHALL; return AF_SUCCESS; }