Skip to content

Commit 0672f56

Browse files
committed
FEAT: Adding function to get use_count of shared pointers
1 parent 8ac5cb9 commit 0672f56

5 files changed

Lines changed: 55 additions & 0 deletions

File tree

include/af/array.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,15 @@ extern "C" {
12691269
*/
12701270
AFAPI af_err af_retain_array(af_array *out, const af_array in);
12711271

1272+
/**
1273+
\ingroup method_mat
1274+
@{
1275+
1276+
Get the use count of `af_array`
1277+
*/
1278+
AFAPI af_err af_get_data_ref_count(int *use_count, const af_array in);
1279+
1280+
12721281
/**
12731282
Evaluate any expressions in the Array
12741283
*/

src/api/c/data.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,34 @@ af_err af_copy_array(af_array *out, const af_array in)
247247
return AF_SUCCESS;
248248
}
249249

250+
//Strong Exception Guarantee
251+
af_err af_get_data_ref_count(int *use_count, const af_array in)
252+
{
253+
try {
254+
ArrayInfo info = getInfo(in);
255+
const af_dtype type = info.getType();
256+
257+
int res;
258+
switch(type) {
259+
case f32: res = getArray<float >(in).useCount(); break;
260+
case c32: res = getArray<cfloat >(in).useCount(); break;
261+
case f64: res = getArray<double >(in).useCount(); break;
262+
case c64: res = getArray<cdouble >(in).useCount(); break;
263+
case b8: res = getArray<char >(in).useCount(); break;
264+
case s32: res = getArray<int >(in).useCount(); break;
265+
case u32: res = getArray<uint >(in).useCount(); break;
266+
case u8: res = getArray<uchar >(in).useCount(); break;
267+
case s64: res = getArray<intl >(in).useCount(); break;
268+
case u64: res = getArray<uintl >(in).useCount(); break;
269+
default: TYPE_ERROR(1, type);
270+
}
271+
std::swap(*use_count, res);
272+
}
273+
CATCHALL
274+
return AF_SUCCESS;
275+
}
276+
277+
250278
template<typename T>
251279
static inline af_array randn_(const af::dim4 &dims)
252280
{

src/backend/cpu/Array.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ namespace cpu
170170
return data.get() + (withOffset ? offset : 0);
171171
}
172172

173+
int useCount() const
174+
{
175+
if (!isReady()) eval();
176+
return data.use_count();
177+
}
178+
173179
TNJ::Node_ptr getNode() const;
174180

175181
friend Array<T> createValueArray<T>(const af::dim4 &size, const T& value);

src/backend/cuda/Array.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ namespace cuda
181181
return data.get() + (withOffset ? offset : 0);
182182
}
183183

184+
int useCount() const
185+
{
186+
if (!isReady()) eval();
187+
return data.use_count();
188+
}
189+
184190
operator Param<T>()
185191
{
186192
Param<T> out;

src/backend/opencl/Array.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,12 @@ namespace opencl
162162
return data.get();
163163
}
164164

165+
int useCount() const
166+
{
167+
if (!isReady()) eval();
168+
return data.use_count();
169+
}
170+
165171
const dim_t getOffset() const
166172
{
167173
return offset;

0 commit comments

Comments
 (0)