Skip to content

Commit de5b1c4

Browse files
committed
Support iter() for types with only __getitem__
1 parent c918e9d commit de5b1c4

1 file changed

Lines changed: 34 additions & 21 deletions

File tree

vm/src/obj/objiter.rs

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
use std::cell::Cell;
66

7-
use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue};
7+
use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol};
88
use crate::vm::VirtualMachine;
99

1010
use super::objtype;
@@ -16,14 +16,19 @@ use super::objtype::PyClassRef;
1616
* function 'iter' is called.
1717
*/
1818
pub fn get_iter(vm: &VirtualMachine, iter_target: &PyObjectRef) -> PyResult {
19-
vm.call_method(iter_target, "__iter__", vec![])
20-
// let type_str = objstr::get_value(&vm.to_str(iter_target.class()).unwrap());
21-
// let type_error = vm.new_type_error(format!("Cannot iterate over {}", type_str));
22-
// return Err(type_error);
23-
24-
// TODO: special case when iter_target only has __getitem__
25-
// see: https://docs.python.org/3/library/functions.html#iter
26-
// also https://docs.python.org/3.8/reference/datamodel.html#special-method-names
19+
if let Ok(method) = vm.get_method(iter_target.clone(), "__iter__") {
20+
vm.invoke(method, vec![])
21+
} else if vm.get_method(iter_target.clone(), "__getitem__").is_ok() {
22+
Ok(PySequenceIterator {
23+
position: Cell::new(0),
24+
obj: iter_target.clone(),
25+
}
26+
.into_ref(vm)
27+
.into_object())
28+
} else {
29+
let message = format!("Cannot iterate over {}", iter_target.class().name);
30+
return Err(vm.new_type_error(message));
31+
}
2732
}
2833

2934
pub fn call_next(vm: &VirtualMachine, iter_obj: &PyObjectRef) -> PyResult {
@@ -70,26 +75,34 @@ pub fn new_stop_iteration(vm: &VirtualMachine) -> PyObjectRef {
7075
vm.new_exception(stop_iteration_type, "End of iterator".to_string())
7176
}
7277

73-
// TODO: This is a workaround and shouldn't exist.
74-
// Each iterable type should have its own distinct iterator type.
75-
// (however, this boilerplate can be reused for "generic iterator" for types with only __getiter__)
7678
#[derive(Debug)]
77-
pub struct PyIteratorValue {
79+
pub struct PySequenceIterator {
7880
pub position: Cell<usize>,
79-
pub iterated_obj: PyObjectRef,
81+
pub obj: PyObjectRef,
8082
}
8183

82-
impl PyValue for PyIteratorValue {
84+
impl PyValue for PySequenceIterator {
8385
fn class(vm: &VirtualMachine) -> PyClassRef {
8486
vm.ctx.iter_type()
8587
}
8688
}
8789

88-
type PyIteratorValueRef = PyRef<PyIteratorValue>;
90+
type PySequenceIteratorRef = PyRef<PySequenceIterator>;
8991

90-
impl PyIteratorValueRef {
91-
fn next(self, _vm: &VirtualMachine) -> PyResult {
92-
unimplemented!()
92+
impl PySequenceIteratorRef {
93+
fn next(self, vm: &VirtualMachine) -> PyResult {
94+
let number = vm.ctx.new_int(self.position.get());
95+
match vm.call_method(&self.obj, "__getitem__", vec![number]) {
96+
Ok(val) => {
97+
self.position.set(self.position.get() + 1);
98+
Ok(val)
99+
}
100+
Err(ref e) if objtype::isinstance(&e, &vm.ctx.exceptions.index_error) => {
101+
Err(new_stop_iteration(vm))
102+
}
103+
// also catches stop_iteration => stop_iteration
104+
Err(e) => Err(e),
105+
}
93106
}
94107

95108
fn iter(self, _vm: &VirtualMachine) -> Self {
@@ -101,7 +114,7 @@ pub fn init(context: &PyContext) {
101114
let iter_type = &context.iter_type;
102115

103116
extend_class!(context, iter_type, {
104-
"__next__" => context.new_rustfunc(PyIteratorValueRef::next),
105-
"__iter__" => context.new_rustfunc(PyIteratorValueRef::iter),
117+
"__next__" => context.new_rustfunc(PySequenceIteratorRef::next),
118+
"__iter__" => context.new_rustfunc(PySequenceIteratorRef::iter),
106119
});
107120
}

0 commit comments

Comments
 (0)