Skip to content

Commit 13cbff8

Browse files
committed
intl/uintl versions of select/replace to fix accuracy issues
Without specific intl/uintl versions of these functions, when a succifiently large 64 bit integer value is passed to select/replace the output is incorrect or getting transformed to zero. Signed-off-by: Pradeep Garigipati <pradeep.garigipati@gmail.com>
1 parent f58b849 commit 13cbff8

File tree

18 files changed

+406
-158
lines changed

18 files changed

+406
-158
lines changed

include/af/data.h

Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ namespace af
409409
\param[in] cond is the conditional array.
410410
\param[in] b is the replacement value.
411411
412-
\note Values of \p a are replaced with corresponding values of \p b, when \p cond is false.
412+
\note Values of \p a are replaced with value \p b, when \p cond is false.
413413
414414
\ingroup data_func_replace
415415
*/
@@ -432,6 +432,81 @@ namespace af
432432
AFAPI array pad(const array &in, const dim4 &beginPadding,
433433
const dim4 &endPadding, const borderType padFillType);
434434
#endif
435+
436+
#if AF_API_VERSION >= 39
437+
/**
438+
\param[inout] a is the input array
439+
\param[in] cond is the conditional array.
440+
\param[in] b is the replacement scalar value.
441+
442+
\note Values of \p a are replaced with value \p b, when \p cond is false.
443+
444+
\ingroup data_func_replace
445+
*/
446+
AFAPI void replace(array &a, const array &cond, const long long b);
447+
448+
/**
449+
\param[inout] a is the input array
450+
\param[in] cond is the conditional array.
451+
\param[in] b is the replacement scalar value.
452+
453+
\note Values of \p a are replaced with value \p b, when \p cond is false.
454+
455+
\ingroup data_func_replace
456+
*/
457+
AFAPI void replace(array &a, const array &cond,
458+
const unsigned long long b);
459+
460+
/**
461+
\param[in] cond is the conditional array
462+
\param[in] a is the array containing elements from the true part of the
463+
condition
464+
\param[in] b is a scalar assigned to \p out when \p cond is false
465+
\return the output containing elements of \p a when \p cond is true
466+
else the value \p b
467+
468+
\ingroup data_func_select
469+
*/
470+
AFAPI array select(const array &cond, const array &a, const long long b);
471+
472+
/**
473+
\param[in] cond is the conditional array
474+
\param[in] a is the array containing elements from the true part of the
475+
condition
476+
\param[in] b is a scalar assigned to \p out when \p cond is false
477+
\return the output containing elements of \p a when \p cond is true
478+
else the value \p b
479+
480+
\ingroup data_func_select
481+
*/
482+
AFAPI array select(const array &cond, const array &a,
483+
const unsigned long long b);
484+
485+
/**
486+
\param[in] cond is the conditional array
487+
\param[in] a is a scalar assigned to \p out when \p cond is true
488+
\param[in] b is the array containing elements from the false part of the
489+
condition
490+
\return the output containing the value \p a when \p cond is true else
491+
elements from \p b
492+
493+
\ingroup data_func_select
494+
*/
495+
AFAPI array select(const array &cond, const long long a, const array &b);
496+
497+
/**
498+
\param[in] cond is the conditional array
499+
\param[in] a is a scalar assigned to \p out when \p cond is true
500+
\param[in] b is the array containing elements from the false part of the
501+
condition
502+
\return the output containing the value \p a when \p cond is true else
503+
elements from \p b
504+
505+
\ingroup data_func_select
506+
*/
507+
AFAPI array select(const array &cond, const unsigned long long a,
508+
const array &b);
509+
#endif
435510
}
436511
#endif
437512

@@ -735,6 +810,90 @@ extern "C" {
735810
const af_border_type pad_fill_type);
736811
#endif
737812

813+
#if AF_API_VERSION >= 39
814+
/**
815+
\param[inout] a is the input array
816+
\param[in] cond is the conditional array.
817+
\param[in] b is the replacement array.
818+
819+
\note Values of \p a are replaced with corresponding values of \p b, when
820+
\p cond is false.
821+
822+
\ingroup data_func_replace
823+
*/
824+
AFAPI af_err af_replace_scalar_long(af_array a, const af_array cond,
825+
const long long b);
826+
827+
/**
828+
\param[inout] a is the input array
829+
\param[in] cond is the conditional array.
830+
\param[in] b is the replacement array.
831+
832+
\note Values of \p a are replaced with corresponding values of \p b, when
833+
\p cond is false.
834+
835+
\ingroup data_func_replace
836+
*/
837+
AFAPI af_err af_replace_scalar_ulong(af_array a, const af_array cond,
838+
const unsigned long long b);
839+
840+
/**
841+
\param[out] out is the output containing elements of \p a when \p cond is
842+
true else elements from \p b
843+
\param[in] cond is the conditional array
844+
\param[in] a is the array containing elements from the true part of the
845+
condition
846+
\param[in] b is a scalar assigned to \p out when \p cond is
847+
false
848+
849+
\ingroup data_func_select
850+
*/
851+
AFAPI af_err af_select_scalar_r_long(af_array *out, const af_array cond,
852+
const af_array a, const long long b);
853+
854+
/**
855+
\param[out] out is the output containing elements of \p a when \p cond is
856+
true else elements from \p b
857+
\param[in] cond is the conditional array
858+
\param[in] a is the array containing elements from the true part of the
859+
condition
860+
\param[in] b is a scalar assigned to \p out when \p cond is
861+
false
862+
863+
\ingroup data_func_select
864+
*/
865+
AFAPI af_err af_select_scalar_r_ulong(af_array *out, const af_array cond,
866+
const af_array a,
867+
const unsigned long long b);
868+
869+
/**
870+
\param[out] out is the output containing elements of \p a when \p cond is
871+
true else elements from \p b
872+
\param[in] cond is the conditional array
873+
\param[in] a is a scalar assigned to \p out when \p cond is true
874+
\param[in] b is the array containing elements from the false part of the
875+
condition
876+
877+
\ingroup data_func_select
878+
*/
879+
AFAPI af_err af_select_scalar_l_long(af_array *out, const af_array cond,
880+
const long long a, const af_array b);
881+
882+
/**
883+
\param[out] out is the output containing elements of \p a when \p cond is
884+
true else elements from \p b
885+
\param[in] cond is the conditional array
886+
\param[in] a is a scalar assigned to \p out when \p cond is true
887+
\param[in] b is the array containing elements from the false part of the
888+
condition
889+
890+
\ingroup data_func_select
891+
*/
892+
AFAPI af_err af_select_scalar_l_ulong(af_array *out, const af_array cond,
893+
const unsigned long long a,
894+
const af_array b);
895+
#endif
896+
738897
#ifdef __cplusplus
739898
}
740899
#endif

src/api/c/deconvolution.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <fftconvolve.hpp>
2020
#include <handle.hpp>
2121
#include <logic.hpp>
22+
#include <math.hpp>
2223
#include <reduce.hpp>
2324
#include <select.hpp>
2425
#include <shift.hpp>
@@ -294,7 +295,7 @@ af_array invDeconv(const af_array in, const af_array ker, const float gamma,
294295
auto cond = logicOp<T, af_ge_t>(absVal, THRESH, absVal.dims());
295296
auto val = arithOp<CT, af_div_t>(numer, denom, numer.dims());
296297

297-
select_scalar<CT, false>(val, cond, val, 0);
298+
select_scalar<CT, false>(val, cond, val, scalar<CT>(0.0));
298299

299300
auto ival =
300301
fft_c2r<CT, T>(val, 1 / static_cast<double>(nElems), odims, BASE_DIM);

src/api/c/replace.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,15 @@ af_err af_replace(af_array a, const af_array cond, const af_array b) {
8282
return AF_SUCCESS;
8383
}
8484

85-
template<typename T>
86-
void replace_scalar(af_array a, const af_array cond, const double b) {
87-
select_scalar<T, false>(getCopyOnWriteArray<T>(a), getArray<char>(cond),
88-
getArray<T>(a), b);
85+
template<typename ArrayType, typename ScalarType>
86+
void replace_scalar(af_array a, const af_array cond, const ScalarType& b) {
87+
select_scalar<ArrayType, false>(
88+
getCopyOnWriteArray<ArrayType>(a), getArray<char>(cond),
89+
getArray<ArrayType>(a), detail::scalar<ArrayType>(b));
8990
}
9091

91-
af_err af_replace_scalar(af_array a, const af_array cond, const double b) {
92+
template<typename ScalarType>
93+
af_err replaceScalar(af_array a, const af_array cond, const ScalarType b) {
9294
try {
9395
const ArrayInfo& ainfo = getInfo(a);
9496
const ArrayInfo& cinfo = getInfo(cond);
@@ -121,3 +123,17 @@ af_err af_replace_scalar(af_array a, const af_array cond, const double b) {
121123
CATCHALL;
122124
return AF_SUCCESS;
123125
}
126+
127+
af_err af_replace_scalar(af_array a, const af_array cond, const double b) {
128+
return replaceScalar(a, cond, b);
129+
}
130+
131+
af_err af_replace_scalar_long(af_array a, const af_array cond,
132+
const long long b) {
133+
return replaceScalar(a, cond, b);
134+
}
135+
136+
af_err af_replace_scalar_ulong(af_array a, const af_array cond,
137+
const unsigned long long b) {
138+
return replaceScalar(a, cond, b);
139+
}

0 commit comments

Comments
 (0)