Skip to content

Commit be91529

Browse files
Specialized ops (RustPython#7322)
* Add CALL_ALLOC_AND_ENTER_INIT specialization Optimizes user-defined class instantiation MyClass(args...) when tp_new == object.__new__ and __init__ is a simple PyFunction. Allocates the object directly and calls __init__ via invoke_exact_args, bypassing the generic type.__call__ dispatch path. * Invalidate JIT cache when __code__ is reassigned Change jitted_code from OnceCell to PyMutex<Option<CompiledCode>> so it can be cleared on __code__ assignment. The setter now sets the cached JIT code to None to prevent executing stale machine code. * Atomic operations for specialization cache - range iterator: deduplicate fast_next/next_fast - Replace raw pointer reads/writes in CodeUnits with atomic operations (AtomicU8/AtomicU16) for thread safety - Add read_op (Acquire), read_arg (Relaxed), compare_exchange_op - Use Release ordering in replace_op to synchronize cache writes - Dispatch loop reads opcodes atomically via read_op/read_arg - Fix adaptive counter access: use read/write_adaptive_counter instead of read/write_cache_u16 (was reading wrong bytes) - Add pre-check guards to all specialize_* functions to prevent concurrent specialization races - Move modified() before attribute changes in type.__setattr__ to prevent use-after-free of cached descriptors - Use SeqCst ordering in modified() for version invalidation - Add Release fence after quicken() initialization * Fix slot wrapper override for inherited attributes For __getattribute__: only use getattro_wrapper when the type itself defines the attribute; otherwise inherit native slot from base class via MRO. For __setattr__/__delattr__: only store setattro_wrapper when the type has its own __setattr__ or __delattr__; otherwise keep the inherited base slot. * Fix StoreAttrSlot cache overflow corrupting next instruction write_cache_u32 at cache_base+3 writes 2 code units (positions 3 and 4), but STORE_ATTR only has 4 cache entries (positions 0-3). This overwrites the next instruction with the upper 16 bits of the slot offset. Changed to write_cache_u16/read_cache_u16 since member descriptor offsets fit within u16 (max 65535 bytes). * Exclude method_descriptor from has_python_cmp check has_python_cmp incorrectly treated method_descriptor as Python-level comparison methods, causing richcompare slot to use wrapper dispatch instead of inheriting the native slot. * Fix CompareOpFloat NaN handling partial_cmp returns None for NaN comparisons. is_some_and incorrectly returned false for all NaN comparisons, but NaN != x should be true per IEEE 754 semantics. * Fix invoke_exact_args borrow in CallAllocAndEnterInit * Distinguish Python method vs not-found in slot MRO lookup Change lookup_slot_in_mro to return a 3-state SlotLookupResult enum (NativeSlot/PythonMethod/NotFound) instead of Option<T>. Previously, both "found a Python-level method" and "found nothing" returned None, causing incorrect slot inheritance. For example, class Test(Mixin, TestCase) would inherit object.slot_init from Mixin via inherit_from_mro instead of using init_wrapper to dispatch TestCase.__init__. Apply this fix consistently to all slot update sites: update_main_slot!, update_sub_slot!, TpGetattro, TpSetattro, TpDescrSet, TpHash, TpRichcompare, SqAssItem, MpAssSubscript. * Extract specialization helper functions to reduce boilerplate - deoptimize() / deoptimize_at(): replace specialized op with base op - adaptive(): decrement warmup counter or call specialize function - commit_specialization(): replace op on success, backoff on failure - execute_binary_op_int() / execute_binary_op_float(): typed binary ops Removes 10 duplicate deoptimize_* functions, consolidates 13 adaptive counter blocks, 6 binary op handlers, and 7 specialize tail patterns. Also replaces inline deopt blocks in LoadAttr/StoreAttr handlers. * Improve specialization guards and fix mark_stacks - CONTAINS_OP_SET: add frozenset support in handler and specialize - TO_BOOL_ALWAYS_TRUE: cache type version instead of checking slots - LOAD_GLOBAL_BUILTIN: cache builtins dict version alongside globals - mark_stacks: deoptimize specialized opcodes for correct reachability * Auto-format: cargo fmt --all --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 6d3fa2d commit be91529

7 files changed

Lines changed: 905 additions & 859 deletions

File tree

crates/compiler-core/src/bytecode.rs

Lines changed: 88 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@ use crate::{
88
};
99
use alloc::{borrow::ToOwned, boxed::Box, collections::BTreeSet, fmt, string::String, vec::Vec};
1010
use bitflags::bitflags;
11-
use core::{cell::UnsafeCell, hash, mem, ops::Deref};
11+
use core::{
12+
cell::UnsafeCell,
13+
hash, mem,
14+
ops::Deref,
15+
sync::atomic::{AtomicU8, AtomicU16, Ordering},
16+
};
1217
use itertools::Itertools;
1318
use malachite_bigint::BigInt;
1419
use num_complex::Complex64;
@@ -367,8 +372,13 @@ impl TryFrom<&[u8]> for CodeUnit {
367372

368373
pub struct CodeUnits(UnsafeCell<Box<[CodeUnit]>>);
369374

370-
// SAFETY: All mutation of the inner buffer is serialized by `monitoring_data: PyMutex`
371-
// in `PyCode`. The `UnsafeCell` is required because `replace_op` mutates through `&self`.
375+
// SAFETY: All cache operations use atomic read/write instructions.
376+
// - replace_op / compare_exchange_op: AtomicU8 store/CAS (Release)
377+
// - cache read/write: AtomicU16 load/store (Relaxed)
378+
// - adaptive counter: AtomicU8 load/store (Relaxed)
379+
// Ordering is established by:
380+
// - replace_op (Release) ↔ dispatch loop read_op (Acquire) for cache data visibility
381+
// - tp_version_tag (Acquire) for descriptor pointer validity
372382
unsafe impl Sync for CodeUnits {}
373383

374384
impl Clone for CodeUnits {
@@ -435,45 +445,81 @@ impl Deref for CodeUnits {
435445

436446
impl CodeUnits {
437447
/// Replace the opcode at `index` in-place without changing the arg byte.
448+
/// Uses atomic Release store to ensure prior cache writes are visible
449+
/// to threads that subsequently read the new opcode with Acquire.
438450
///
439451
/// # Safety
440452
/// - `index` must be in bounds.
441453
/// - `new_op` must have the same arg semantics as the original opcode.
442-
/// - The caller must ensure exclusive access to the instruction buffer
443-
/// (no concurrent reads or writes to the same `CodeUnits`).
444454
pub unsafe fn replace_op(&self, index: usize, new_op: Instruction) {
445-
unsafe {
446-
let units = &mut *self.0.get();
447-
let unit_ptr = units.as_mut_ptr().add(index);
448-
// Write only the opcode byte (first byte of CodeUnit due to #[repr(C)])
449-
let op_ptr = unit_ptr as *mut u8;
450-
core::ptr::write(op_ptr, new_op.into());
451-
}
455+
let units = unsafe { &*self.0.get() };
456+
let ptr = units.as_ptr().wrapping_add(index) as *const AtomicU8;
457+
unsafe { &*ptr }.store(new_op.into(), Ordering::Release);
458+
}
459+
460+
/// Atomically replace opcode only if it still matches `expected`.
461+
/// Returns true on success. Uses Release ordering on success.
462+
///
463+
/// # Safety
464+
/// - `index` must be in bounds.
465+
pub unsafe fn compare_exchange_op(
466+
&self,
467+
index: usize,
468+
expected: Instruction,
469+
new_op: Instruction,
470+
) -> bool {
471+
let units = unsafe { &*self.0.get() };
472+
let ptr = units.as_ptr().wrapping_add(index) as *const AtomicU8;
473+
unsafe { &*ptr }
474+
.compare_exchange(
475+
expected.into(),
476+
new_op.into(),
477+
Ordering::Release,
478+
Ordering::Relaxed,
479+
)
480+
.is_ok()
481+
}
482+
483+
/// Atomically read the opcode at `index` with Acquire ordering.
484+
/// Pairs with `replace_op` (Release) to ensure cache data visibility.
485+
pub fn read_op(&self, index: usize) -> Instruction {
486+
let units = unsafe { &*self.0.get() };
487+
let ptr = units.as_ptr().wrapping_add(index) as *const AtomicU8;
488+
let byte = unsafe { &*ptr }.load(Ordering::Acquire);
489+
// SAFETY: Only valid Instruction values are stored via replace_op/compare_exchange_op.
490+
unsafe { mem::transmute::<u8, Instruction>(byte) }
491+
}
492+
493+
/// Atomically read the arg byte at `index` with Relaxed ordering.
494+
pub fn read_arg(&self, index: usize) -> OpArgByte {
495+
let units = unsafe { &*self.0.get() };
496+
let ptr = units.as_ptr().wrapping_add(index) as *const u8;
497+
let arg_ptr = unsafe { ptr.add(1) } as *const AtomicU8;
498+
OpArgByte::from(unsafe { &*arg_ptr }.load(Ordering::Relaxed))
452499
}
453500

454501
/// Write a u16 value into a CACHE code unit at `index`.
455502
/// Each CodeUnit is 2 bytes (#[repr(C)]: op u8 + arg u8), so one u16 fits exactly.
503+
/// Uses Relaxed atomic store; ordering is provided by replace_op (Release).
456504
///
457505
/// # Safety
458506
/// - `index` must be in bounds and point to a CACHE entry.
459-
/// - The caller must ensure no concurrent reads/writes to the same slot.
460507
pub unsafe fn write_cache_u16(&self, index: usize, value: u16) {
461-
unsafe {
462-
let units = &mut *self.0.get();
463-
let ptr = units.as_mut_ptr().add(index) as *mut u8;
464-
core::ptr::write_unaligned(ptr as *mut u16, value);
465-
}
508+
let units = unsafe { &*self.0.get() };
509+
let ptr = units.as_ptr().wrapping_add(index) as *const AtomicU16;
510+
unsafe { &*ptr }.store(value, Ordering::Relaxed);
466511
}
467512

468513
/// Read a u16 value from a CACHE code unit at `index`.
514+
/// Uses Relaxed atomic load; ordering is provided by read_op (Acquire).
469515
///
470516
/// # Panics
471517
/// Panics if `index` is out of bounds.
472518
pub fn read_cache_u16(&self, index: usize) -> u16 {
473519
let units = unsafe { &*self.0.get() };
474520
assert!(index < units.len(), "read_cache_u16: index out of bounds");
475-
let ptr = units.as_ptr().wrapping_add(index) as *const u8;
476-
unsafe { core::ptr::read_unaligned(ptr as *const u16) }
521+
let ptr = units.as_ptr().wrapping_add(index) as *const AtomicU16;
522+
unsafe { &*ptr }.load(Ordering::Relaxed)
477523
}
478524

479525
/// Write a u32 value across two consecutive CACHE code units starting at `index`.
@@ -518,36 +564,40 @@ impl CodeUnits {
518564
lo | (hi << 32)
519565
}
520566

521-
/// Read the adaptive counter from the first CACHE entry's `arg` byte.
522-
/// This preserves `op = Instruction::Cache`, unlike `read_cache_u16`.
567+
/// Read the adaptive counter from the CACHE entry's `arg` byte at `index`.
568+
/// Uses Relaxed atomic load.
523569
pub fn read_adaptive_counter(&self, index: usize) -> u8 {
524570
let units = unsafe { &*self.0.get() };
525-
u8::from(units[index].arg)
571+
let ptr = units.as_ptr().wrapping_add(index) as *const u8;
572+
let arg_ptr = unsafe { ptr.add(1) } as *const AtomicU8;
573+
unsafe { &*arg_ptr }.load(Ordering::Relaxed)
526574
}
527575

528-
/// Write the adaptive counter to the first CACHE entry's `arg` byte.
529-
/// This preserves `op = Instruction::Cache`, unlike `write_cache_u16`.
576+
/// Write the adaptive counter to the CACHE entry's `arg` byte at `index`.
577+
/// Uses Relaxed atomic store.
530578
///
531579
/// # Safety
532580
/// - `index` must be in bounds and point to a CACHE entry.
533581
pub unsafe fn write_adaptive_counter(&self, index: usize, value: u8) {
534-
let units = unsafe { &mut *self.0.get() };
535-
units[index].arg = OpArgByte::from(value);
582+
let units = unsafe { &*self.0.get() };
583+
let ptr = units.as_ptr().wrapping_add(index) as *const u8;
584+
let arg_ptr = unsafe { ptr.add(1) } as *const AtomicU8;
585+
unsafe { &*arg_ptr }.store(value, Ordering::Relaxed);
536586
}
537587

538588
/// Produce a clean copy of the bytecode suitable for serialization
539589
/// (marshal) and `co_code`. Specialized opcodes are mapped back to their
540590
/// base variants via `deoptimize()` and all CACHE entries are zeroed.
541591
pub fn original_bytes(&self) -> Vec<u8> {
542-
let units = unsafe { &*self.0.get() };
543-
let mut out = Vec::with_capacity(units.len() * 2);
544-
let len = units.len();
592+
let len = self.len();
593+
let mut out = Vec::with_capacity(len * 2);
545594
let mut i = 0;
546595
while i < len {
547-
let op = units[i].op.deoptimize();
596+
let op = self.read_op(i).deoptimize();
597+
let arg = self.read_arg(i);
548598
let caches = op.cache_entries();
549599
out.push(u8::from(op));
550-
out.push(u8::from(units[i].arg));
600+
out.push(u8::from(arg));
551601
// Zero-fill all CACHE entries (counter + cached data)
552602
for _ in 0..caches {
553603
i += 1;
@@ -562,20 +612,22 @@ impl CodeUnits {
562612
/// Initialize adaptive warmup counters for all cacheable instructions.
563613
/// Called lazily at RESUME (first execution of a code object).
564614
/// Uses the `arg` byte of the first CACHE entry, preserving `op = Instruction::Cache`.
615+
/// All writes are atomic (Relaxed) to avoid data races with concurrent readers.
565616
pub fn quicken(&self) {
566-
let units = unsafe { &mut *self.0.get() };
567-
let len = units.len();
617+
let len = self.len();
568618
let mut i = 0;
569619
while i < len {
570-
let op = units[i].op;
620+
let op = self.read_op(i);
571621
let caches = op.cache_entries();
572622
if caches > 0 {
573623
// Don't write adaptive counter for instrumented opcodes;
574624
// specialization is skipped while monitoring is active.
575625
if !op.is_instrumented() {
576626
let cache_base = i + 1;
577627
if cache_base < len {
578-
units[cache_base].arg = OpArgByte::from(ADAPTIVE_WARMUP_VALUE);
628+
unsafe {
629+
self.write_adaptive_counter(cache_base, ADAPTIVE_WARMUP_VALUE);
630+
}
579631
}
580632
}
581633
i += 1 + caches;

crates/vm/src/builtins/frame.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ pub(crate) mod stack_analysis {
182182
}
183183
oparg = (oparg << 8) | u32::from(u8::from(instructions[i].arg));
184184

185-
// De-instrument: get the underlying real instruction
186-
let opcode = opcode.to_base().unwrap_or(opcode);
185+
// De-instrument and de-specialize: get the underlying base instruction
186+
let opcode = opcode.to_base().unwrap_or(opcode).deoptimize();
187187

188188
let caches = opcode.cache_entries();
189189
let next_i = i + 1 + caches;

crates/vm/src/builtins/function.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ use super::{
55
PyAsyncGen, PyCode, PyCoroutine, PyDictRef, PyGenerator, PyModule, PyStr, PyStrRef, PyTuple,
66
PyTupleRef, PyType,
77
};
8-
#[cfg(feature = "jit")]
9-
use crate::common::lock::OnceCell;
108
use crate::common::lock::PyMutex;
119
use crate::function::ArgMapping;
1210
use crate::object::{PyAtomicRef, Traverse, TraverseFn};
@@ -75,7 +73,7 @@ pub struct PyFunction {
7573
doc: PyMutex<PyObjectRef>,
7674
func_version: AtomicU32,
7775
#[cfg(feature = "jit")]
78-
jitted_code: OnceCell<CompiledCode>,
76+
jitted_code: PyMutex<Option<CompiledCode>>,
7977
}
8078

8179
static FUNC_VERSION_COUNTER: AtomicU32 = AtomicU32::new(1);
@@ -214,7 +212,7 @@ impl PyFunction {
214212
doc: PyMutex::new(doc),
215213
func_version: AtomicU32::new(next_func_version()),
216214
#[cfg(feature = "jit")]
217-
jitted_code: OnceCell::new(),
215+
jitted_code: PyMutex::new(None),
218216
};
219217
Ok(func)
220218
}
@@ -538,7 +536,7 @@ impl Py<PyFunction> {
538536
vm: &VirtualMachine,
539537
) -> PyResult {
540538
#[cfg(feature = "jit")]
541-
if let Some(jitted_code) = self.jitted_code.get() {
539+
if let Some(jitted_code) = self.jitted_code.lock().as_ref() {
542540
use crate::convert::ToPyObject;
543541
match jit::get_jit_args(self, &func_args, jitted_code, vm) {
544542
Ok(args) => {
@@ -712,6 +710,10 @@ impl PyFunction {
712710
#[pygetset(setter)]
713711
fn set___code__(&self, code: PyRef<PyCode>, vm: &VirtualMachine) {
714712
self.code.swap_to_temporary_refs(code, vm);
713+
#[cfg(feature = "jit")]
714+
{
715+
*self.jitted_code.lock() = None;
716+
}
715717
self.func_version.store(0, Relaxed);
716718
}
717719

@@ -948,15 +950,15 @@ impl PyFunction {
948950
#[cfg(feature = "jit")]
949951
#[pymethod]
950952
fn __jit__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<()> {
951-
if zelf.jitted_code.get().is_some() {
953+
if zelf.jitted_code.lock().is_some() {
952954
return Ok(());
953955
}
954956
let arg_types = jit::get_jit_arg_types(&zelf, vm)?;
955957
let ret_type = jit::jit_ret_type(&zelf, vm)?;
956958
let code: &Py<PyCode> = &zelf.code;
957959
let compiled = rustpython_jit::compile(&code.code, &arg_types, ret_type)
958960
.map_err(|err| jit::new_jit_error(err.to_string(), vm))?;
959-
let _ = zelf.jitted_code.set(compiled);
961+
*zelf.jitted_code.lock() = Some(compiled);
960962
Ok(())
961963
}
962964
}

crates/vm/src/builtins/type.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ impl PyType {
347347
if old_version == 0 {
348348
return;
349349
}
350-
self.tp_version_tag.store(0, Ordering::Release);
350+
self.tp_version_tag.store(0, Ordering::SeqCst);
351351
// Release strong references held by cache entries for this version.
352352
// We hold owned refs that would prevent GC of class attributes after
353353
// type deletion.
@@ -2168,6 +2168,11 @@ impl SetAttr for PyType {
21682168
}
21692169
let assign = value.is_assign();
21702170

2171+
// Invalidate inline caches before modifying attributes.
2172+
// This ensures other threads see the version invalidation before
2173+
// any attribute changes, preventing use-after-free of cached descriptors.
2174+
zelf.modified();
2175+
21712176
if let PySetterValue::Assign(value) = value {
21722177
zelf.attributes.write().insert(attr_name, value);
21732178
} else {
@@ -2180,8 +2185,6 @@ impl SetAttr for PyType {
21802185
)));
21812186
}
21822187
}
2183-
// Invalidate inline caches that depend on this type's attributes
2184-
zelf.modified();
21852188

21862189
if attr_name.as_wtf8().starts_with("__") && attr_name.as_wtf8().ends_with("__") {
21872190
if assign {

crates/vm/src/dict_inner.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ use crate::{
1919
use alloc::fmt;
2020
use core::mem::size_of;
2121
use core::ops::ControlFlow;
22-
use core::sync::atomic::{AtomicU64, Ordering::Relaxed};
22+
use core::sync::atomic::{
23+
AtomicU64,
24+
Ordering::{Acquire, Release},
25+
};
2326
use num_traits::ToPrimitive;
2427

2528
// HashIndex is intended to be same size with hash::PyHash
@@ -261,12 +264,12 @@ type PopInnerResult<T> = ControlFlow<Option<DictEntry<T>>>;
261264
impl<T: Clone> Dict<T> {
262265
/// Monotonically increasing version counter for mutation tracking.
263266
pub fn version(&self) -> u64 {
264-
self.version.load(Relaxed)
267+
self.version.load(Acquire)
265268
}
266269

267270
/// Bump the version counter after any mutation.
268271
fn bump_version(&self) {
269-
self.version.fetch_add(1, Relaxed);
272+
self.version.fetch_add(1, Release);
270273
}
271274

272275
fn read(&self) -> PyRwLockReadGuard<'_, DictInner<T>> {

0 commit comments

Comments
 (0)