Skip to content

Commit 8326db9

Browse files
committed
Align type lock behavior with CPython
1 parent 6d2f650 commit 8326db9

4 files changed

Lines changed: 125 additions & 84 deletions

File tree

crates/vm/src/builtins/type.rs

Lines changed: 74 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,24 @@ pub fn type_cache_clear() {
216216
TYPE_CACHE_CLEARING.store(false, Ordering::Release);
217217
}
218218

219+
/// Repair type-cache SeqLock state in the post-fork child.
220+
///
221+
/// If fork happens while a writer holds an entry SeqLock, the child inherits
222+
/// the odd sequence value with no surviving writer to release it. Clear only
223+
/// those in-progress entries, matching CPython's `_PyTypes_AfterFork()`.
224+
pub unsafe fn type_cache_after_fork() {
225+
for entry in TYPE_CACHE.iter() {
226+
let seq = entry.sequence.load(Ordering::Relaxed);
227+
if (seq & 1) == 0 {
228+
continue;
229+
}
230+
entry.value.store(core::ptr::null_mut(), Ordering::Relaxed);
231+
entry.name.store(core::ptr::null_mut(), Ordering::Relaxed);
232+
entry.version.store(0, Ordering::Relaxed);
233+
entry.sequence.store(0, Ordering::Relaxed);
234+
}
235+
}
236+
219237
unsafe impl crate::object::Traverse for PyType {
220238
fn traverse(&self, tracer_fn: &mut crate::object::TraverseFn<'_>) {
221239
self.base.traverse(tracer_fn);
@@ -487,10 +505,28 @@ impl PyType {
487505
}
488506

489507
pub fn assign_version_tag(&self) -> u32 {
490-
self.assign_version_tag_inner()
508+
let version = self.tp_version_tag.load(Ordering::Acquire);
509+
if version != 0 {
510+
return version;
511+
}
512+
crate::vm::thread::try_with_current_vm(|vm| {
513+
Self::with_type_lock(vm, || {
514+
let version = self.tp_version_tag.load(Ordering::Acquire);
515+
if version == 0 {
516+
self.assign_version_tag_inner()
517+
} else {
518+
version
519+
}
520+
})
521+
})
522+
.unwrap_or_else(|| self.assign_version_tag_inner())
491523
}
492524

493525
pub(crate) fn version_for_specialization(&self, vm: &VirtualMachine) -> u32 {
526+
let version = self.tp_version_tag.load(Ordering::Acquire);
527+
if version != 0 {
528+
return version;
529+
}
494530
Self::with_type_lock(vm, || {
495531
let version = self.tp_version_tag.load(Ordering::Acquire);
496532
if version == 0 {
@@ -503,28 +539,34 @@ impl PyType {
503539

504540
/// Invalidate this type's version tag and cascade to all subclasses.
505541
fn modified_inner(&self) {
506-
if let Some(ext) = self.heaptype_ext.as_ref() {
507-
ext.specialization_cache.invalidate_for_type_modified();
508-
}
509-
// If already invalidated, all subclasses must also be invalidated
510-
// (guaranteed by the MRO invariant in assign_version_tag).
511542
let old_version = self.tp_version_tag.load(Ordering::Acquire);
512543
if old_version == 0 {
513544
return;
514545
}
515-
self.tp_version_tag.store(0, Ordering::SeqCst);
516-
// Nullify borrowed pointers in cache entries for this version
517-
// so they don't dangle after the dict is modified.
518-
type_cache_clear_version(old_version);
519546
let subclasses = self.subclasses.read();
520547
for weak_ref in subclasses.iter() {
521548
if let Some(sub) = weak_ref.upgrade() {
522549
sub.downcast_ref::<PyType>().unwrap().modified_inner();
523550
}
524551
}
552+
self.tp_version_tag.store(0, Ordering::SeqCst);
553+
// Nullify borrowed pointers in cache entries for this version
554+
// so they don't dangle after the dict is modified.
555+
type_cache_clear_version(old_version);
556+
if let Some(ext) = self.heaptype_ext.as_ref() {
557+
ext.specialization_cache.invalidate_for_type_modified();
558+
}
525559
}
526560

527561
pub fn modified(&self) {
562+
if self.tp_version_tag.load(Ordering::Acquire) == 0 {
563+
return;
564+
}
565+
if let Some(()) = crate::vm::thread::try_with_current_vm(|vm| {
566+
Self::with_type_lock(vm, || self.modified_inner());
567+
}) {
568+
return;
569+
}
528570
self.modified_inner();
529571
}
530572

@@ -1049,11 +1091,11 @@ impl PyType {
10491091
if func_version == 0 {
10501092
return false;
10511093
}
1094+
ext.specialization_cache
1095+
.swap_getitem(Some(getitem), Some(vm));
10521096
ext.specialization_cache
10531097
.getitem_version
10541098
.store(func_version, Ordering::Release);
1055-
ext.specialization_cache
1056-
.swap_getitem(Some(getitem), Some(vm));
10571099
true
10581100
})
10591101
}
@@ -1076,18 +1118,7 @@ impl PyType {
10761118
Some((getitem, cached_version))
10771119
}
10781120

1079-
pub fn get_direct_attr(&self, attr_name: &'static PyStrInterned) -> Option<PyObjectRef> {
1080-
self.attributes.read().get(attr_name).cloned()
1081-
}
1082-
1083-
/// find_name_in_mro with method cache (MCACHE).
1084-
/// Looks in tp_dict of types in MRO, bypasses descriptors.
1085-
///
1086-
/// Uses a lock-free SeqLock-style pattern:
1087-
/// Read: load sequence/version/name → load value + try_to_owned →
1088-
/// validate value pointer + sequence
1089-
/// Write: sequence(begin) → version=0 → swap value/name → version=assigned → sequence(end)
1090-
fn find_name_in_mro(&self, name: &'static PyStrInterned) -> Option<PyObjectRef> {
1121+
fn find_name_in_mro_without_vm(&self, name: &'static PyStrInterned) -> Option<PyObjectRef> {
10911122
let version = self.tp_version_tag.load(Ordering::Acquire);
10921123
if version != 0 {
10931124
let idx = type_cache_hash(version, name);
@@ -1109,8 +1140,6 @@ impl PyType {
11091140
}
11101141
continue;
11111142
}
1112-
// _Py_TryIncrefCompare-style validation:
1113-
// safe_inc via raw pointer, then ensure source is unchanged.
11141143
if let Some(cloned) = unsafe { PyObject::try_to_owned_from_ptr(ptr) } {
11151144
let same_ptr = core::ptr::eq(entry.value.load(Ordering::Relaxed), ptr);
11161145
if same_ptr && entry.end_read(seq1) {
@@ -1123,20 +1152,12 @@ impl PyType {
11231152
}
11241153
}
11251154

1126-
// Assign version BEFORE the MRO walk so that any concurrent
1127-
// modified() call during the walk invalidates this version.
11281155
let assigned = if version == 0 {
11291156
self.assign_version_tag()
11301157
} else {
11311158
version
11321159
};
1133-
1134-
// MRO walk
11351160
let result = self.find_name_in_mro_uncached(name);
1136-
1137-
// Only cache positive results. Negative results are not cached to
1138-
// avoid stale entries from transient MRO walk failures during
1139-
// concurrent type modifications.
11401161
if let Some(ref found) = result
11411162
&& assigned != 0
11421163
&& !TYPE_CACHE_CLEARING.load(Ordering::Acquire)
@@ -1146,20 +1167,34 @@ impl PyType {
11461167
let entry = &TYPE_CACHE[idx];
11471168
let name_ptr = name as *const _ as *mut _;
11481169
entry.begin_write();
1149-
// Invalidate first to prevent readers from seeing partial state
11501170
entry.version.store(0, Ordering::Release);
1151-
// Store borrowed pointer (no refcount increment).
11521171
let new_ptr = &**found as *const PyObject as *mut PyObject;
11531172
entry.value.store(new_ptr, Ordering::Relaxed);
11541173
entry.name.store(name_ptr, Ordering::Relaxed);
1155-
// Activate entry — Release ensures value/name writes are visible
11561174
entry.version.store(assigned, Ordering::Release);
11571175
entry.end_write();
11581176
}
1159-
11601177
result
11611178
}
11621179

1180+
pub fn get_direct_attr(&self, attr_name: &'static PyStrInterned) -> Option<PyObjectRef> {
1181+
self.attributes.read().get(attr_name).cloned()
1182+
}
1183+
1184+
/// find_name_in_mro with method cache (MCACHE).
1185+
/// Looks in tp_dict of types in MRO, bypasses descriptors.
1186+
///
1187+
/// Uses a lock-free SeqLock-style pattern:
1188+
/// Read: load sequence/version/name → load value + try_to_owned →
1189+
/// validate value pointer + sequence
1190+
/// Write: sequence(begin) → version=0 → swap value/name → version=assigned → sequence(end)
1191+
fn find_name_in_mro(&self, name: &'static PyStrInterned) -> Option<PyObjectRef> {
1192+
crate::vm::thread::try_with_current_vm(|vm| {
1193+
self.lookup_ref_and_version_interned(name, vm).0
1194+
})
1195+
.unwrap_or_else(|| self.find_name_in_mro_without_vm(name))
1196+
}
1197+
11631198
/// Raw MRO walk without cache.
11641199
fn find_name_in_mro_uncached(&self, name: &'static PyStrInterned) -> Option<PyObjectRef> {
11651200
for cls in self.mro.read().iter() {
@@ -1173,7 +1208,7 @@ impl PyType {
11731208
/// _PyType_LookupRef: look up a name through the MRO without setting an exception.
11741209
pub fn lookup_ref(&self, name: &Py<PyStr>, vm: &VirtualMachine) -> Option<PyObjectRef> {
11751210
let interned_name = vm.ctx.interned_str(name)?;
1176-
self.find_name_in_mro(interned_name)
1211+
self.lookup_ref_and_version_interned(interned_name, vm).0
11771212
}
11781213

11791214
pub fn get_super_attr(&self, attr_name: &'static PyStrInterned) -> Option<PyObjectRef> {

crates/vm/src/frame.rs

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7456,12 +7456,13 @@ impl ExecutingFrame<'_> {
74567456
.load()
74577457
.is_some_and(|f| f as usize == PyBaseObject::getattro as *const () as usize);
74587458
if !is_default_getattro {
7459-
let type_version = cls.version_for_specialization(_vm);
7459+
let (getattribute, type_version) =
7460+
cls.lookup_ref_and_version_interned(identifier!(_vm, __getattribute__), _vm);
74607461
if type_version != 0
74617462
&& !oparg.is_method()
74627463
&& !self.specialization_eval_frame_active(_vm)
74637464
&& cls.get_attr(identifier!(_vm, __getattr__)).is_none()
7464-
&& let Some(getattribute) = cls.get_attr(identifier!(_vm, __getattribute__))
7465+
&& let Some(getattribute) = getattribute
74657466
&& let Some(func) = getattribute.downcast_ref_if_exact::<PyFunction>(_vm)
74667467
&& func.can_specialize_call(2)
74677468
{
@@ -7493,27 +7494,24 @@ impl ExecutingFrame<'_> {
74937494
return;
74947495
}
74957496

7496-
// Get or assign type version
7497-
let type_version = cls.version_for_specialization(_vm);
7498-
if type_version == 0 {
7499-
// Version counter overflow — backoff to avoid re-attempting every execution
7500-
unsafe {
7501-
self.code.instructions.write_adaptive_counter(
7502-
cache_base,
7503-
bytecode::adaptive_counter_backoff(
7504-
self.code.instructions.read_adaptive_counter(cache_base),
7505-
),
7506-
);
7507-
}
7508-
return;
7509-
}
7510-
75117497
let attr_name = self.code.names[oparg.name_idx() as usize];
75127498

75137499
// Match CPython: only specialize module attribute loads when the
75147500
// current module dict has no __getattr__ override and the attribute is
75157501
// already present.
75167502
if let Some(module) = obj.downcast_ref_if_exact::<PyModule>(_vm) {
7503+
let type_version = cls.version_for_specialization(_vm);
7504+
if type_version == 0 {
7505+
unsafe {
7506+
self.code.instructions.write_adaptive_counter(
7507+
cache_base,
7508+
bytecode::adaptive_counter_backoff(
7509+
self.code.instructions.read_adaptive_counter(cache_base),
7510+
),
7511+
);
7512+
}
7513+
return;
7514+
}
75177515
let module_dict = module.dict();
75187516
match (
75197517
module_dict.get_item_opt(identifier!(_vm, __getattr__), _vm),
@@ -7540,8 +7538,18 @@ impl ExecutingFrame<'_> {
75407538
return;
75417539
}
75427540

7543-
// Look up attr in class via MRO
7544-
let cls_attr = cls.get_attr(attr_name);
7541+
let (cls_attr, type_version) = cls.lookup_ref_and_version_interned(attr_name, _vm);
7542+
if type_version == 0 {
7543+
unsafe {
7544+
self.code.instructions.write_adaptive_counter(
7545+
cache_base,
7546+
bytecode::adaptive_counter_backoff(
7547+
self.code.instructions.read_adaptive_counter(cache_base),
7548+
),
7549+
);
7550+
}
7551+
return;
7552+
}
75457553
let class_has_dict = cls.slots.flags.has_feature(PyTypeFlags::HAS_DICT);
75467554

75477555
if oparg.is_method() {
@@ -7712,26 +7720,11 @@ impl ExecutingFrame<'_> {
77127720
) {
77137721
let obj = self.top_value();
77147722
let owner_type = obj.downcast_ref::<PyType>().unwrap();
7715-
7716-
// Get or assign type version for the type object itself
7717-
let type_version = owner_type.version_for_specialization(_vm);
7718-
if type_version == 0 {
7719-
unsafe {
7720-
self.code.instructions.write_adaptive_counter(
7721-
cache_base,
7722-
bytecode::adaptive_counter_backoff(
7723-
self.code.instructions.read_adaptive_counter(cache_base),
7724-
),
7725-
);
7726-
}
7727-
return;
7728-
}
7729-
77307723
let attr_name = self.code.names[oparg.name_idx() as usize];
77317724

77327725
// Check metaclass: ensure no data descriptor on metaclass for this name
77337726
let mcl = obj.class();
7734-
let mcl_attr = mcl.get_attr(attr_name);
7727+
let (mcl_attr, mut metaclass_version) = mcl.lookup_ref_and_version_interned(attr_name, _vm);
77357728
if let Some(ref attr) = mcl_attr {
77367729
let attr_class = attr.class();
77377730
if attr_class.slots.descr_set.load().is_some() {
@@ -7747,9 +7740,7 @@ impl ExecutingFrame<'_> {
77477740
return;
77487741
}
77497742
}
7750-
let mut metaclass_version = 0;
77517743
if !mcl.slots.flags.has_feature(PyTypeFlags::IMMUTABLETYPE) {
7752-
metaclass_version = mcl.version_for_specialization(_vm);
77537744
if metaclass_version == 0 {
77547745
unsafe {
77557746
self.code.instructions.write_adaptive_counter(
@@ -7761,10 +7752,22 @@ impl ExecutingFrame<'_> {
77617752
}
77627753
return;
77637754
}
7755+
} else {
7756+
metaclass_version = 0;
77647757
}
77657758

7766-
// Look up attr in the type's own MRO
7767-
let cls_attr = owner_type.get_attr(attr_name);
7759+
let (cls_attr, type_version) = owner_type.lookup_ref_and_version_interned(attr_name, _vm);
7760+
if type_version == 0 {
7761+
unsafe {
7762+
self.code.instructions.write_adaptive_counter(
7763+
cache_base,
7764+
bytecode::adaptive_counter_backoff(
7765+
self.code.instructions.read_adaptive_counter(cache_base),
7766+
),
7767+
);
7768+
}
7769+
return;
7770+
}
77687771
if let Some(ref descr) = cls_attr {
77697772
let descr_class = descr.class();
77707773
let has_descr_get = descr_class.slots.descr_get.load().is_some();
@@ -9140,8 +9143,8 @@ impl ExecutingFrame<'_> {
91409143
return;
91419144
}
91429145

9143-
// Get or assign type version
9144-
let type_version = cls.version_for_specialization(vm);
9146+
let attr_name = self.code.names[attr_idx as usize];
9147+
let (cls_attr, type_version) = cls.lookup_ref_and_version_interned(attr_name, vm);
91459148
if type_version == 0 {
91469149
unsafe {
91479150
self.code.instructions.write_adaptive_counter(
@@ -9153,10 +9156,6 @@ impl ExecutingFrame<'_> {
91539156
}
91549157
return;
91559158
}
9156-
9157-
// Check for data descriptor
9158-
let attr_name = self.code.names[attr_idx as usize];
9159-
let cls_attr = cls.get_attr(attr_name);
91609159
let has_data_descr = cls_attr.as_ref().is_some_and(|descr| {
91619160
let descr_cls = descr.class();
91629161
descr_cls.slots.descr_get.load().is_some() && descr_cls.slots.descr_set.load().is_some()

crates/vm/src/stdlib/posix.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,9 @@ pub mod module {
801801
#[cfg(feature = "threading")]
802802
crate::object::reset_weakref_locks_after_fork();
803803

804+
// Repair any type-cache entries left mid-update at fork time.
805+
unsafe { crate::builtins::type_::type_cache_after_fork() };
806+
804807
// Phase 3: Clean up thread state. Locks are now reinit'd so we can
805808
// acquire them normally instead of using try_lock().
806809
#[cfg(feature = "threading")]

crates/vm/src/vm/thread.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ pub fn with_current_vm<R>(f: impl FnOnce(&VirtualMachine) -> R) -> R {
7575
VM_CURRENT.with(f)
7676
}
7777

78+
pub fn try_with_current_vm<R>(f: impl FnOnce(&VirtualMachine) -> R) -> Option<R> {
79+
VM_CURRENT.is_set().then(|| VM_CURRENT.with(f))
80+
}
81+
7882
pub fn enter_vm<R>(vm: &VirtualMachine, f: impl FnOnce() -> R) -> R {
7983
VM_STACK.with(|vms| {
8084
// Outermost enter_vm: transition DETACHED → ATTACHED

0 commit comments

Comments
 (0)