Skip to content

Commit 08bd46b

Browse files
committed
Enabling iterator pairs on pyarray backstrides
1 parent 75e68b0 commit 08bd46b

2 files changed

Lines changed: 253 additions & 14 deletions

File tree

include/xtensor-python/pyarray.hpp

Lines changed: 215 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,21 +75,188 @@ namespace pybind11
7575
namespace xt
7676
{
7777

78+
/**************************
79+
* pybackstrides_iterator *
80+
**************************/
81+
82+
template <class A>
83+
class pybackstrides_iterator
84+
{
85+
public:
86+
87+
using self_type = pybackstrides_iterator<A>;
88+
89+
using value_type = typename A::size_type;
90+
using pointer = const value_type*;
91+
using reference = value_type;
92+
using difference_type = std::ptrdiff_t;
93+
using iterator_category = std::random_access_iterator_tag;
94+
95+
inline pybackstrides_iterator(const A* a, std::size_t offset)
96+
: p_a(a), m_offset(offset)
97+
{
98+
}
99+
100+
inline reference operator*() const
101+
{
102+
value_type sh = p_a->shape()[m_offset];
103+
value_type res = sh == 1 ? 0 : (sh - 1) * p_a->strides()[m_offset];
104+
return res;
105+
}
106+
107+
inline pointer operator->() const
108+
{
109+
// Returning the address of a temporary
110+
value_type sh = p_a->shape()[m_offset];
111+
value_type res = sh == 1 ? 0 : (sh - 1) * p_a->strides()[m_offset];
112+
return &res;
113+
}
114+
115+
inline reference operator[](difference_type n) const
116+
{
117+
auto index = m_offset + n;
118+
value_type sh = p_a->shape()[index];
119+
value_type res = sh == 1 ? 0 : (sh - 1) * p_a->strides()[index];
120+
return res;
121+
}
122+
123+
inline self_type& operator++()
124+
{
125+
++m_offset;
126+
return *this;
127+
}
128+
129+
inline self_type& operator--()
130+
{
131+
--m_offset;
132+
return *this;
133+
}
134+
135+
inline self_type operator++(int)
136+
{
137+
self_type tmp(*this);
138+
++m_offset;
139+
return tmp;
140+
}
141+
142+
inline self_type operator--(int)
143+
{
144+
self_type tmp(*this);
145+
--m_offset;
146+
return tmp;
147+
}
148+
149+
inline self_type& operator+=(difference_type n)
150+
{
151+
m_offset += n;
152+
return *this;
153+
}
154+
155+
inline self_type& operator-=(difference_type n)
156+
{
157+
m_offset -= n;
158+
return *this;
159+
}
160+
161+
inline self_type operator+(difference_type n) const
162+
{
163+
return self_type(p_a, m_offset + n);
164+
}
165+
166+
inline self_type operator-(difference_type n) const
167+
{
168+
return self_type(p_a, m_offset - n);
169+
}
170+
171+
inline self_type operator-(const self_type& rhs) const
172+
{
173+
self_type tmp(*this);
174+
tmp -= (m_offset - rhs.m_offset);
175+
return tmp;
176+
}
177+
178+
inline std::size_t offset() const
179+
{
180+
return m_offset;
181+
}
182+
183+
private:
184+
185+
const A* p_a;
186+
std::size_t m_offset;
187+
};
188+
189+
template <class A>
190+
inline bool operator==(const pybackstrides_iterator<A>& lhs,
191+
const pybackstrides_iterator<A>& rhs)
192+
{
193+
return lhs.offset() == rhs.offset();
194+
}
195+
196+
template <class A>
197+
inline bool operator!=(const pybackstrides_iterator<A>& lhs,
198+
const pybackstrides_iterator<A>& rhs)
199+
{
200+
return !(lhs == rhs);
201+
}
202+
203+
template <class A>
204+
inline bool operator<(const pybackstrides_iterator<A>& lhs,
205+
const pybackstrides_iterator<A>& rhs)
206+
{
207+
return lhs.offset() < rhs.offset();
208+
}
209+
210+
template <class A>
211+
inline bool operator<=(const pybackstrides_iterator<A>& lhs,
212+
const pybackstrides_iterator<A>& rhs)
213+
{
214+
return (lhs < rhs) || (lhs == rhs);
215+
}
216+
217+
template <class A>
218+
inline bool operator>(const pybackstrides_iterator<A>& lhs,
219+
const pybackstrides_iterator<A>& rhs)
220+
{
221+
return !(lhs <= rhs);
222+
}
223+
224+
template <class A>
225+
inline bool operator>=(const pybackstrides_iterator<A>& lhs,
226+
const pybackstrides_iterator<A>& rhs)
227+
{
228+
return !(lhs < rhs);
229+
}
230+
78231
template <class A>
79232
class pyarray_backstrides
80233
{
81234
public:
82235

83236
using array_type = A;
84237
using value_type = typename array_type::size_type;
238+
using const_reference = value_type;
239+
using const_pointer = const value_type*;
85240
using size_type = typename array_type::size_type;
241+
using difference_type = typename array_type::difference_type;
242+
243+
using const_iterator = pybackstrides_iterator<A>;
86244

87245
pyarray_backstrides() = default;
88246
pyarray_backstrides(const array_type& a);
89247

248+
bool empty() const;
249+
size_type size() const;
250+
90251
value_type operator[](size_type i) const;
91252

92-
size_type size() const;
253+
const_reference front() const;
254+
const_reference back() const;
255+
256+
const_iterator begin() const;
257+
const_iterator end() const;
258+
const_iterator cbegin() const;
259+
const_iterator cend() const;
93260

94261
private:
95262

@@ -213,6 +380,12 @@ namespace xt
213380
{
214381
}
215382

383+
template <class A>
384+
inline bool pyarray_backstrides<A>::empty() const
385+
{
386+
return p_a->dimension() == 0;
387+
}
388+
216389
template <class A>
217390
inline auto pyarray_backstrides<A>::size() const -> size_type
218391
{
@@ -227,6 +400,47 @@ namespace xt
227400
return res;
228401
}
229402

403+
template <class A>
404+
inline auto pyarray_backstrides<A>::front() const -> const_reference
405+
{
406+
value_type sh = p_a->shape()[0];
407+
value_type res = sh == 1 ? 0 : (sh - 1) * p_a->strides()[0];
408+
return res;
409+
}
410+
411+
template <class A>
412+
inline auto pyarray_backstrides<A>::back() const -> const_reference
413+
{
414+
auto index = p_a->size() - 1;
415+
value_type sh = p_a->shape()[index];
416+
value_type res = sh == 1 ? 0 : (sh - 1) * p_a->strides()[index];
417+
return res;
418+
}
419+
420+
template <class A>
421+
inline auto pyarray_backstrides<A>::begin() const -> const_iterator
422+
{
423+
return cbegin();
424+
}
425+
426+
template <class A>
427+
inline auto pyarray_backstrides<A>::end() const -> const_iterator
428+
{
429+
return cend();
430+
}
431+
432+
template <class A>
433+
inline auto pyarray_backstrides<A>::cbegin() const -> const_iterator
434+
{
435+
const_iterator(p_a, 0);
436+
}
437+
438+
template <class A>
439+
inline auto pyarray_backstrides<A>::cend() const -> const_iterator
440+
{
441+
const_iterator(p_a, size());
442+
}
443+
230444
/**************************
231445
* pyarray implementation *
232446
**************************/

include/xtensor-python/pystrides_adaptor.hpp

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ namespace xt
1818
template <std::size_t N>
1919
class pystrides_iterator;
2020

21+
/*********************************
22+
* pystrides_adaptor declaration *
23+
*********************************/
24+
2125
template <std::size_t N>
2226
class pystrides_adaptor
2327
{
@@ -53,14 +57,13 @@ namespace xt
5357
size_type m_size;
5458
};
5559

56-
/*************************************
57-
* pystrides_iterator implementation *
58-
*************************************/
60+
/**********************************
61+
* pystrides_iterator declaration *
62+
**********************************/
5963

6064
template <std::size_t N>
6165
class pystrides_iterator
6266
{
63-
6467
public:
6568

6669
using self_type = pystrides_iterator<N>;
@@ -76,16 +79,29 @@ namespace xt
7679
{
7780
}
7881

79-
inline reference operator*() const { return *p_current / N; }
80-
inline pointer operator->() const { return p_current; }
82+
inline reference operator*() const
83+
{
84+
return *p_current / N;
85+
}
8186

82-
inline reference operator[](difference_type n) const { return *(p_current + n) / N; }
87+
inline pointer operator->() const
88+
{
89+
// Returning the address of a temporary
90+
value_type res = *p_current / N;
91+
return &res;
92+
}
93+
94+
inline reference operator[](difference_type n) const
95+
{
96+
return *(p_current + n) / N;
97+
}
8398

8499
inline self_type& operator++()
85100
{
86101
++p_current;
87102
return *this;
88103
}
104+
89105
inline self_type& operator--()
90106
{
91107
--p_current;
@@ -110,14 +126,23 @@ namespace xt
110126
p_current += n;
111127
return *this;
112128
}
129+
113130
inline self_type& operator-=(difference_type n)
114131
{
115132
p_current -= n;
116133
return *this;
117134
}
118135

119-
inline self_type operator+(difference_type n) const { return self_type(p_current + n); }
120-
inline self_type operator-(difference_type n) const { return self_type(p_current - n); }
136+
inline self_type operator+(difference_type n) const
137+
{
138+
return self_type(p_current + n);
139+
}
140+
141+
inline self_type operator-(difference_type n) const
142+
{
143+
return self_type(p_current - n);
144+
}
145+
121146
inline self_type operator-(const self_type& rhs) const
122147
{
123148
self_type tmp(*this);
@@ -217,25 +242,25 @@ namespace xt
217242
template <std::size_t N>
218243
inline auto pystrides_adaptor<N>::begin() const -> const_iterator
219244
{
220-
return const_iterator(p_data);
245+
return cbegin();
221246
}
222247

223248
template <std::size_t N>
224249
inline auto pystrides_adaptor<N>::end() const -> const_iterator
225250
{
226-
return const_iterator(p_data + m_size);
251+
return cend();
227252
}
228253

229254
template <std::size_t N>
230255
inline auto pystrides_adaptor<N>::cbegin() const -> const_iterator
231256
{
232-
return begin();
257+
return const_iterator(p_data);
233258
}
234259

235260
template <std::size_t N>
236261
inline auto pystrides_adaptor<N>::cend() const -> const_iterator
237262
{
238-
return end();
263+
return const_iterator(p_data + m_size);
239264
}
240265
}
241266

0 commit comments

Comments
 (0)