File tree Expand file tree Collapse file tree 4 files changed +66
-4
lines changed
include/boost/python/numpy Expand file tree Collapse file tree 4 files changed +66
-4
lines changed Original file line number Diff line number Diff line change @@ -86,11 +86,11 @@ class BOOST_NUMPY_DECL ndarray : public object
8686 // / @brief Copy the scalar (deep for all non-object fields).
8787 ndarray copy () const ;
8888
89- // / @brief Return the size of the nth dimension.
90- Py_intptr_t shape (int n) const { return get_shape ()[n]; }
89+ // / @brief Return the size of the nth dimension. raises IndexError if k not in [-get_nd() : get_nd()-1 ]
90+ Py_intptr_t shape (int n) const ;
9191
92- // / @brief Return the stride of the nth dimension.
93- Py_intptr_t strides (int n) const { return get_strides ()[n]; }
92+ // / @brief Return the stride of the nth dimension. raises IndexError if k not in [-get_nd() : get_nd()-1]
93+ Py_intptr_t strides (int n) const ;
9494
9595 /* *
9696 * @brief Return the array's raw data pointer.
Original file line number Diff line number Diff line change @@ -138,6 +138,30 @@ ndarray from_data_impl(void * data,
138138
139139} // namespace detail
140140
141+ namespace {
142+ int normalize_index (int n,int nlim) // wraps [-nlim:nlim) into [0:nlim), throw IndexError otherwise
143+ {
144+ if (n<0 )
145+ n += nlim; // negative indices work backwards from end
146+ if (n < 0 || n >= nlim)
147+ {
148+ PyErr_SetObject (PyExc_IndexError, Py_None);
149+ throw_error_already_set ();
150+ }
151+ return n;
152+ }
153+ }
154+
155+ Py_intptr_t ndarray::shape (int n) const
156+ {
157+ return get_shape ()[normalize_index (n,get_nd ())];
158+ }
159+
160+ Py_intptr_t ndarray::strides (int n) const
161+ {
162+ return get_strides ()[normalize_index (n,get_nd ())];
163+ }
164+
141165ndarray ndarray::view (dtype const & dt) const
142166{
143167 return ndarray (python::detail::new_reference
Original file line number Diff line number Diff line change @@ -31,6 +31,9 @@ np::ndarray transpose(np::ndarray arr) { return arr.transpose();}
3131np::ndarray squeeze (np::ndarray arr) { return arr.squeeze ();}
3232np::ndarray reshape (np::ndarray arr,p::tuple tup) { return arr.reshape (tup);}
3333
34+ Py_intptr_t shape_index (np::ndarray arr,int k) { return arr.shape (k); }
35+ Py_intptr_t strides_index (np::ndarray arr,int k) { return arr.strides (k); }
36+
3437BOOST_PYTHON_MODULE (ndarray_ext)
3538{
3639 np::initialize ();
@@ -43,4 +46,6 @@ BOOST_PYTHON_MODULE(ndarray_ext)
4346 p::def (" transpose" , transpose);
4447 p::def (" squeeze" , squeeze);
4548 p::def (" reshape" , reshape);
49+ p::def (" shape_index" , shape_index);
50+ p::def (" strides_index" , strides_index);
4651}
Original file line number Diff line number Diff line change @@ -75,5 +75,38 @@ def testReshape(self):
7575 a2 = ndarray_ext .reshape (a1 ,(1 ,4 ))
7676 self .assertEqual (a2 .shape ,(1 ,4 ))
7777
78+ def testShapeIndex (self ):
79+ a = numpy .arange (24 )
80+ a .shape = (1 ,2 ,3 ,4 )
81+ def shape_check (i ):
82+ print (i )
83+ self .assertEqual (ndarray_ext .shape_index (a ,i ) ,a .shape [i ] )
84+ for i in range (4 ):
85+ shape_check (i )
86+ for i in range (- 1 ,- 5 ,- 1 ):
87+ shape_check (i )
88+ try :
89+ ndarray_ext .shape_index (a ,4 ) # out of bounds -- should raise IndexError
90+ self .assertTrue (False )
91+ except IndexError :
92+ pass
93+
94+ def testStridesIndex (self ):
95+ a = numpy .arange (24 )
96+ a .shape = (1 ,2 ,3 ,4 )
97+ def strides_check (i ):
98+ print (i )
99+ self .assertEqual (ndarray_ext .strides_index (a ,i ) ,a .strides [i ] )
100+ for i in range (4 ):
101+ strides_check (i )
102+ for i in range (- 1 ,- 5 ,- 1 ):
103+ strides_check (i )
104+ try :
105+ ndarray_ext .strides_index (a ,4 ) # out of bounds -- should raise IndexError
106+ self .assertTrue (False )
107+ except IndexError :
108+ pass
109+
110+
78111if __name__ == "__main__" :
79112 unittest .main ()
You can’t perform that action at this time.
0 commit comments