Skip to content

Commit 45f8c3a

Browse files
committed
remove closure lock
1 parent dc308f6 commit 45f8c3a

File tree

2 files changed

+31
-56
lines changed

2 files changed

+31
-56
lines changed

vm/src/builtins/function.rs

Lines changed: 25 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use super::{
55
PyAsyncGen, PyCode, PyCoroutine, PyDictRef, PyGenerator, PyStr, PyStrRef, PyTuple, PyTupleRef,
66
PyType, PyTypeRef,
77
};
8-
use crate::PyAtomicRef;
98
#[cfg(feature = "jit")]
109
use crate::common::lock::OnceCell;
1110
use crate::common::lock::PyMutex;
@@ -32,23 +31,23 @@ pub struct PyFunction {
3231
code: PyRef<PyCode>,
3332
globals: PyDictRef,
3433
builtins: PyObjectRef,
35-
closure: PyAtomicRef<Option<PyTuple>>,
34+
closure: Option<PyRef<PyTuple<PyCellRef>>>,
3635
defaults_and_kwdefaults: PyMutex<(Option<PyTupleRef>, Option<PyDictRef>)>,
3736
name: PyMutex<PyStrRef>,
3837
qualname: PyMutex<PyStrRef>,
3938
type_params: PyMutex<PyTupleRef>,
40-
#[cfg(feature = "jit")]
41-
jitted_code: OnceCell<CompiledCode>,
4239
annotations: PyMutex<PyDictRef>,
4340
module: PyMutex<PyObjectRef>,
4441
doc: PyMutex<PyObjectRef>,
42+
#[cfg(feature = "jit")]
43+
jitted_code: OnceCell<CompiledCode>,
4544
}
4645

4746
unsafe impl Traverse for PyFunction {
4847
fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
4948
self.globals.traverse(tracer_fn);
50-
if let Some(closure) = self.closure.deref() {
51-
closure.traverse(tracer_fn);
49+
if let Some(closure) = self.closure.as_ref() {
50+
closure.as_untyped().traverse(tracer_fn);
5251
}
5352
self.defaults_and_kwdefaults.traverse(tracer_fn);
5453
}
@@ -60,7 +59,7 @@ impl PyFunction {
6059
pub(crate) fn new(
6160
code: PyRef<PyCode>,
6261
globals: PyDictRef,
63-
closure: Option<PyTupleRef>,
62+
closure: Option<PyRef<PyTuple<PyCellRef>>>,
6463
defaults: Option<PyTupleRef>,
6564
kw_only_defaults: Option<PyDictRef>,
6665
qualname: PyStrRef,
@@ -84,7 +83,7 @@ impl PyFunction {
8483
code,
8584
globals,
8685
builtins,
87-
closure: PyAtomicRef::from(closure),
86+
closure,
8887
defaults_and_kwdefaults: PyMutex::new((defaults, kw_only_defaults)),
8988
name,
9089
qualname: PyMutex::new(qualname),
@@ -334,15 +333,15 @@ impl PyFunction {
334333
attr_value.class().name()
335334
))
336335
})?;
337-
self.set___defaults__(Some(defaults));
336+
self.defaults_and_kwdefaults.lock().0 = Some(defaults);
338337
} else if attr.contains(bytecode::MakeFunctionFlags::KW_ONLY_DEFAULTS) {
339338
let kwdefaults = attr_value.clone().downcast::<PyDict>().map_err(|_| {
340339
vm.new_type_error(format!(
341340
"__kwdefaults__ must be a dict, not {}",
342341
attr_value.class().name()
343342
))
344343
})?;
345-
self.set___kwdefaults__(Some(kwdefaults));
344+
self.defaults_and_kwdefaults.lock().1 = Some(kwdefaults);
346345
} else if attr.contains(bytecode::MakeFunctionFlags::ANNOTATIONS) {
347346
let annotations = attr_value.clone().downcast::<PyDict>().map_err(|_| {
348347
vm.new_type_error(format!(
@@ -354,33 +353,18 @@ impl PyFunction {
354353
} else if attr.contains(bytecode::MakeFunctionFlags::CLOSURE) {
355354
// For closure, we need special handling
356355
// The closure tuple contains cell objects
357-
let closure_tuple =
358-
attr_value
359-
.clone()
360-
.downcast_exact::<PyTuple>(vm)
361-
.map_err(|obj| {
362-
vm.new_type_error(format!(
363-
"closure must be a tuple, not {}",
364-
obj.class().name()
365-
))
366-
})?;
367-
368-
// Convert to tuple of cells
369-
let cells: Result<Vec<_>, _> = closure_tuple
370-
.iter()
371-
.map(|cell| cell.clone().downcast_exact::<PyCell>(vm))
372-
.collect();
373-
let cells = cells
374-
.map_err(|_| vm.new_type_error("closure must be a tuple of cells".to_owned()))?;
375-
376-
// Convert cells to PyTuple
377-
let cells_objects: Vec<PyObjectRef> = cells
378-
.into_iter()
379-
.map(|cell| cell.into_pyref().into())
380-
.collect();
381-
let cells_tuple = PyTuple::new_ref(cells_objects, &vm.ctx);
382-
383-
let _ = unsafe { self.closure.swap(Some(cells_tuple)) };
356+
let closure_tuple = attr_value
357+
.clone()
358+
.downcast_exact::<PyTuple>(vm)
359+
.map_err(|obj| {
360+
vm.new_type_error(format!(
361+
"closure must be a tuple, not {}",
362+
obj.class().name()
363+
))
364+
})?
365+
.into_pyref();
366+
367+
self.closure = Some(closure_tuple.try_into_typed::<PyCell>(vm)?);
384368
} else if attr.contains(bytecode::MakeFunctionFlags::TYPE_PARAMS) {
385369
let type_params = attr_value.clone().downcast::<PyTuple>().map_err(|_| {
386370
vm.new_type_error(format!(
@@ -431,15 +415,7 @@ impl Py<PyFunction> {
431415
code.clone(),
432416
Scope::new(Some(locals), self.globals.clone()),
433417
vm.builtins.dict(),
434-
self.closure.deref().as_ref().map_or(&[], |tuple| {
435-
// SAFETY: We know closure contains only cells from construction
436-
unsafe {
437-
std::slice::from_raw_parts(
438-
tuple.as_slice().as_ptr() as *const PyCellRef,
439-
tuple.len(),
440-
)
441-
}
442-
}),
418+
self.closure.as_ref().map_or(&[], |c| c.as_slice()),
443419
Some(self.to_owned().into()),
444420
vm,
445421
)
@@ -513,8 +489,7 @@ impl PyFunction {
513489
#[pymember]
514490
fn __closure__(vm: &VirtualMachine, zelf: PyObjectRef) -> PyResult {
515491
let zelf = Self::_as_pyref(&zelf, vm)?;
516-
let closure = zelf.closure.deref().map(|x| x.as_object().to_owned());
517-
Ok(vm.unwrap_or_none(closure))
492+
Ok(vm.unwrap_or_none(zelf.closure.clone().map(|x| x.into())))
518493
}
519494

520495
#[pymember]
@@ -698,9 +673,9 @@ impl Constructor for PyFunction {
698673
)));
699674
}
700675

701-
// Validate that all items are cells
676+
// Validate that all items are cells and create typed tuple
702677
let typed_closure = closure_tuple.try_into_typed::<PyCell>(vm)?;
703-
Some(typed_closure.into_untyped())
678+
Some(typed_closure)
704679
} else if !args.code.freevars.is_empty() {
705680
return Err(vm.new_type_error("arg 5 (closure) must be tuple"));
706681
} else {

vm/src/object/core.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,12 +1023,12 @@ impl<T> Clone for PyRef<T> {
10231023
}
10241024

10251025
impl<T: PyObjectPayload> PyRef<T> {
1026-
#[inline(always)]
1027-
pub(crate) const fn into_non_null(self) -> NonNull<Py<T>> {
1028-
let ptr = self.ptr;
1029-
std::mem::forget(self);
1030-
ptr
1031-
}
1026+
// #[inline(always)]
1027+
// pub(crate) const fn into_non_null(self) -> NonNull<Py<T>> {
1028+
// let ptr = self.ptr;
1029+
// std::mem::forget(self);
1030+
// ptr
1031+
// }
10321032

10331033
#[inline(always)]
10341034
pub(crate) const unsafe fn from_non_null(ptr: NonNull<Py<T>>) -> Self {

0 commit comments

Comments
 (0)