@@ -120,6 +120,9 @@ namespace xt
120120 static bool check_ (pybind11::handle h);
121121 static PyObject* raw_array_t (PyObject* ptr);
122122
123+ derived_type& derived_cast ();
124+ const derived_type& derived_cast () const ;
125+
123126 PyArrayObject* python_array () const ;
124127 size_type get_min_stride () const ;
125128 };
@@ -260,6 +263,19 @@ namespace xt
260263 return std::max (size_type (1 ), std::accumulate (this ->strides ().cbegin (), this ->strides ().cend (), std::numeric_limits<size_type>::max (), min));
261264 }
262265
266+ template <class D >
267+ inline auto pycontainer<D>::derived_cast() -> derived_type&
268+ {
269+ return *static_cast <derived_type*>(this );
270+ }
271+
272+ template <class D >
273+ inline auto pycontainer<D>::derived_cast() const -> const derived_type&
274+ {
275+ return *static_cast <const derived_type*>(this );
276+ }
277+
278+
263279 /* *
264280 * resizes the container.
265281 * @param shape the new shape
@@ -330,7 +346,11 @@ namespace xt
330346 }
331347
332348 PyArray_Dims dims ({reinterpret_cast <npy_intp*>(shape.data ()), static_cast <int >(shape.size ())});
333- PyArray_Newshape ((PyArrayObject*) this ->ptr (), &dims, npy_layout);
349+ auto new_ptr = PyArray_Newshape ((PyArrayObject*) this ->ptr (), &dims, npy_layout);
350+ auto old_ptr = this ->ptr ();
351+ this ->ptr () = new_ptr;
352+ Py_XDECREF (old_ptr);
353+ this ->derived_cast ().init_from_python ();
334354 }
335355
336356 /* *
0 commit comments