Skip to content

Commit 3c7bd0f

Browse files
committed
Add T dot<T>() function to return scalar from dot operation
1 parent 2fa5311 commit 3c7bd0f

File tree

4 files changed

+156
-8
lines changed

4 files changed

+156
-8
lines changed

include/af/blas.h

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,39 @@ namespace af
166166
const matProp optLhs = AF_MAT_NONE,
167167
const matProp optRhs = AF_MAT_NONE);
168168

169+
#if AF_API_VERSION >= 35
170+
/**
171+
\brief Return the dot product of two vectors as a scalar
172+
173+
Scalar dot product between two vectors. Also referred to as the inner
174+
product.
175+
176+
\code
177+
// compute scalar dot product
178+
array x = randu(100), y = randu(100);
179+
float h_dot = dot<float>(x,y);
180+
\endcode
181+
182+
\param[in] lhs The array object on the left hand side
183+
\param[in] rhs The array object on the right hand side
184+
\param[in] optLhs Options for lhs. Currently only \ref AF_MAT_NONE and
185+
AF_MAT_CONJ are supported.
186+
\param[in] optRhs Options for rhs. Currently only \ref AF_MAT_NONE and AF_MAT_CONJ are supported
187+
\return The result of the dot product of lhs, rhs as a host scalar
188+
189+
\note optLhs and optRhs can only be one of \ref AF_MAT_NONE or \ref AF_MAT_CONJ
190+
\note optLhs = AF_MAT_CONJ and optRhs = AF_MAT_NONE will run conjugate dot operation.
191+
\note This function is not supported in GFOR
192+
193+
\returns out = dot(lhs, rhs)
194+
195+
\ingroup blas_func_dot
196+
*/
197+
template<typename T> T dot(const array &lhs, const array &rhs,
198+
const matProp optLhs = AF_MAT_NONE,
199+
const matProp optRhs = AF_MAT_NONE);
200+
#endif
201+
169202
/**
170203
\brief Transposes a matrix
171204
@@ -235,11 +268,41 @@ extern "C" {
235268
print(dot<float>(x,y));
236269
\endcode
237270
271+
\param[out] out The array object with the result of the dot operation
272+
\param[in] lhs The array object on the left hand side
273+
\param[in] rhs The array object on the right hand side
274+
\param[in] optLhs Options for lhs. Currently only \ref AF_MAT_NONE and
275+
AF_MAT_CONJ are supported.
276+
\param[in] optRhs Options for rhs. Currently only \ref AF_MAT_NONE and AF_MAT_CONJ are supported
277+
\return AF_SUCCESS if the process is successful.
278+
238279
\ingroup blas_func_dot
239280
*/
240-
AFAPI af_err af_dot( af_array *out,
281+
AFAPI af_err af_dot(af_array *out,
282+
const af_array lhs, const af_array rhs,
283+
const af_mat_prop optLhs, const af_mat_prop optRhs);
284+
285+
#if AF_API_VERSION >= 35
286+
/**
287+
Scalar dot product between two vectors. Also referred to as the inner
288+
product. Returns the result as a host scalar.
289+
290+
\param[out] real is the real component of the result of dot operation
291+
\param[out] imag is the imaginary component of the result of dot operation
292+
\param[in] lhs The array object on the left hand side
293+
\param[in] rhs The array object on the right hand side
294+
\param[in] optLhs Options for lhs. Currently only \ref AF_MAT_NONE and
295+
AF_MAT_CONJ are supported.
296+
\param[in] optRhs Options for rhs. Currently only \ref AF_MAT_NONE and AF_MAT_CONJ are supported
297+
298+
\return AF_SUCCESS if the process is successful.
299+
300+
\ingroup blas_func_dot
301+
*/
302+
AFAPI af_err af_dot_all(double *real, double *imag,
241303
const af_array lhs, const af_array rhs,
242304
const af_mat_prop optLhs, const af_mat_prop optRhs);
305+
#endif
243306

244307
/**
245308
\brief Transposes a matrix

src/api/c/blas.cpp

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ af_err af_matmul(af_array *out,
151151
return AF_SUCCESS;
152152
}
153153

154-
af_err af_dot( af_array *out,
155-
const af_array lhs, const af_array rhs,
156-
const af_mat_prop optLhs, const af_mat_prop optRhs)
154+
af_err af_dot(af_array *out,
155+
const af_array lhs, const af_array rhs,
156+
const af_mat_prop optLhs, const af_mat_prop optRhs)
157157
{
158158
using namespace detail;
159159

@@ -195,5 +195,55 @@ af_err af_dot( af_array *out,
195195
std::swap(*out, output);
196196
}
197197
CATCHALL
198-
return AF_SUCCESS;
198+
return AF_SUCCESS;
199+
}
200+
201+
template<typename T>
202+
static inline
203+
T dotAll(af_array out)
204+
{
205+
T res;
206+
AF_CHECK(af_eval(out));
207+
AF_CHECK(af_get_data_ptr((void *)&res, out));
208+
return res;
209+
}
210+
211+
af_err af_dot_all(double *rval, double *ival,
212+
const af_array lhs, const af_array rhs,
213+
const af_mat_prop optLhs, const af_mat_prop optRhs)
214+
{
215+
using namespace detail;
216+
217+
try {
218+
*rval = 0;
219+
if (ival) *ival = 0;
220+
221+
af_array out = 0;
222+
AF_CHECK(af_dot(&out, lhs, rhs, optLhs, optRhs));
223+
224+
ArrayInfo lhsInfo = getInfo(lhs);
225+
af_dtype lhs_type = lhsInfo.getType();
226+
227+
switch(lhs_type) {
228+
case f32: *rval = dotAll<float >(out); break;
229+
case f64: *rval = dotAll<double>(out); break;
230+
case c32:
231+
{
232+
cfloat temp = dotAll<cfloat>(out);
233+
*rval = real(temp);
234+
if (ival) *ival = imag(temp);
235+
} break;
236+
case c64:
237+
{
238+
cdouble temp = dotAll<cdouble>(out);
239+
*rval = real(temp);
240+
if (ival) *ival = imag(temp);
241+
} break;
242+
default: TYPE_ERROR(1, lhs_type);
243+
}
244+
245+
if(out != 0) AF_CHECK(af_release_array(out));
246+
}
247+
CATCHALL
248+
return AF_SUCCESS;
199249
}

src/api/cpp/blas.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,38 @@ namespace af
6969
}
7070
}
7171

72-
array dot (const array &lhs, const array &rhs,
73-
const matProp optLhs, const matProp optRhs)
72+
array dot(const array &lhs, const array &rhs,
73+
const matProp optLhs, const matProp optRhs)
7474
{
7575
af_array out = 0;
7676
AF_THROW(af_dot(&out, lhs.get(), rhs.get(), optLhs, optRhs));
7777
return array(out);
7878
}
79+
80+
#define INSTANTIATE_REAL(TYPE) \
81+
template<> AFAPI \
82+
TYPE dot(const array &lhs, const array &rhs, \
83+
const matProp optLhs, const matProp optRhs) \
84+
{ \
85+
double rval = 0, ival = 0; \
86+
AF_THROW(af_dot_all(&rval, &ival, lhs.get(), rhs.get(), optLhs, optRhs)); \
87+
return (TYPE)(rval); \
88+
}
89+
90+
#define INSTANTIATE_CPLX(TYPE, REAL) \
91+
template<> AFAPI \
92+
TYPE dot(const array &lhs, const array &rhs, \
93+
const matProp optLhs, const matProp optRhs) \
94+
{ \
95+
double rval = 0, ival = 0; \
96+
AF_THROW(af_dot_all(&rval, &ival, lhs.get(), rhs.get(), optLhs, optRhs)); \
97+
TYPE out((REAL)rval, (REAL)ival); \
98+
return out; \
99+
}
100+
101+
INSTANTIATE_REAL(float)
102+
INSTANTIATE_REAL(double)
103+
INSTANTIATE_CPLX(cfloat, float)
104+
INSTANTIATE_CPLX(cdouble, double)
105+
79106
}

src/api/unified/blas.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,22 @@ af_err af_matmul( af_array *out ,
1919
}
2020

2121

22-
af_err af_dot( af_array *out,
22+
af_err af_dot(af_array *out,
2323
const af_array lhs, const af_array rhs,
2424
const af_mat_prop optLhs, const af_mat_prop optRhs)
2525
{
2626
CHECK_ARRAYS(lhs, rhs);
2727
return CALL(out, lhs, rhs, optLhs, optRhs);
2828
}
2929

30+
af_err af_dot_all(double *rval, double *ival,
31+
const af_array lhs, const af_array rhs,
32+
const af_mat_prop optLhs, const af_mat_prop optRhs)
33+
{
34+
CHECK_ARRAYS(lhs, rhs);
35+
return CALL(rval, ival, lhs, rhs, optLhs, optRhs);
36+
}
37+
3038
af_err af_transpose(af_array *out, af_array in, const bool conjugate)
3139
{
3240
CHECK_ARRAYS(in);

0 commit comments

Comments
 (0)