Skip to content

Commit ecf05c4

Browse files
mborgerdingstefanseefeld
authored andcommitted
ndarray.shape(k),strides(k) act more like their python counterparts (negative indexing, bounds checking) (issue boostorg#157)
1 parent 00b7ed0 commit ecf05c4

File tree

4 files changed

+66
-4
lines changed

4 files changed

+66
-4
lines changed

include/boost/python/numpy/ndarray.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ class BOOST_NUMPY_DECL ndarray : public object
8686
/// @brief Copy the scalar (deep for all non-object fields).
8787
ndarray copy() const;
8888

89-
/// @brief Return the size of the nth dimension.
90-
Py_intptr_t shape(int n) const { return get_shape()[n]; }
89+
/// @brief Return the size of the nth dimension. raises IndexError if k not in [-get_nd() : get_nd()-1 ]
90+
Py_intptr_t shape(int n) const;
9191

92-
/// @brief Return the stride of the nth dimension.
93-
Py_intptr_t strides(int n) const { return get_strides()[n]; }
92+
/// @brief Return the stride of the nth dimension. raises IndexError if k not in [-get_nd() : get_nd()-1]
93+
Py_intptr_t strides(int n) const;
9494

9595
/**
9696
* @brief Return the array's raw data pointer.

src/numpy/ndarray.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,30 @@ ndarray from_data_impl(void * data,
138138

139139
} // namespace detail
140140

141+
namespace {
142+
int normalize_index(int n,int nlim) // wraps [-nlim:nlim) into [0:nlim), throw IndexError otherwise
143+
{
144+
if (n<0)
145+
n += nlim; // negative indices work backwards from end
146+
if (n < 0 || n >= nlim)
147+
{
148+
PyErr_SetObject(PyExc_IndexError, Py_None);
149+
throw_error_already_set();
150+
}
151+
return n;
152+
}
153+
}
154+
155+
Py_intptr_t ndarray::shape(int n) const
156+
{
157+
return get_shape()[normalize_index(n,get_nd())];
158+
}
159+
160+
Py_intptr_t ndarray::strides(int n) const
161+
{
162+
return get_strides()[normalize_index(n,get_nd())];
163+
}
164+
141165
ndarray ndarray::view(dtype const & dt) const
142166
{
143167
return ndarray(python::detail::new_reference

test/numpy/ndarray.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ np::ndarray transpose(np::ndarray arr) { return arr.transpose();}
3131
np::ndarray squeeze(np::ndarray arr) { return arr.squeeze();}
3232
np::ndarray reshape(np::ndarray arr,p::tuple tup) { return arr.reshape(tup);}
3333

34+
Py_intptr_t shape_index(np::ndarray arr,int k) { return arr.shape(k); }
35+
Py_intptr_t strides_index(np::ndarray arr,int k) { return arr.strides(k); }
36+
3437
BOOST_PYTHON_MODULE(ndarray_ext)
3538
{
3639
np::initialize();
@@ -43,4 +46,6 @@ BOOST_PYTHON_MODULE(ndarray_ext)
4346
p::def("transpose", transpose);
4447
p::def("squeeze", squeeze);
4548
p::def("reshape", reshape);
49+
p::def("shape_index", shape_index);
50+
p::def("strides_index", strides_index);
4651
}

test/numpy/ndarray.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,38 @@ def testReshape(self):
7575
a2 = ndarray_ext.reshape(a1,(1,4))
7676
self.assertEqual(a2.shape,(1,4))
7777

78+
def testShapeIndex(self):
79+
a = numpy.arange(24)
80+
a.shape = (1,2,3,4)
81+
def shape_check(i):
82+
print(i)
83+
self.assertEqual(ndarray_ext.shape_index(a,i) ,a.shape[i] )
84+
for i in range(4):
85+
shape_check(i)
86+
for i in range(-1,-5,-1):
87+
shape_check(i)
88+
try:
89+
ndarray_ext.shape_index(a,4) # out of bounds -- should raise IndexError
90+
self.assertTrue(False)
91+
except IndexError:
92+
pass
93+
94+
def testStridesIndex(self):
95+
a = numpy.arange(24)
96+
a.shape = (1,2,3,4)
97+
def strides_check(i):
98+
print(i)
99+
self.assertEqual(ndarray_ext.strides_index(a,i) ,a.strides[i] )
100+
for i in range(4):
101+
strides_check(i)
102+
for i in range(-1,-5,-1):
103+
strides_check(i)
104+
try:
105+
ndarray_ext.strides_index(a,4) # out of bounds -- should raise IndexError
106+
self.assertTrue(False)
107+
except IndexError:
108+
pass
109+
110+
78111
if __name__=="__main__":
79112
unittest.main()

0 commit comments

Comments
 (0)