Skip to content

Commit 32bfcf3

Browse files
Merge pull request RustPython#384 from skinny121/iter_lazy_2
Add enumerate and zip types
2 parents c555934 + e1284e3 commit 32bfcf3

11 files changed

Lines changed: 243 additions & 135 deletions

File tree

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
assert list(enumerate(['a', 'b', 'c'])) == [(0, 'a'), (1, 'b'), (2, 'c')]
2+
3+
assert type(enumerate([])) == enumerate
4+
5+
assert list(enumerate(['a', 'b', 'c'], -100)) == [(-100, 'a'), (-99, 'b'), (-98, 'c')]
6+
assert list(enumerate(['a', 'b', 'c'], 2**200)) == [(2**200, 'a'), (2**200 + 1, 'b'), (2**200 + 2, 'c')]
7+
8+
9+
# test infinite iterator
10+
class Counter(object):
11+
counter = 0
12+
13+
def __next__(self):
14+
self.counter += 1
15+
return self.counter
16+
17+
def __iter__(self):
18+
return self
19+
20+
21+
it = enumerate(Counter())
22+
assert next(it) == (0, 1)
23+
assert next(it) == (1, 2)

tests/snippets/builtin_zip.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
assert list(zip(['a', 'b', 'c'], range(3), [9, 8, 7, 99])) == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7)]
2+
3+
assert list(zip(['a', 'b', 'c'])) == [('a',), ('b',), ('c',)]
4+
assert list(zip()) == []
5+
6+
assert list(zip(*zip(['a', 'b', 'c'], range(1, 4)))) == [('a', 'b', 'c'), (1, 2, 3)]
7+
8+
9+
# test infinite iterator
10+
class Counter(object):
11+
def __init__(self, counter=0):
12+
self.counter = counter
13+
14+
def __next__(self):
15+
self.counter += 1
16+
return self.counter
17+
18+
def __iter__(self):
19+
return self
20+
21+
22+
it = zip(Counter(), Counter(3))
23+
assert next(it) == (1, 4)
24+
assert next(it) == (2, 5)

tests/snippets/builtins.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,8 @@
55
# TODO:
66
# assert callable(callable)
77

8-
assert list(enumerate(['a', 'b', 'c'])) == [(0, 'a'), (1, 'b'), (2, 'c')]
9-
108
assert type(frozenset) is type
119

12-
assert list(zip(['a', 'b', 'c'], range(3), [9, 8, 7, 99])) == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7)]
13-
1410
assert 3 == eval('1+2')
1511

1612
code = compile('5+3', 'x.py', 'eval')

vm/src/builtins.rs

Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ use super::pyobject::{
2121
use super::stdlib::io::io_open;
2222

2323
use super::vm::VirtualMachine;
24-
use num_bigint::ToBigInt;
25-
use num_traits::{Signed, ToPrimitive, Zero};
24+
use num_traits::{Signed, ToPrimitive};
2625

2726
fn get_locals(vm: &mut VirtualMachine) -> PyObjectRef {
2827
let d = vm.new_dict();
@@ -180,29 +179,6 @@ fn builtin_divmod(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
180179
}
181180
}
182181

183-
fn builtin_enumerate(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
184-
arg_check!(
185-
vm,
186-
args,
187-
required = [(iterable, None)],
188-
optional = [(start, None)]
189-
);
190-
let items = vm.extract_elements(iterable)?;
191-
let start = if let Some(start) = start {
192-
objint::get_value(start)
193-
} else {
194-
Zero::zero()
195-
};
196-
let mut new_items = vec![];
197-
for (i, item) in items.into_iter().enumerate() {
198-
let element = vm
199-
.ctx
200-
.new_tuple(vec![vm.ctx.new_int(i.to_bigint().unwrap() + &start), item]);
201-
new_items.push(element);
202-
}
203-
Ok(vm.ctx.new_list(new_items))
204-
}
205-
206182
/// Implements `eval`.
207183
/// See also: https://docs.python.org/3/library/functions.html#eval
208184
fn builtin_eval(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
@@ -641,32 +617,6 @@ fn builtin_sum(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
641617
}
642618

643619
// builtin_vars
644-
645-
fn builtin_zip(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
646-
no_kwargs!(vm, args);
647-
648-
// TODO: process one element at a time from iterators.
649-
let mut iterables = vec![];
650-
for iterable in args.args.iter() {
651-
let iterable = vm.extract_elements(iterable)?;
652-
iterables.push(iterable);
653-
}
654-
655-
let minsize: usize = iterables.iter().map(|i| i.len()).min().unwrap_or(0);
656-
657-
let mut new_items = vec![];
658-
for i in 0..minsize {
659-
let items = iterables
660-
.iter()
661-
.map(|iterable| iterable[i].clone())
662-
.collect();
663-
let element = vm.ctx.new_tuple(items);
664-
new_items.push(element);
665-
}
666-
667-
Ok(vm.ctx.new_list(new_items))
668-
}
669-
670620
// builtin___import__
671621

672622
pub fn make_module(ctx: &PyContext) -> PyObjectRef {
@@ -692,7 +642,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
692642
ctx.set_attr(&py_mod, "dict", ctx.dict_type());
693643
ctx.set_attr(&py_mod, "divmod", ctx.new_rustfunc(builtin_divmod));
694644
ctx.set_attr(&py_mod, "dir", ctx.new_rustfunc(builtin_dir));
695-
ctx.set_attr(&py_mod, "enumerate", ctx.new_rustfunc(builtin_enumerate));
645+
ctx.set_attr(&py_mod, "enumerate", ctx.enumerate_type());
696646
ctx.set_attr(&py_mod, "eval", ctx.new_rustfunc(builtin_eval));
697647
ctx.set_attr(&py_mod, "exec", ctx.new_rustfunc(builtin_exec));
698648
ctx.set_attr(&py_mod, "float", ctx.float_type());
@@ -733,7 +683,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
733683
ctx.set_attr(&py_mod, "super", ctx.super_type());
734684
ctx.set_attr(&py_mod, "tuple", ctx.tuple_type());
735685
ctx.set_attr(&py_mod, "type", ctx.type_type());
736-
ctx.set_attr(&py_mod, "zip", ctx.new_rustfunc(builtin_zip));
686+
ctx.set_attr(&py_mod, "zip", ctx.zip_type());
737687

738688
// Exceptions:
739689
ctx.set_attr(

vm/src/obj/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub mod objbytes;
66
pub mod objcode;
77
pub mod objcomplex;
88
pub mod objdict;
9+
pub mod objenumerate;
910
pub mod objfilter;
1011
pub mod objfloat;
1112
pub mod objframe;
@@ -25,3 +26,4 @@ pub mod objstr;
2526
pub mod objsuper;
2627
pub mod objtuple;
2728
pub mod objtype;
29+
pub mod objzip;

vm/src/obj/objenumerate.rs

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
use super::super::pyobject::{
2+
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
3+
};
4+
use super::super::vm::VirtualMachine;
5+
use super::objint;
6+
use super::objiter;
7+
use super::objtype; // Required for arg_check! to use isinstance
8+
use num_bigint::BigInt;
9+
use num_traits::Zero;
10+
use std::ops::AddAssign;
11+
12+
fn enumerate_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
13+
arg_check!(
14+
vm,
15+
args,
16+
required = [(cls, Some(vm.ctx.type_type())), (iterable, None)],
17+
optional = [(start, Some(vm.ctx.int_type()))]
18+
);
19+
let counter = if let Some(x) = start {
20+
objint::get_value(x)
21+
} else {
22+
BigInt::zero()
23+
};
24+
let iterator = objiter::get_iter(vm, iterable)?;
25+
Ok(PyObject::new(
26+
PyObjectPayload::EnumerateIterator { counter, iterator },
27+
cls.clone(),
28+
))
29+
}
30+
31+
fn enumerate_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
32+
arg_check!(
33+
vm,
34+
args,
35+
required = [(enumerate, Some(vm.ctx.enumerate_type()))]
36+
);
37+
38+
if let PyObjectPayload::EnumerateIterator {
39+
ref mut counter,
40+
ref mut iterator,
41+
} = enumerate.borrow_mut().payload
42+
{
43+
let next_obj = objiter::call_next(vm, iterator)?;
44+
let result = vm
45+
.ctx
46+
.new_tuple(vec![vm.ctx.new_int(counter.clone()), next_obj]);
47+
48+
AddAssign::add_assign(counter, 1);
49+
50+
Ok(result)
51+
} else {
52+
panic!("enumerate doesn't have correct payload");
53+
}
54+
}
55+
56+
pub fn init(context: &PyContext) {
57+
let enumerate_type = &context.enumerate_type;
58+
objiter::iter_type_init(context, enumerate_type);
59+
context.set_attr(
60+
enumerate_type,
61+
"__new__",
62+
context.new_rustfunc(enumerate_new),
63+
);
64+
context.set_attr(
65+
enumerate_type,
66+
"__next__",
67+
context.new_rustfunc(enumerate_next),
68+
);
69+
}

vm/src/obj/objfilter.rs

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use super::objbool;
77
use super::objiter;
88
use super::objtype; // Required for arg_check! to use isinstance
99

10-
pub fn filter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
10+
fn filter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
1111
arg_check!(
1212
vm,
1313
args,
@@ -23,21 +23,6 @@ pub fn filter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
2323
))
2424
}
2525

26-
fn filter_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
27-
arg_check!(vm, args, required = [(filter, Some(vm.ctx.filter_type()))]);
28-
// Return self:
29-
Ok(filter.clone())
30-
}
31-
32-
fn filter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
33-
arg_check!(
34-
vm,
35-
args,
36-
required = [(filter, Some(vm.ctx.filter_type())), (needle, None)]
37-
);
38-
objiter::contains(vm, filter, needle)
39-
}
40-
4126
fn filter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
4227
arg_check!(vm, args, required = [(filter, Some(vm.ctx.filter_type()))]);
4328

@@ -72,12 +57,7 @@ fn filter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
7257

7358
pub fn init(context: &PyContext) {
7459
let filter_type = &context.filter_type;
75-
context.set_attr(
76-
&filter_type,
77-
"__contains__",
78-
context.new_rustfunc(filter_contains),
79-
);
80-
context.set_attr(&filter_type, "__iter__", context.new_rustfunc(filter_iter));
60+
objiter::iter_type_init(context, filter_type);
8161
context.set_attr(&filter_type, "__new__", context.new_rustfunc(filter_new));
8262
context.set_attr(&filter_type, "__next__", context.new_rustfunc(filter_next));
8363
}

0 commit comments

Comments
 (0)