diff --git a/include/af/index.h b/include/af/index.h index 295739077e..373804dc17 100644 --- a/include/af/index.h +++ b/include/af/index.h @@ -103,6 +103,15 @@ class AFAPI index { /// index(const af::array& idx0); + /// + /// \brief Copy constructor + /// + /// \param[in] idx0 is index to copy. + /// + /// \sa indexing + /// + index(const index& idx0); + /// /// \brief Returns true if the \ref af::index represents a af::span object /// @@ -116,6 +125,31 @@ class AFAPI index { /// \returns the af_index_t represented by this object /// const af_index_t& get() const; + + /// + /// \brief Assigns idx0 to this index + /// + /// \param[in] idx0 is the index to be assigned to the /ref af::index + /// \returns the reference to this + /// + /// + index & operator=(const index& idx0); + +#if __cplusplus > 199711L + /// + /// \brief Move constructor + /// + /// \param[in] idx0 is index to copy. + /// + index(index &&idx0); + /// + /// \brief Move assignment operator + /// + /// \param[in] idx0 is the index to be assigned to the /ref af::index + /// \returns a reference to this + /// + index& operator=(index &&idx0); +#endif }; /// diff --git a/src/api/cpp/index.cpp b/src/api/cpp/index.cpp index 2b66dd62c1..ccc698bc3c 100644 --- a/src/api/cpp/index.cpp +++ b/src/api/cpp/index.cpp @@ -75,11 +75,36 @@ index::index(const af::array& idx0) { impl.isBatch = false; } +index::index(const af::index& idx0) { + *this = idx0; +} + index::~index() { if (!impl.isSeq) af_release_array(impl.idx.arr); } +index & index::operator=(const index& idx0) { + impl = idx0.get(); + if(impl.isSeq == false){ + // increment reference count to avoid double free + // when/if idx0 is destroyed + AF_THROW(af_retain_array(&impl.idx.arr, impl.idx.arr)); + } + return *this; +} + +#if __cplusplus > 199711L +index::index(index &&idx0) { + impl = idx0.impl; +} + +index& index::operator=(index &&idx0) { + impl = idx0.impl; + return *this; +} +#endif + static bool operator==(const af_seq& lhs, const af_seq& rhs) { return lhs.begin == rhs.begin && lhs.end == rhs.end && lhs.step == rhs.step; diff --git a/test/index.cpp b/test/index.cpp index 9f23ae5fe3..28b57459fc 100644 --- a/test/index.cpp +++ b/test/index.cpp @@ -1251,3 +1251,13 @@ TEST(Indexing, SNIPPET_indexing_ref) //! [ex_indexing_ref] //TODO: Confirm the outputs are correct. see #697 } + +TEST(Indexing, SNIPPET_indexing_copy) +{ + af::array A = af::constant(0,1, s32); + af::index s1; + s1 = af::index(A); + // At exit both A and s1 will be destroyed + // but the underlying array should only be + // freed once. +}