Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix(itertools): add re-entrancy guard to tee object
  • Loading branch information
ever0de committed Jul 10, 2025
commit aaf18058f825c6a9328b9d62f2eb70b0674d2285
2 changes: 0 additions & 2 deletions Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,8 +1761,6 @@ def test_tee_del_backward(self):
del forward, backward
raise

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_tee_reenter(self):
class I:
first = True
Expand Down
14 changes: 12 additions & 2 deletions vm/src/stdlib/itertools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1185,21 +1185,31 @@ mod decl {
struct PyItertoolsTeeData {
iterable: PyIter,
values: PyRwLock<Vec<PyObjectRef>>,
locked: AtomicCell<bool>,
}

impl PyItertoolsTeeData {
fn new(iterable: PyIter, _vm: &VirtualMachine) -> PyResult<PyRc<Self>> {
Ok(PyRc::new(Self {
iterable,
values: PyRwLock::new(vec![]),
locked: AtomicCell::new(false),
}))
}

fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult<PyIterReturn> {
if self.values.read().len() == index {
let result = raise_if_stop!(self.iterable.next(vm)?);
self.values.write().push(result);
if self.locked.swap(true) {
return Err(vm.new_runtime_error("cannot re-enter the tee iterator"));
}

let result = self.iterable.next(vm);
self.locked.store(false);

let obj = raise_if_stop!(result?);
self.values.write().push(obj);
}

Comment on lines 1199 to +1207
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While traditional locks prefer to lock the code, Rust locks prefer to lock the data.
After changing PyRwLock to PyMutex, this approach will work

Suggested change
if self.values.read().len() == index {
let result = raise_if_stop!(self.iterable.next(vm)?);
self.values.write().push(result);
if self.locked.swap(true) {
return Err(vm.new_runtime_error("cannot re-enter the tee iterator"));
}
let result = self.iterable.next(vm);
self.locked.store(false);
let obj = raise_if_stop!(result?);
self.values.write().push(obj);
}
let Some(values) = self.values.try_lock() else {
return Err(vm.new_runtime_error("cannot re-enter the tee iterator"));
};
if values.len() == index {
let obj = raise_if_stop!(self.iterable.next(vm)?);
values.push(obj);
}
Ok(PyIterReturn::Return(values[index].clone()))

Ok(PyIterReturn::Return(self.values.read()[index].clone()))
}
}
Expand Down
Loading