Skip to content

Commit 80b1f54

Browse files
committed
Add itertools.islice
1 parent d70caf7 commit 80b1f54

2 files changed

Lines changed: 162 additions & 1 deletion

File tree

tests/snippets/stdlib_itertools.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,32 @@
154154
next(t)
155155
with assertRaises(StopIteration):
156156
next(t)
157+
158+
159+
# itertools.islice tests
160+
161+
def assert_matches_seq(it, seq):
162+
assert list(it) == list(seq)
163+
164+
i = itertools.islice
165+
166+
it = i([1, 2, 3, 4, 5], 3)
167+
assert_matches_seq(it, [1, 2, 3])
168+
169+
it = i([0.5, 1, 1.5, 2, 2.5, 3, 4, 5], 1, 6, 2)
170+
assert_matches_seq(it, [1, 2, 3])
171+
172+
it = i([1, 2], None)
173+
assert_matches_seq(it, [1, 2])
174+
175+
it = i([1, 2, 3], None, None, None)
176+
assert_matches_seq(it, [1, 2, 3])
177+
178+
it = i([1, 2, 3], 1, None, None)
179+
assert_matches_seq(it, [2, 3])
180+
181+
it = i([1, 2, 3], None, 2, None)
182+
assert_matches_seq(it, [1, 2])
183+
184+
it = i([1, 2, 3], None, None, 3)
185+
assert_matches_seq(it, [1])

vm/src/stdlib/itertools.rs

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@ use std::cmp::Ordering;
33
use std::ops::{AddAssign, SubAssign};
44

55
use num_bigint::BigInt;
6+
use num_traits::ToPrimitive;
67

78
use crate::function::{OptionalArg, PyFuncArgs};
89
use crate::obj::objbool;
10+
use crate::obj::objint;
911
use crate::obj::objint::{PyInt, PyIntRef};
1012
use crate::obj::objiter::{call_next, get_iter, new_stop_iteration};
1113
use crate::obj::objtype;
1214
use crate::obj::objtype::PyClassRef;
13-
use crate::pyobject::{PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue};
15+
use crate::pyobject::{IdProtocol, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue};
1416
use crate::vm::VirtualMachine;
1517

1618
#[pyclass(name = "chain")]
@@ -284,6 +286,133 @@ impl PyItertoolsTakewhile {
284286
}
285287
}
286288

289+
#[pyclass(name = "islice")]
290+
#[derive(Debug)]
291+
struct PyItertoolsIslice {
292+
iterable: PyObjectRef,
293+
cur: RefCell<usize>,
294+
next: RefCell<usize>,
295+
stop: Option<usize>,
296+
step: usize,
297+
}
298+
299+
impl PyValue for PyItertoolsIslice {
300+
fn class(vm: &VirtualMachine) -> PyClassRef {
301+
vm.class("itertools", "islice")
302+
}
303+
}
304+
305+
fn pyobject_to_opt_usize(obj: PyObjectRef, vm: &VirtualMachine) -> Option<usize> {
306+
let is_int = objtype::isinstance(&obj, &vm.ctx.int_type());
307+
if is_int {
308+
objint::get_value(&obj).to_usize()
309+
} else {
310+
None
311+
}
312+
}
313+
314+
#[pyimpl]
315+
impl PyItertoolsIslice {
316+
#[pymethod(name = "__new__")]
317+
fn new(_cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult {
318+
let (iter, start, stop, step) = match args.args.len() {
319+
0 | 1 => {
320+
return Err(vm.new_type_error(format!(
321+
"islice expected at least 2 arguments, got {}",
322+
args.args.len()
323+
)));
324+
}
325+
326+
2 => {
327+
let (iter, stop): (PyObjectRef, PyObjectRef) = args.bind(vm)?;
328+
329+
(iter, 0usize, stop, 1usize)
330+
}
331+
_ => {
332+
let (iter, start, stop, step): (
333+
PyObjectRef,
334+
PyObjectRef,
335+
PyObjectRef,
336+
PyObjectRef,
337+
) = args.bind(vm)?;
338+
339+
let start = if !start.is(&vm.get_none()) {
340+
pyobject_to_opt_usize(start, &vm).ok_or_else(|| {
341+
vm.new_value_error(
342+
"Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.".to_string(),
343+
)
344+
})?
345+
} else {
346+
0usize
347+
};
348+
349+
let step = if !step.is(&vm.get_none()) {
350+
pyobject_to_opt_usize(step, &vm).ok_or_else(|| {
351+
vm.new_value_error(
352+
"Step for islice() must be a positive integer or None.".to_string(),
353+
)
354+
})?
355+
} else {
356+
1usize
357+
};
358+
359+
(iter, start, stop, step)
360+
}
361+
};
362+
363+
let stop = if !stop.is(&vm.get_none()) {
364+
Some(pyobject_to_opt_usize(stop, &vm).ok_or_else(|| {
365+
vm.new_value_error(
366+
"Stop argument for islice() must be None or an integer: 0 <= x <= sys.maxsize."
367+
.to_string(),
368+
)
369+
})?)
370+
} else {
371+
None
372+
};
373+
374+
let iter = get_iter(vm, &iter)?;
375+
376+
Ok(PyItertoolsIslice {
377+
iterable: iter,
378+
cur: RefCell::new(0),
379+
next: RefCell::new(start),
380+
stop: stop,
381+
step: step,
382+
}
383+
.into_ref(vm)
384+
.into_object())
385+
}
386+
387+
#[pymethod(name = "__next__")]
388+
fn next(&self, vm: &VirtualMachine) -> PyResult {
389+
while *self.cur.borrow() < *self.next.borrow() {
390+
call_next(vm, &self.iterable)?;
391+
*self.cur.borrow_mut() += 1;
392+
}
393+
394+
if let Some(stop) = self.stop {
395+
if *self.cur.borrow() >= stop {
396+
return Err(new_stop_iteration(vm));
397+
}
398+
}
399+
400+
let obj = call_next(vm, &self.iterable)?;
401+
*self.cur.borrow_mut() += 1;
402+
403+
// TODO is this overflow check required? attempts to copy CPython.
404+
let (next, ovf) = (*self.next.borrow()).overflowing_add(self.step);
405+
*self.next.borrow_mut() = if ovf { self.stop.unwrap() } else { next };
406+
407+
Ok(obj)
408+
}
409+
410+
#[pymethod(name = "__iter__")]
411+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
412+
zelf
413+
}
414+
}
415+
287416
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
288417
let ctx = &vm.ctx;
289418

@@ -300,11 +429,14 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
300429
let takewhile = ctx.new_class("takewhile", ctx.object());
301430
PyItertoolsTakewhile::extend_class(ctx, &takewhile);
302431

432+
let islice = PyItertoolsIslice::make_class(ctx);
433+
303434
py_module!(vm, "itertools", {
304435
"chain" => chain,
305436
"count" => count,
306437
"repeat" => repeat,
307438
"starmap" => starmap,
308439
"takewhile" => takewhile,
440+
"islice" => islice,
309441
})
310442
}

0 commit comments

Comments
 (0)