Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Add combinations.__reduce__
  • Loading branch information
mumallaeng authored and youknowone committed Nov 11, 2022
commit dd93ec3c415ece053877518af92d1b3b4417c709
89 changes: 43 additions & 46 deletions vm/src/stdlib/itertools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ pub(crate) use decl::make_module;
#[pymodule(name = "itertools")]
mod decl {
use crate::{
builtins::{int, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef, PyTypeRef},
builtins::{
int, tuple::IntoPyTuple, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef,
PyTypeRef,
},
common::{
lock::{PyMutex, PyRwLock, PyRwLockWriteGuard},
rc::PyRc,
Expand Down Expand Up @@ -1308,7 +1311,7 @@ mod decl {
struct PyItertoolsCombinations {
pool: Vec<PyObjectRef>,
indices: PyRwLock<Vec<usize>>,
result: PyRwLock<Option<Vec<usize>>>,
result: PyRwLock<Option<Vec<PyObjectRef>>>,
r: AtomicCell<usize>,
exhausted: AtomicCell<bool>,
}
Expand Down Expand Up @@ -1355,32 +1358,29 @@ mod decl {
impl PyItertoolsCombinations {
#[pymethod(magic)]
fn reduce(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyTupleRef {
let result = zelf.result.read();
if let Some(result) = &*result {
if zelf.exhausted.load() {
vm.new_tuple((
zelf.class().to_owned(),
vm.new_tuple((vm.new_tuple(()), vm.ctx.new_int(zelf.r.load()))),
))
} else {
vm.new_tuple((
zelf.class().to_owned(),
vm.new_tuple((
vm.new_tuple(zelf.pool.clone()),
vm.ctx.new_int(zelf.r.load()),
)),
vm.ctx
.new_tuple(result.iter().map(|&i| zelf.pool[i].clone()).collect()),
))
}
let r = zelf.r.load();

let class = zelf.class().to_owned();

if zelf.exhausted.load() {
return vm.new_tuple((
class,
vm.new_tuple((vm.ctx.empty_tuple.clone(), vm.ctx.new_int(r))),
));
}

let tup = vm.new_tuple((zelf.pool.clone().into_pytuple(vm), vm.ctx.new_int(r)));

if zelf.result.read().is_none() {
vm.new_tuple((class, tup))
} else {
vm.new_tuple((
zelf.class().to_owned(),
vm.new_tuple((
vm.new_tuple(zelf.pool.clone()),
vm.ctx.new_int(zelf.r.load()),
)),
))
let mut indices: Vec<PyObjectRef> = Vec::new();

for item in &zelf.indices.read()[..r] {
indices.push(vm.new_pyobj(*item));
}

vm.new_tuple((class, tup, indices.into_pytuple(vm)))
}
}
}
Expand All @@ -1401,9 +1401,8 @@ mod decl {
return Ok(PyIterReturn::Return(vm.new_tuple(()).into()));
}

let mut result = zelf.result.write();

if let Some(ref mut result) = *result {
let mut result_lock = zelf.result.write();
let result = if let Some(ref mut result) = *result_lock {
let mut indices = zelf.indices.write();

// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
Expand All @@ -1426,26 +1425,24 @@ mod decl {
for j in idx as usize + 1..r {
indices[j] = indices[j - 1] + 1;
}
for j in 0..r {
result[j] = indices[j];

// Update the result tuple for the new indices
// starting with i, the leftmost index that changed
for i in idx as usize..r {
let index = indices[i];
let elem = &zelf.pool[index];
result[i] = elem.to_owned();
}

result.to_vec()
}
} else {
*result = Some((0..r).collect());
}
let res = zelf.pool[0..r].to_vec();
*result_lock = Some(res.clone());
res
};

Ok(PyIterReturn::Return(
vm.ctx
.new_tuple(
result
.as_ref()
.unwrap()
.iter()
.map(|&i| zelf.pool[i].clone())
.collect(),
)
.into(),
))
Ok(PyIterReturn::Return(vm.ctx.new_tuple(result).into()))
}
}

Expand Down