Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Implemented __reduce__, __setstate__ in product object
  • Loading branch information
MannarAmuthan committed Oct 11, 2023
commit 623e8bad9de8c3707bcedf60c01e0c746ce9cd3d
7 changes: 3 additions & 4 deletions Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_chain_setstate(self):
it = chain()
it.__setstate__((iter(['abc', 'def']), iter(['ghi'])))
self.assertEqual(list(it), ['ghi', 'a', 'b', 'c', 'd', 'e', 'f'])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_combinations(self):
Expand Down Expand Up @@ -1165,8 +1165,7 @@ def test_product_tuple_reuse(self):
self.assertEqual(len(set(map(id, product('abc', 'def')))), 1)
self.assertNotEqual(len(set(map(id, list(product('abc', 'def'))))), 1)

# TODO: RUSTPYTHON
@unittest.expectedFailure

def test_product_pickling(self):
# check copy, deepcopy, pickle
for args, result in [
Expand Down Expand Up @@ -2297,7 +2296,7 @@ def __eq__(self, other):


class SubclassWithKwargsTest(unittest.TestCase):

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_keywords_in_subclass(self):
Expand Down
110 changes: 99 additions & 11 deletions vm/src/stdlib/itertools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub(crate) use decl::make_module;

#[pymodule(name = "itertools")]
mod decl {
use crate::stdlib::itertools::decl::int::get_value;
use crate::{
builtins::{
int, tuple::IntoPyTuple, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef,
Expand Down Expand Up @@ -110,7 +111,9 @@ mod decl {
Ok(())
}
}

impl SelfIter for PyItertoolsChain {}

impl IterNext for PyItertoolsChain {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let Some(source) = zelf.source.read().clone() else {
Expand Down Expand Up @@ -201,6 +204,7 @@ mod decl {
}

impl SelfIter for PyItertoolsCompress {}

impl IterNext for PyItertoolsCompress {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
loop {
Expand Down Expand Up @@ -268,7 +272,9 @@ mod decl {
(zelf.class().to_owned(), (zelf.cur.read().clone(),))
}
}

impl SelfIter for PyItertoolsCount {}

impl IterNext for PyItertoolsCount {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let mut cur = zelf.cur.write();
Expand Down Expand Up @@ -316,7 +322,9 @@ mod decl {

#[pyclass(with(IterNext, Iterable, Constructor), flags(BASETYPE))]
impl PyItertoolsCycle {}

impl SelfIter for PyItertoolsCycle {}

impl IterNext for PyItertoolsCycle {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let item = if let PyIterReturn::Return(item) = zelf.iter.next(vm)? {
Expand Down Expand Up @@ -401,6 +409,7 @@ mod decl {
}

impl SelfIter for PyItertoolsRepeat {}

impl IterNext for PyItertoolsRepeat {
fn next(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<PyIterReturn> {
if let Some(ref times) = zelf.times {
Expand Down Expand Up @@ -466,7 +475,9 @@ mod decl {
)
}
}

impl SelfIter for PyItertoolsStarmap {}

impl IterNext for PyItertoolsStarmap {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let obj = zelf.iterable.next(vm)?;
Expand Down Expand Up @@ -537,7 +548,9 @@ mod decl {
Ok(())
}
}

impl SelfIter for PyItertoolsTakewhile {}

impl IterNext for PyItertoolsTakewhile {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
if zelf.stop_flag.load() {
Expand Down Expand Up @@ -618,7 +631,9 @@ mod decl {
Ok(())
}
}

impl SelfIter for PyItertoolsDropwhile {}

impl IterNext for PyItertoolsDropwhile {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let predicate = &zelf.predicate;
Expand All @@ -629,7 +644,7 @@ mod decl {
let obj = match iterable.next(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
};
let pred = predicate.clone();
Expand Down Expand Up @@ -737,7 +752,9 @@ mod decl {
Ok(PyIterReturn::Return((new_value, new_key)))
}
}

impl SelfIter for PyItertoolsGroupBy {}

impl IterNext for PyItertoolsGroupBy {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let mut state = zelf.state.lock();
Expand All @@ -753,7 +770,7 @@ mod decl {
let (value, new_key) = match zelf.advance(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
};
if !vm.bool_eq(&new_key, &old_key)? {
Expand All @@ -764,7 +781,7 @@ mod decl {
match zelf.advance(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
}
};
Expand Down Expand Up @@ -797,7 +814,9 @@ mod decl {

#[pyclass(with(IterNext, Iterable))]
impl PyItertoolsGrouper {}

impl SelfIter for PyItertoolsGrouper {}

impl IterNext for PyItertoolsGrouper {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let old_key = {
Expand Down Expand Up @@ -960,6 +979,7 @@ mod decl {
}

impl SelfIter for PyItertoolsIslice {}

impl IterNext for PyItertoolsIslice {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
while zelf.cur.load() < zelf.next.load() {
Expand Down Expand Up @@ -1033,7 +1053,9 @@ mod decl {
)
}
}

impl SelfIter for PyItertoolsFilterFalse {}

impl IterNext for PyItertoolsFilterFalse {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let predicate = &zelf.predicate;
Expand Down Expand Up @@ -1142,6 +1164,7 @@ mod decl {
}

impl SelfIter for PyItertoolsAccumulate {}

impl IterNext for PyItertoolsAccumulate {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let iterable = &zelf.iterable;
Expand All @@ -1153,7 +1176,7 @@ mod decl {
None => match iterable.next(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
},
Some(obj) => obj.clone(),
Expand All @@ -1162,7 +1185,7 @@ mod decl {
let obj = match iterable.next(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
};
match &zelf.binop {
Expand Down Expand Up @@ -1348,7 +1371,60 @@ mod decl {
self.cur.store(idxs.len() - 1);
}
}

#[pymethod(magic)]
fn setstate(zelf: PyRef<Self>, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
let args = state.as_slice();
if args.len() != zelf.pools.len() {
let msg = format!("Invalid number of arguments");
return Err(vm.new_type_error(msg));
}
let mut idxs: PyRwLockWriteGuard<'_, Vec<usize>> = zelf.idxs.write();
idxs.clear();
for s in 0..args.len() {
let index = get_value(state.get(s).unwrap()).to_usize().unwrap();
let pool_size = zelf.pools.get(s).unwrap().len();
if pool_size == 0 {
zelf.stop.store(true);
return Ok(());
}
if index >= pool_size {
idxs.push(pool_size - 1);
} else {
idxs.push(index);
}
}
zelf.stop.store(false);
return Ok(());
}

#[pymethod(magic)]
fn reduce(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyTupleRef {
let class = zelf.class().to_owned();

if zelf.stop.load() {
return vm.new_tuple((class, (vm.ctx.empty_tuple.clone(),)));
}

let mut pools: Vec<PyObjectRef> = Vec::new();
for element in zelf.pools.iter() {
pools.push(element.clone().into_pytuple(vm).into());
}

let mut indices: Vec<PyObjectRef> = Vec::new();

for item in &zelf.idxs.read()[..] {
indices.push(vm.new_pyobj(*item));
}

return vm.new_tuple((
class,
pools.clone().into_pytuple(vm),
indices.into_pytuple(vm),
));
}
}

impl SelfIter for PyItertoolsProduct {}
impl IterNext for PyItertoolsProduct {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
Expand Down Expand Up @@ -1563,6 +1639,7 @@ mod decl {
impl PyItertoolsCombinationsWithReplacement {}

impl SelfIter for PyItertoolsCombinationsWithReplacement {}

impl IterNext for PyItertoolsCombinationsWithReplacement {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
// stop signal
Expand Down Expand Up @@ -1612,12 +1689,17 @@ mod decl {
#[pyclass(name = "permutations")]
#[derive(Debug, PyPayload)]
struct PyItertoolsPermutations {
pool: Vec<PyObjectRef>, // Collected input iterable
indices: PyRwLock<Vec<usize>>, // One index per element in pool
cycles: PyRwLock<Vec<usize>>, // One rollover counter per element in the result
result: PyRwLock<Option<Vec<usize>>>, // Indexes of the most recently returned result
r: AtomicCell<usize>, // Size of result tuple
exhausted: AtomicCell<bool>, // Set when the iterator is exhausted
pool: Vec<PyObjectRef>,
// Collected input iterable
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment place here is looking weird. Isn't this expected to be put on upper line?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

Thanks for response,

Yeah, actually I am not sure how this is getting formatted (maybe because of my IDE).
Anyway will fix it and push this.

Now Ran (cargo fmt -all)

Copy link
Copy Markdown
Contributor Author

@MannarAmuthan MannarAmuthan Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some tests are failing after this formatting, But I don't think, it is related to this change(formatting comments).
Is it possible to run tests again?

indices: PyRwLock<Vec<usize>>,
// One index per element in pool
cycles: PyRwLock<Vec<usize>>,
// One rollover counter per element in the result
result: PyRwLock<Option<Vec<usize>>>,
// Indexes of the most recently returned result
r: AtomicCell<usize>,
// Size of result tuple
exhausted: AtomicCell<bool>, // Set when the iterator is exhausted
}

#[derive(FromArgs)]
Expand Down Expand Up @@ -1679,7 +1761,9 @@ mod decl {
))
}
}

impl SelfIter for PyItertoolsPermutations {}

impl IterNext for PyItertoolsPermutations {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
// stop signal
Expand Down Expand Up @@ -1802,7 +1886,9 @@ mod decl {
Ok(())
}
}

impl SelfIter for PyItertoolsZipLongest {}

impl IterNext for PyItertoolsZipLongest {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
if zelf.iterators.is_empty() {
Expand Down Expand Up @@ -1851,7 +1937,9 @@ mod decl {

#[pyclass(with(IterNext, Iterable, Constructor))]
impl PyItertoolsPairwise {}

impl SelfIter for PyItertoolsPairwise {}

impl IterNext for PyItertoolsPairwise {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let old = match zelf.old.read().clone() {
Expand Down