Skip to content

Commit a03f28f

Browse files
committed
Make PyEnumerate ThreadSafe
1 parent ceb7ca6 commit a03f28f

1 file changed

Lines changed: 11 additions & 15 deletions

File tree

vm/src/obj/objenumerate.rs

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
use std::cell::RefCell;
21
use std::ops::AddAssign;
2+
use std::sync::RwLock;
33

44
use num_bigint::BigInt;
55
use num_traits::Zero;
@@ -8,16 +8,17 @@ use super::objint::PyIntRef;
88
use super::objiter;
99
use super::objtype::PyClassRef;
1010
use crate::function::OptionalArg;
11-
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue};
11+
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe};
1212
use crate::vm::VirtualMachine;
1313

1414
#[pyclass]
1515
#[derive(Debug)]
1616
pub struct PyEnumerate {
17-
counter: RefCell<BigInt>,
17+
counter: RwLock<BigInt>,
1818
iterator: PyObjectRef,
1919
}
2020
type PyEnumerateRef = PyRef<PyEnumerate>;
21+
impl ThreadSafe for PyEnumerate {}
2122

2223
impl PyValue for PyEnumerate {
2324
fn class(vm: &VirtualMachine) -> PyClassRef {
@@ -41,24 +42,19 @@ impl PyEnumerate {
4142

4243
let iterator = objiter::get_iter(vm, &iterable)?;
4344
PyEnumerate {
44-
counter: RefCell::new(counter),
45+
counter: RwLock::new(counter),
4546
iterator,
4647
}
4748
.into_ref_with_type(vm, cls)
4849
}
4950

5051
#[pymethod(name = "__next__")]
51-
fn next(&self, vm: &VirtualMachine) -> PyResult {
52-
let iterator = &self.iterator;
53-
let counter = &self.counter;
54-
let next_obj = objiter::call_next(vm, iterator)?;
55-
let result = vm
56-
.ctx
57-
.new_tuple(vec![vm.ctx.new_bigint(&counter.borrow()), next_obj]);
58-
59-
AddAssign::add_assign(&mut counter.borrow_mut() as &mut BigInt, 1);
60-
61-
Ok(result)
52+
fn next(&self, vm: &VirtualMachine) -> PyResult<(BigInt, PyObjectRef)> {
53+
let next_obj = objiter::call_next(vm, &self.iterator)?;
54+
let mut counter = self.counter.write().unwrap();
55+
let position = counter.clone();
56+
AddAssign::add_assign(&mut counter as &mut BigInt, 1);
57+
Ok((position, next_obj))
6258
}
6359

6460
#[pymethod(name = "__iter__")]

0 commit comments

Comments
 (0)