Skip to content

Commit 29c0cf2

Browse files
committed
Also check globals is exact dict for LOAD_GLOBAL fast path
get_chain_exact bypasses __missing__ on dict subclasses. Move get_chain_exact to PyExact<PyDict> impl with debug_assert, and have get_chain delegate to it. Store builtins_dict as Option<&PyExact<PyDict>> to enforce exact type at compile time. Use PyRangeIterator::next_fast() instead of pub(crate) fields. Fix comment style issues.
1 parent 5b9e078 commit 29c0cf2

File tree

4 files changed

+59
-31
lines changed

4 files changed

+59
-31
lines changed

crates/vm/src/builtins/dict.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use super::{
55
use crate::common::lock::LazyLock;
66
use crate::object::{Traverse, TraverseFn};
77
use crate::{
8-
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult,
8+
AsObject, Context, Py, PyExact, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult,
99
TryFromObject, atomic_func,
1010
builtins::{
1111
PyTuple,
@@ -681,22 +681,29 @@ impl Py<PyDict> {
681681
let self_exact = self.exact_dict(vm);
682682
let other_exact = other.exact_dict(vm);
683683
if self_exact && other_exact {
684-
self.entries.get_chain(&other.entries, vm, key)
684+
// SAFETY: exact_dict checks passed
685+
let self_exact = unsafe { PyExact::ref_unchecked(self) };
686+
let other_exact = unsafe { PyExact::ref_unchecked(other) };
687+
self_exact.get_chain_exact(other_exact, key, vm)
685688
} else if let Some(value) = self.get_item_opt(key, vm)? {
686689
Ok(Some(value))
687690
} else {
688691
other.get_item_opt(key, vm)
689692
}
690693
}
694+
}
691695

692-
/// Like `get_chain` but skips the exact_dict type checks. Use when both
693-
/// dicts are known to be exact dict types (e.g. globals + builtins).
694-
pub fn get_chain_exact<K: DictKey + ?Sized>(
696+
impl PyExact<PyDict> {
697+
/// Look up `key` in `self`, falling back to `other`.
698+
/// Both dicts must be exact `dict` types (enforced by `PyExact`).
699+
pub(crate) fn get_chain_exact<K: DictKey + ?Sized>(
695700
&self,
696701
other: &Self,
697702
key: &K,
698703
vm: &VirtualMachine,
699704
) -> PyResult<Option<PyObjectRef>> {
705+
debug_assert!(self.class().is(vm.ctx.types.dict_type));
706+
debug_assert!(other.class().is(vm.ctx.types.dict_type));
700707
self.entries.get_chain(&other.entries, vm, key)
701708
}
702709
}

crates/vm/src/builtins/range.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,10 +607,23 @@ impl IterNext for PyLongRangeIterator {
607607
#[pyclass(module = false, name = "range_iterator")]
608608
#[derive(Debug)]
609609
pub struct PyRangeIterator {
610-
pub(crate) index: AtomicCell<usize>,
611-
pub(crate) start: isize,
612-
pub(crate) step: isize,
613-
pub(crate) length: usize,
610+
index: AtomicCell<usize>,
611+
start: isize,
612+
step: isize,
613+
length: usize,
614+
}
615+
616+
impl PyRangeIterator {
617+
/// Advance and return next value without going through the iterator protocol.
618+
#[inline]
619+
pub(crate) fn next_fast(&self) -> Option<isize> {
620+
let index = self.index.fetch_add(1);
621+
if index < self.length {
622+
Some(self.start + (index as isize) * self.step)
623+
} else {
624+
None
625+
}
626+
}
614627
}
615628

616629
impl PyPayload for PyRangeIterator {

crates/vm/src/frame.rs

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#[cfg(feature = "flame")]
22
use crate::bytecode::InstructionMetadata;
33
use crate::{
4-
AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, PyStackRef, TryFromObject,
5-
VirtualMachine,
4+
AsObject, Py, PyExact, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, PyStackRef,
5+
TryFromObject, VirtualMachine,
66
builtins::{
77
PyBaseException, PyBaseExceptionRef, PyCode, PyCoroutine, PyDict, PyDictRef, PyGenerator,
88
PyInterpolation, PyList, PySet, PySlice, PyStr, PyStrInterned, PyTemplate, PyTraceback,
@@ -447,7 +447,14 @@ impl Py<Frame> {
447447
locals: &self.locals,
448448
globals: &self.globals,
449449
builtins: &self.builtins,
450-
builtins_dict: self.builtins.downcast_ref_if_exact::<PyDict>(vm),
450+
builtins_dict: if self.globals.class().is(vm.ctx.types.dict_type) {
451+
self.builtins
452+
.downcast_ref_if_exact::<PyDict>(vm)
453+
// SAFETY: downcast_ref_if_exact already verified exact type
454+
.map(|d| unsafe { PyExact::ref_unchecked(d) })
455+
} else {
456+
None
457+
},
451458
lasti: &self.lasti,
452459
object: self,
453460
state: &mut state,
@@ -530,9 +537,11 @@ struct ExecutingFrame<'a> {
530537
locals: &'a ArgMapping,
531538
globals: &'a PyDictRef,
532539
builtins: &'a PyObjectRef,
533-
/// Cached downcast of builtins to PyDict. builtins never changes during
534-
/// frame execution, so we avoid repeating the downcast on every LOAD_GLOBAL.
535-
builtins_dict: Option<&'a Py<PyDict>>,
540+
/// Cached downcast of builtins to PyDict for fast LOAD_GLOBAL.
541+
/// Only set when both globals and builtins are exact dict types (not
542+
/// subclasses), so that `__missing__` / `__getitem__` overrides are
543+
/// not bypassed.
544+
builtins_dict: Option<&'a PyExact<PyDict>>,
536545
object: &'a Py<Frame>,
537546
lasti: &'a PyAtomic<u32>,
538547
state: &'a mut FrameState,
@@ -3028,8 +3037,10 @@ impl ExecutingFrame<'_> {
30283037
#[inline]
30293038
fn load_global_or_builtin(&self, name: &Py<PyStr>, vm: &VirtualMachine) -> PyResult {
30303039
if let Some(builtins_dict) = self.builtins_dict {
3031-
// Fast path: both globals (PyDictRef) and builtins are exact dicts
3032-
self.globals
3040+
// Fast path: both globals and builtins are exact dicts
3041+
// SAFETY: builtins_dict is only set when globals is also exact dict
3042+
let globals_exact = unsafe { PyExact::ref_unchecked(self.globals.as_ref()) };
3043+
globals_exact
30333044
.get_chain_exact(builtins_dict, name, vm)?
30343045
.ok_or_else(|| {
30353046
vm.new_name_error(format!("name '{name}' is not defined"), name.to_owned())
@@ -3720,13 +3731,10 @@ impl ExecutingFrame<'_> {
37203731

37213732
// FOR_ITER_RANGE: bypass generic iterator protocol for range iterators
37223733
if let Some(range_iter) = top.downcast_ref_if_exact::<PyRangeIterator>(vm) {
3723-
let index = range_iter.index.fetch_add(1);
3724-
if index < range_iter.length {
3725-
let value = range_iter.start + (index as isize) * range_iter.step;
3734+
if let Some(value) = range_iter.next_fast() {
37263735
self.push_value(vm.ctx.new_int(value).into());
37273736
return Ok(true);
37283737
}
3729-
// Exhausted
37303738
if vm.use_tracing.get() && !vm.is_none(&self.object.trace.lock()) {
37313739
let stop_exc = vm.new_stop_iteration(None);
37323740
self.fire_exception_trace(&stop_exc, vm)?;
@@ -4085,17 +4093,17 @@ impl ExecutingFrame<'_> {
40854093
self.push_value(vm.ctx.new_bool(result).into());
40864094
return Ok(None);
40874095
}
4088-
// COMPARE_OP_FLOAT
4096+
// COMPARE_OP_FLOAT: leaf type, cannot recurse — skip rich_compare dispatch.
4097+
// Falls through on NaN (partial_cmp returns None) for correct != semantics.
40894098
if let (Some(a_f), Some(b_f)) = (
40904099
a.downcast_ref_if_exact::<PyFloat>(vm),
40914100
b.downcast_ref_if_exact::<PyFloat>(vm),
40924101
) {
4093-
let result = a_f
4094-
.to_f64()
4095-
.partial_cmp(&b_f.to_f64())
4096-
.is_some_and(|ord| cmp_op.eval_ord(ord));
4097-
self.push_value(vm.ctx.new_bool(result).into());
4098-
return Ok(None);
4102+
if let Some(ord) = a_f.to_f64().partial_cmp(&b_f.to_f64()) {
4103+
let result = cmp_op.eval_ord(ord);
4104+
self.push_value(vm.ctx.new_bool(result).into());
4105+
return Ok(None);
4106+
}
40994107
}
41004108

41014109
let value = a.rich_compare(b, cmp_op, vm)?;

crates/vm/src/protocol/object.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,16 +277,16 @@ impl PyObject {
277277

278278
// Perform a comparison, raising TypeError when the requested comparison
279279
// operator is not supported.
280-
// see: CPython PyObject_RichCompare / do_richcompare
280+
// see: PyObject_RichCompare / do_richcompare
281281
#[inline] // called by ExecutingFrame::execute_compare with const op
282282
fn _cmp(
283283
&self,
284284
other: &Self,
285285
op: PyComparisonOp,
286286
vm: &VirtualMachine,
287287
) -> PyResult<Either<PyObjectRef, bool>> {
288-
// Single recursion guard for the entire comparison (matching CPython's
289-
// Py_EnterRecursiveCallTstate placement in do_richcompare).
288+
// Single recursion guard for the entire comparison
289+
// (do_richcompare in Objects/object.c).
290290
vm.with_recursion("in comparison", || self._cmp_inner(other, op, vm))
291291
}
292292

0 commit comments

Comments
 (0)