Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions include/af/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ class AFAPI index {
///
index(const af::array& idx0);

///
/// \brief Copy constructor
///
/// \param[in] idx0 is index to copy.
///
/// \sa indexing
///
index(const index& idx0);

///
/// \brief Returns true if the \ref af::index represents a af::span object
///
Expand All @@ -116,6 +125,31 @@ class AFAPI index {
/// \returns the af_index_t represented by this object
///
const af_index_t& get() const;

///
/// \brief Assigns idx0 to this index
///
/// \param[in] idx0 is the index to be assigned to the /ref af::index
/// \returns the reference to this
///
///
index & operator=(const index& idx0);

#if __cplusplus > 199711L
///
/// \brief Move constructor
///
/// \param[in] idx0 is index to copy.
///
index(index &&idx0);
///
/// \brief Move assignment operator
///
/// \param[in] idx0 is the index to be assigned to the /ref af::index
/// \returns a reference to this
///
index& operator=(index &&idx0);
#endif
};

///
Expand Down
25 changes: 25 additions & 0 deletions src/api/cpp/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,36 @@ index::index(const af::array& idx0) {
impl.isBatch = false;
}

index::index(const af::index& idx0) {
*this = idx0;
}

index::~index() {
if (!impl.isSeq)
af_release_array(impl.idx.arr);
}

index & index::operator=(const index& idx0) {
impl = idx0.get();
if(impl.isSeq == false){
// increment reference count to avoid double free
// when/if idx0 is destroyed
AF_THROW(af_retain_array(&impl.idx.arr, impl.idx.arr));
}
return *this;
}

#if __cplusplus > 199711L
index::index(index &&idx0) {
impl = idx0.impl;
}

index& index::operator=(index &&idx0) {
impl = idx0.impl;
return *this;
}
#endif


static bool operator==(const af_seq& lhs, const af_seq& rhs) {
return lhs.begin == rhs.begin && lhs.end == rhs.end && lhs.step == rhs.step;
Expand Down
10 changes: 10 additions & 0 deletions test/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1251,3 +1251,13 @@ TEST(Indexing, SNIPPET_indexing_ref)
//! [ex_indexing_ref]
//TODO: Confirm the outputs are correct. see #697
}

TEST(Indexing, SNIPPET_indexing_copy)
{
af::array A = af::constant(0,1, s32);
af::index s1;
s1 = af::index(A);
// At exit both A and s1 will be destroyed
// but the underlying array should only be
// freed once.
}