diff --git a/src/api/c/mean.cpp b/src/api/c/mean.cpp index d7908c7f37..59e71d02a6 100644 --- a/src/api/c/mean.cpp +++ b/src/api/c/mean.cpp @@ -33,7 +33,7 @@ static outType mean(const af_array &in) template static outType mean(const af_array &in, const af_array &weights) { - typedef baseOutType bType; + typedef typename baseOutType::type bType; Array input = cast(getArray(in)); Array wts = cast(getArray(weights)); @@ -55,7 +55,7 @@ static af_array mean(const af_array &in, dim_type dim) template static af_array mean(const af_array &in, const af_array &weights, dim_type dim) { - typedef baseOutType bType; + typedef typename baseOutType::type bType; Array input = cast(getArray(in)); Array wts = cast(getArray(weights)); diff --git a/src/api/c/stats.h b/src/api/c/stats.h index 5c4a889e55..e41cefd179 100644 --- a/src/api/c/stats.h +++ b/src/api/c/stats.h @@ -9,11 +9,37 @@ #pragma once +template +struct is_same{ + static const bool value = false; +}; + template -using baseOutType = typename std::conditional< std::is_same::value || - std::is_same::value, - double, - float>::type; +struct is_same { + static const bool value = true; +}; + +template +struct cond_type; + +template +struct cond_type { + typedef T type; +}; + +template +struct cond_type { + typedef Other type; +}; + +template +struct baseOutType { + typedef typename cond_type< is_same::value || + is_same::value, + double, + float>::type type; +}; + template inline T mean(const Array& in) { diff --git a/src/api/c/var.cpp b/src/api/c/var.cpp index 10b5f54632..9cee7c32b4 100644 --- a/src/api/c/var.cpp +++ b/src/api/c/var.cpp @@ -43,7 +43,7 @@ static outType varAll(const af_array& in, bool isbiased) template static outType varAll(const af_array& in, const af_array weights) { - typedef baseOutType bType; + typedef typename baseOutType::type bType; Array input = cast(getArray(in)); Array wts = cast(getArray(weights)); @@ -91,7 +91,7 @@ static af_array var(const af_array& in, bool isbiased, int dim) template static af_array var(const af_array& in, const af_array& weights, dim_type dim) { - typedef baseOutType bType; + typedef typename baseOutType::type bType; Array input = cast(getArray(in)); Array wts = cast(getArray(weights));