Skip to content

Commit 99a9234

Browse files
committed
Allow users to set precision when using print
1 parent 28d9e71 commit 99a9234

File tree

3 files changed

+68
-16
lines changed

3 files changed

+68
-16
lines changed

include/af/util.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,23 @@ namespace af
3030
*/
3131
AFAPI void print(const char *exp, const array &arr);
3232

33+
/**
34+
\param[in] exp is an expression, generally the name of the array
35+
\param[in] arr is the input array
36+
\param[in] precision is the precision length for display
37+
38+
\ingroup print_func_print
39+
*/
40+
AFAPI void print(const char *exp, const array &arr, const int precision);
41+
3342
// Purpose of Addition: "How to add Function" documentation
3443
AFAPI array exampleFunction(const array& in, const af_someenum_t param);
3544
}
3645

3746
#define af_print(exp) af::print(#exp, exp);
3847

48+
#define af_print_p(exp, p) af::print(#exp, exp, p);
49+
3950
#endif //__cplusplus
4051

4152
#ifdef __cplusplus
@@ -238,6 +249,16 @@ extern "C" {
238249
*/
239250
AFAPI af_err af_print_array(af_array arr);
240251

252+
/**
253+
\param[in] arr is the input array
254+
\param[in] precision precision for the display
255+
256+
\returns error codes
257+
258+
\ingroup print_func_print
259+
*/
260+
AFAPI af_err af_print_array_p(af_array arr, const int precision);
261+
241262
// Purpose of Addition: "How to add Function" documentation
242263
AFAPI af_err af_example_function(af_array* out, const af_array in, const af_someenum_t param);
243264

src/api/c/print.cpp

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ using std::endl;
2828
using std::vector;
2929

3030
template<typename T>
31-
static void printer(ostream &out, const T* ptr, const ArrayInfo &info, unsigned dim)
31+
static void printer(ostream &out, const T* ptr, const ArrayInfo &info, unsigned dim, const int precision)
3232
{
3333

3434
dim_t stride = info.strides()[dim];
@@ -38,22 +38,22 @@ static void printer(ostream &out, const T* ptr, const ArrayInfo &info, unsigned
3838
if(dim == 0) {
3939
for(dim_t i = 0, j = 0; i < d; i++, j+=stride) {
4040
out<< std::fixed <<
41-
std::setw(10) <<
42-
std::setprecision(4) << toNum(ptr[j]) << " ";
41+
std::setw(precision + 6) <<
42+
std::setprecision(precision) << toNum(ptr[j]) << " ";
4343
}
4444
out << endl;
4545
}
4646
else {
4747
for(dim_t i = 0; i < d; i++) {
48-
printer(out, ptr, info, dim - 1);
48+
printer(out, ptr, info, dim - 1, precision);
4949
ptr += stride;
5050
}
5151
out << endl;
5252
}
5353
}
5454

5555
template<typename T>
56-
static void print(af_array arr)
56+
static void print(af_array arr, const int precision)
5757
{
5858
const ArrayInfo info = getInfo(arr);
5959
vector<T> data(info.elements());
@@ -74,7 +74,7 @@ static void print(af_array arr)
7474
std::cout <<" Strides: ["<<info.strides()<<"]"<<std::endl;
7575
#endif
7676

77-
printer(std::cout, &data.front(), infoT, infoT.ndims() - 1);
77+
printer(std::cout, &data.front(), infoT, infoT.ndims() - 1, precision);
7878

7979
std::cout.flags(backup);
8080
}
@@ -86,16 +86,40 @@ af_err af_print_array(af_array arr)
8686
af_dtype type = info.getType();
8787
switch(type)
8888
{
89-
case f32: print<float>(arr); break;
90-
case c32: print<cfloat>(arr); break;
91-
case f64: print<double>(arr); break;
92-
case c64: print<cdouble>(arr); break;
93-
case b8: print<char>(arr); break;
94-
case s32: print<int>(arr); break;
95-
case u32: print<unsigned>(arr); break;
96-
case u8: print<uchar>(arr); break;
97-
case s64: print<intl>(arr); break;
98-
case u64: print<uintl>(arr); break;
89+
case f32: print<float> (arr, 4); break;
90+
case c32: print<cfloat> (arr, 4); break;
91+
case f64: print<double> (arr, 4); break;
92+
case c64: print<cdouble> (arr, 4); break;
93+
case b8: print<char> (arr, 4); break;
94+
case s32: print<int> (arr, 4); break;
95+
case u32: print<unsigned>(arr, 4); break;
96+
case u8: print<uchar> (arr, 4); break;
97+
case s64: print<intl> (arr, 4); break;
98+
case u64: print<uintl> (arr, 4); break;
99+
default: TYPE_ERROR(1, type);
100+
}
101+
}
102+
CATCHALL;
103+
return AF_SUCCESS;
104+
}
105+
106+
af_err af_print_array_p(af_array arr, const int precision)
107+
{
108+
try {
109+
ArrayInfo info = getInfo(arr);
110+
af_dtype type = info.getType();
111+
switch(type)
112+
{
113+
case f32: print<float> (arr, precision); break;
114+
case c32: print<cfloat> (arr, precision); break;
115+
case f64: print<double> (arr, precision); break;
116+
case c64: print<cdouble> (arr, precision); break;
117+
case b8: print<char> (arr, precision); break;
118+
case s32: print<int> (arr, precision); break;
119+
case u32: print<unsigned>(arr, precision); break;
120+
case u8: print<uchar> (arr, precision); break;
121+
case s64: print<intl> (arr, precision); break;
122+
case u64: print<uintl> (arr, precision); break;
99123
default: TYPE_ERROR(1, type);
100124
}
101125
}

src/api/cpp/util.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,11 @@ namespace af
2222
AF_THROW(af_print_array(arr.get()));
2323
return;
2424
}
25+
26+
void print(const char *exp, const array &arr, const int precision)
27+
{
28+
printf("%s ", exp);
29+
AF_THROW(af_print_array_p(arr.get(), precision));
30+
return;
31+
}
2532
}

0 commit comments

Comments
 (0)