Skip to content

Commit 0ce15e1

Browse files
Merge pull request RustPython#772 from adrian17/iterables
Refactor iterables, separate types, make behavior more compatible with CPython
2 parents 42768b2 + 2f2a843 commit 0ce15e1

16 files changed

Lines changed: 395 additions & 198 deletions

tests/snippets/iterable.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from testutils import assert_raises
2+
3+
def test_container(x):
4+
assert 3 in x
5+
assert 4 not in x
6+
assert list(x) == list(iter(x))
7+
assert list(x) == [0, 1, 2, 3]
8+
assert [*x] == [0, 1, 2, 3]
9+
lst = []
10+
lst.extend(x)
11+
assert lst == [0, 1, 2, 3]
12+
13+
class C:
14+
def __iter__(self):
15+
return iter([0, 1, 2, 3])
16+
test_container(C())
17+
18+
class C:
19+
def __getitem__(self, x):
20+
return (0, 1, 2, 3)[x] # raises IndexError on x==4
21+
test_container(C())
22+
23+
class C:
24+
def __getitem__(self, x):
25+
if x > 3:
26+
raise StopIteration
27+
return x
28+
test_container(C())
29+
30+
class C: pass
31+
assert_raises(TypeError, lambda: 5 in C())
32+
assert_raises(TypeError, lambda: iter(C))

vm/src/frame.rs

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,36 +1075,14 @@ impl Frame {
10751075
a.get_id()
10761076
}
10771077

1078-
// https://docs.python.org/3/reference/expressions.html#membership-test-operations
1079-
fn _membership(
1080-
&self,
1081-
vm: &VirtualMachine,
1082-
needle: PyObjectRef,
1083-
haystack: &PyObjectRef,
1084-
) -> PyResult {
1085-
vm.call_method(&haystack, "__contains__", vec![needle])
1086-
// TODO: implement __iter__ and __getitem__ cases when __contains__ is
1087-
// not implemented.
1088-
}
1089-
10901078
fn _in(&self, vm: &VirtualMachine, needle: PyObjectRef, haystack: PyObjectRef) -> PyResult {
1091-
match self._membership(vm, needle, &haystack) {
1092-
Ok(found) => Ok(found),
1093-
Err(_) => Err(vm.new_type_error(format!(
1094-
"{} has no __contains__ method",
1095-
haystack.class().name
1096-
))),
1097-
}
1079+
let found = vm._membership(haystack.clone(), needle)?;
1080+
Ok(vm.ctx.new_bool(objbool::boolval(vm, found)?))
10981081
}
10991082

11001083
fn _not_in(&self, vm: &VirtualMachine, needle: PyObjectRef, haystack: PyObjectRef) -> PyResult {
1101-
match self._membership(vm, needle, &haystack) {
1102-
Ok(found) => Ok(vm.ctx.new_bool(!objbool::get_value(&found))),
1103-
Err(_) => Err(vm.new_type_error(format!(
1104-
"{} has no __contains__ method",
1105-
haystack.class().name
1106-
))),
1107-
}
1084+
let found = vm._membership(haystack.clone(), needle)?;
1085+
Ok(vm.ctx.new_bool(!objbool::boolval(vm, found)?))
11081086
}
11091087

11101088
fn _is(&self, a: PyObjectRef, b: PyObjectRef) -> bool {

vm/src/obj/objbytearray.rs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Implementation of the python bytearray object.
22
3-
use std::cell::RefCell;
3+
use std::cell::{Cell, RefCell};
44
use std::fmt::Write;
55
use std::ops::{Deref, DerefMut};
66

@@ -11,6 +11,7 @@ use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue};
1111
use crate::vm::VirtualMachine;
1212

1313
use super::objint;
14+
use super::objiter;
1415
use super::objtype::PyClassRef;
1516

1617
#[derive(Debug)]
@@ -67,6 +68,7 @@ pub fn init(context: &PyContext) {
6768
"__eq__" => context.new_rustfunc(PyByteArrayRef::eq),
6869
"__len__" => context.new_rustfunc(PyByteArrayRef::len),
6970
"__repr__" => context.new_rustfunc(PyByteArrayRef::repr),
71+
"__iter__" => context.new_rustfunc(PyByteArrayRef::iter),
7072
"clear" => context.new_rustfunc(PyByteArrayRef::clear),
7173
"isalnum" => context.new_rustfunc(PyByteArrayRef::isalnum),
7274
"isalpha" => context.new_rustfunc(PyByteArrayRef::isalpha),
@@ -80,6 +82,12 @@ pub fn init(context: &PyContext) {
8082
"pop" => context.new_rustfunc(PyByteArrayRef::pop),
8183
"upper" => context.new_rustfunc(PyByteArrayRef::upper)
8284
});
85+
86+
let bytearrayiterator_type = &context.bytearrayiterator_type;
87+
extend_class!(context, bytearrayiterator_type, {
88+
"__next__" => context.new_rustfunc(PyByteArrayIteratorRef::next),
89+
"__iter__" => context.new_rustfunc(PyByteArrayIteratorRef::iter),
90+
});
8391
}
8492

8593
fn bytearray_new(
@@ -225,6 +233,13 @@ impl PyByteArrayRef {
225233
value: RefCell::new(bytes),
226234
}
227235
}
236+
237+
fn iter(self, _vm: &VirtualMachine) -> PyByteArrayIterator {
238+
PyByteArrayIterator {
239+
position: Cell::new(0),
240+
bytearray: self,
241+
}
242+
}
228243
}
229244

230245
// helper function for istitle
@@ -266,3 +281,33 @@ mod tests {
266281
assert_eq!(&to_hex(&[11u8, 222u8]), "\\x0b\\xde");
267282
}
268283
}
284+
285+
#[derive(Debug)]
286+
pub struct PyByteArrayIterator {
287+
position: Cell<usize>,
288+
bytearray: PyByteArrayRef,
289+
}
290+
291+
impl PyValue for PyByteArrayIterator {
292+
fn class(vm: &VirtualMachine) -> PyClassRef {
293+
vm.ctx.bytearrayiterator_type()
294+
}
295+
}
296+
297+
type PyByteArrayIteratorRef = PyRef<PyByteArrayIterator>;
298+
299+
impl PyByteArrayIteratorRef {
300+
fn next(self, vm: &VirtualMachine) -> PyResult<u8> {
301+
if self.position.get() < self.bytearray.value.borrow().len() {
302+
let ret = self.bytearray.value.borrow()[self.position.get()];
303+
self.position.set(self.position.get() + 1);
304+
Ok(ret)
305+
} else {
306+
Err(objiter::new_stop_iteration(vm))
307+
}
308+
}
309+
310+
fn iter(self, _vm: &VirtualMachine) -> Self {
311+
self
312+
}
313+
}

vm/src/obj/objbytes.rs

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ use std::ops::Deref;
66
use num_traits::ToPrimitive;
77

88
use crate::function::OptionalArg;
9-
use crate::pyobject::{PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue};
9+
use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue};
1010
use crate::vm::VirtualMachine;
1111

1212
use super::objint;
13+
use super::objiter;
1314
use super::objtype::PyClassRef;
1415

1516
#[derive(Debug)]
@@ -68,6 +69,12 @@ pub fn init(context: &PyContext) {
6869
"__iter__" => context.new_rustfunc(PyBytesRef::iter),
6970
"__doc__" => context.new_str(bytes_doc.to_string())
7071
});
72+
73+
let bytesiterator_type = &context.bytesiterator_type;
74+
extend_class!(context, bytesiterator_type, {
75+
"__next__" => context.new_rustfunc(PyBytesIteratorRef::next),
76+
"__iter__" => context.new_rustfunc(PyBytesIteratorRef::iter),
77+
});
7178
}
7279

7380
fn bytes_new(
@@ -149,14 +156,44 @@ impl PyBytesRef {
149156
format!("b'{}'", data)
150157
}
151158

152-
fn iter(obj: PyBytesRef, _vm: &VirtualMachine) -> PyIteratorValue {
153-
PyIteratorValue {
159+
fn iter(self, _vm: &VirtualMachine) -> PyBytesIterator {
160+
PyBytesIterator {
154161
position: Cell::new(0),
155-
iterated_obj: obj.into_object(),
162+
bytes: self,
156163
}
157164
}
158165
}
159166

160167
pub fn get_value<'a>(obj: &'a PyObjectRef) -> impl Deref<Target = Vec<u8>> + 'a {
161168
&obj.payload::<PyBytes>().unwrap().value
162169
}
170+
171+
#[derive(Debug)]
172+
pub struct PyBytesIterator {
173+
position: Cell<usize>,
174+
bytes: PyBytesRef,
175+
}
176+
177+
impl PyValue for PyBytesIterator {
178+
fn class(vm: &VirtualMachine) -> PyClassRef {
179+
vm.ctx.bytesiterator_type()
180+
}
181+
}
182+
183+
type PyBytesIteratorRef = PyRef<PyBytesIterator>;
184+
185+
impl PyBytesIteratorRef {
186+
fn next(self, vm: &VirtualMachine) -> PyResult<u8> {
187+
if self.position.get() < self.bytes.value.len() {
188+
let ret = self.bytes[self.position.get()];
189+
self.position.set(self.position.get() + 1);
190+
Ok(ret)
191+
} else {
192+
Err(objiter::new_stop_iteration(vm))
193+
}
194+
}
195+
196+
fn iter(self, _vm: &VirtualMachine) -> Self {
197+
self
198+
}
199+
}

vm/src/obj/objdict.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@ use std::ops::{Deref, DerefMut};
55

66
use crate::function::{KwArgs, OptionalArg};
77
use crate::pyobject::{
8-
DictProtocol, PyAttributes, PyContext, PyIteratorValue, PyObjectRef, PyRef, PyResult, PyValue,
8+
DictProtocol, PyAttributes, PyContext, PyObjectRef, PyRef, PyResult, PyValue,
99
};
1010
use crate::vm::{ReprGuard, VirtualMachine};
1111

1212
use super::objiter;
13+
use super::objlist::PyListIterator;
1314
use super::objstr::{self, PyStringRef};
1415
use super::objtype;
1516
use crate::obj::objtype::PyClassRef;
@@ -210,7 +211,8 @@ impl PyDictRef {
210211
}
211212

212213
/// When iterating over a dictionary, we iterate over the keys of it.
213-
fn iter(self, vm: &VirtualMachine) -> PyIteratorValue {
214+
fn iter(self, vm: &VirtualMachine) -> PyListIterator {
215+
// TODO: separate type, not a list iterator
214216
let keys = self
215217
.entries
216218
.borrow()
@@ -219,13 +221,14 @@ impl PyDictRef {
219221
.collect();
220222
let key_list = vm.ctx.new_list(keys);
221223

222-
PyIteratorValue {
224+
PyListIterator {
223225
position: Cell::new(0),
224-
iterated_obj: key_list,
226+
list: key_list.downcast().unwrap(),
225227
}
226228
}
227229

228-
fn values(self, vm: &VirtualMachine) -> PyIteratorValue {
230+
fn values(self, vm: &VirtualMachine) -> PyListIterator {
231+
// TODO: separate type. `values` should be a live view over the collection, not an iterator.
229232
let values = self
230233
.entries
231234
.borrow()
@@ -234,13 +237,14 @@ impl PyDictRef {
234237
.collect();
235238
let values_list = vm.ctx.new_list(values);
236239

237-
PyIteratorValue {
240+
PyListIterator {
238241
position: Cell::new(0),
239-
iterated_obj: values_list,
242+
list: values_list.downcast().unwrap(),
240243
}
241244
}
242245

243-
fn items(self, vm: &VirtualMachine) -> PyIteratorValue {
246+
fn items(self, vm: &VirtualMachine) -> PyListIterator {
247+
// TODO: separate type. `items` should be a live view over the collection, not an iterator.
244248
let items = self
245249
.entries
246250
.borrow()
@@ -249,9 +253,9 @@ impl PyDictRef {
249253
.collect();
250254
let items_list = vm.ctx.new_list(items);
251255

252-
PyIteratorValue {
256+
PyListIterator {
253257
position: Cell::new(0),
254-
iterated_obj: items_list,
258+
list: items_list.downcast().unwrap(),
255259
}
256260
}
257261

@@ -332,6 +336,7 @@ pub fn init(context: &PyContext) {
332336
"clear" => context.new_rustfunc(PyDictRef::clear),
333337
"values" => context.new_rustfunc(PyDictRef::values),
334338
"items" => context.new_rustfunc(PyDictRef::items),
339+
// TODO: separate type. `keys` should be a live view over the collection, not an iterator.
335340
"keys" => context.new_rustfunc(PyDictRef::iter),
336341
"get" => context.new_rustfunc(PyDictRef::get),
337342
});

vm/src/obj/objenumerate.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,17 @@ impl PyEnumerateRef {
5757

5858
Ok(result)
5959
}
60+
61+
fn iter(self, _vm: &VirtualMachine) -> Self {
62+
self
63+
}
6064
}
6165

6266
pub fn init(context: &PyContext) {
6367
let enumerate_type = &context.enumerate_type;
64-
objiter::iter_type_init(context, enumerate_type);
6568
extend_class!(context, enumerate_type, {
6669
"__new__" => context.new_rustfunc(enumerate_new),
67-
"__next__" => context.new_rustfunc(PyEnumerateRef::next)
70+
"__next__" => context.new_rustfunc(PyEnumerateRef::next),
71+
"__iter__" => context.new_rustfunc(PyEnumerateRef::iter),
6872
});
6973
}

vm/src/obj/objfilter.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,15 @@ impl PyFilterRef {
5252
}
5353
}
5454
}
55+
56+
fn iter(self, _vm: &VirtualMachine) -> Self {
57+
self
58+
}
5559
}
5660

5761
pub fn init(context: &PyContext) {
5862
let filter_type = &context.filter_type;
5963

60-
objiter::iter_type_init(context, filter_type);
61-
6264
let filter_doc =
6365
"filter(function or None, iterable) --> filter object\n\n\
6466
Return an iterator yielding those items of iterable for which function(item)\n\
@@ -67,6 +69,7 @@ pub fn init(context: &PyContext) {
6769
extend_class!(context, filter_type, {
6870
"__new__" => context.new_rustfunc(filter_new),
6971
"__doc__" => context.new_str(filter_doc.to_string()),
70-
"__next__" => context.new_rustfunc(PyFilterRef::next)
72+
"__next__" => context.new_rustfunc(PyFilterRef::next),
73+
"__iter__" => context.new_rustfunc(PyFilterRef::iter),
7174
});
7275
}

0 commit comments

Comments
 (0)