Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
vm: use exact-type guards for call specializations
  • Loading branch information
youknowone committed Mar 5, 2026
commit 78dbbb0152f39139d75e4a74392ac943b15aee45
81 changes: 44 additions & 37 deletions crates/vm/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4023,14 +4023,14 @@ impl ExecutingFrame<'_> {
let nargs: u32 = arg.into();
// Stack: [callable, self_or_null, arg1, ..., argN]
let callable = self.nth_value(nargs + 1);
if let Some(func) = callable.downcast_ref::<PyFunction>()
if let Some(func) = callable.downcast_ref_if_exact::<PyFunction>(vm)
&& func.func_version() == cached_version
&& cached_version != 0
{
let pos_args: Vec<PyObjectRef> = self.pop_multiple(nargs as usize).collect();
let self_or_null = self.pop_value_opt();
let callable = self.pop_value();
let func = callable.downcast_ref::<PyFunction>().unwrap();
let func = callable.downcast_ref_if_exact::<PyFunction>(vm).unwrap();
let args = if let Some(self_val) = self_or_null {
let mut all_args = Vec::with_capacity(pos_args.len() + 1);
all_args.push(self_val);
Expand Down Expand Up @@ -4063,11 +4063,11 @@ impl ExecutingFrame<'_> {
.is_some();
let callable = self.nth_value(nargs + 1);
if !self_or_null_is_some
&& let Some(bound_method) = callable.downcast_ref::<PyBoundMethod>()
&& let Some(bound_method) = callable.downcast_ref_if_exact::<PyBoundMethod>(vm)
{
let bound_function = bound_method.function_obj().clone();
let bound_self = bound_method.self_obj().clone();
if let Some(func) = bound_function.downcast_ref::<PyFunction>()
if let Some(func) = bound_function.downcast_ref_if_exact::<PyFunction>(vm)
&& func.func_version() == cached_version
&& cached_version != 0
{
Expand Down Expand Up @@ -4308,7 +4308,7 @@ impl ExecutingFrame<'_> {
let cached_version = self.code.instructions.read_cache_u32(cache_base + 1);
let nargs: u32 = arg.into();
let callable = self.nth_value(nargs + 1);
if let Some(func) = callable.downcast_ref::<PyFunction>()
if let Some(func) = callable.downcast_ref_if_exact::<PyFunction>(vm)
&& func.func_version() == cached_version
&& cached_version != 0
{
Expand Down Expand Up @@ -4348,11 +4348,11 @@ impl ExecutingFrame<'_> {
.is_some();
let callable = self.nth_value(nargs + 1);
if !self_or_null_is_some
&& let Some(bound_method) = callable.downcast_ref::<PyBoundMethod>()
&& let Some(bound_method) = callable.downcast_ref_if_exact::<PyBoundMethod>(vm)
{
let bound_function = bound_method.function_obj().clone();
let bound_self = bound_method.self_obj().clone();
if let Some(func) = bound_function.downcast_ref::<PyFunction>()
if let Some(func) = bound_function.downcast_ref_if_exact::<PyFunction>(vm)
&& func.func_version() == cached_version
&& cached_version != 0
{
Expand Down Expand Up @@ -4392,14 +4392,13 @@ impl ExecutingFrame<'_> {
.stack_index(stack_len - 2)
.as_ref()
.is_some_and(|obj| obj.downcast_ref::<PyList>().is_some());
let is_list_append =
callable
.downcast_ref::<PyMethodDescriptor>()
.is_some_and(|descr| {
descr.method.flags.contains(PyMethodFlags::METHOD)
&& descr.method.name == "append"
&& descr.objclass.is(vm.ctx.types.list_type)
});
let is_list_append = callable
.downcast_ref_if_exact::<PyMethodDescriptor>(vm)
.is_some_and(|descr| {
descr.method.flags.contains(PyMethodFlags::METHOD)
&& descr.method.name == "append"
&& descr.objclass.is(vm.ctx.types.list_type)
});
if is_list_append && self_or_null_is_some && self_is_list {
let item = self.pop_value();
let self_or_null = self.pop_value_opt();
Expand Down Expand Up @@ -4434,7 +4433,7 @@ impl ExecutingFrame<'_> {
let self_or_null_is_some = self.localsplus.stack_index(stack_len - 1).is_some();
let callable = self.nth_value(1);
let descr = if self_or_null_is_some {
callable.downcast_ref::<PyMethodDescriptor>()
callable.downcast_ref_if_exact::<PyMethodDescriptor>(vm)
} else {
None
};
Expand Down Expand Up @@ -4471,7 +4470,7 @@ impl ExecutingFrame<'_> {
let self_or_null_is_some = self.localsplus.stack_index(stack_len - 2).is_some();
let callable = self.nth_value(2);
let descr = if self_or_null_is_some {
callable.downcast_ref::<PyMethodDescriptor>()
callable.downcast_ref_if_exact::<PyMethodDescriptor>(vm)
} else {
None
};
Expand Down Expand Up @@ -4510,7 +4509,7 @@ impl ExecutingFrame<'_> {
.stack_index(stack_len - nargs as usize - 1)
.is_some();
let descr = if self_or_null_is_some {
callable.downcast_ref::<PyMethodDescriptor>()
callable.downcast_ref_if_exact::<PyMethodDescriptor>(vm)
} else {
None
};
Expand Down Expand Up @@ -4591,7 +4590,7 @@ impl ExecutingFrame<'_> {
{
// Look up __init__ (guarded by type_version)
if let Some(init) = cls.get_attr(identifier!(vm, __init__))
&& let Some(init_func) = init.downcast_ref::<PyFunction>()
&& let Some(init_func) = init.downcast_ref_if_exact::<PyFunction>(vm)
&& init_func.can_specialize_call(nargs + 1)
{
// Allocate object directly (tp_new == object.__new__)
Expand Down Expand Up @@ -4646,7 +4645,7 @@ impl ExecutingFrame<'_> {
.stack_index(stack_len - nargs as usize - 1)
.is_some();
let descr = if self_or_null_is_some {
callable.downcast_ref::<PyMethodDescriptor>()
callable.downcast_ref_if_exact::<PyMethodDescriptor>(vm)
} else {
None
};
Expand Down Expand Up @@ -4720,8 +4719,10 @@ impl ExecutingFrame<'_> {
.stack_index(stack_len - nargs as usize - 1)
.is_some();
let callable = self.nth_value(nargs + 1);
if callable.downcast_ref::<PyFunction>().is_some()
|| callable.downcast_ref::<PyBoundMethod>().is_some()
if callable.downcast_ref_if_exact::<PyFunction>(vm).is_some()
|| callable
.downcast_ref_if_exact::<PyBoundMethod>(vm)
.is_some()
{
self.deoptimize(Instruction::Call {
argc: Arg::marker(),
Expand Down Expand Up @@ -4755,7 +4756,7 @@ impl ExecutingFrame<'_> {
let nargs: u32 = arg.into();
// Stack: [callable, self_or_null, arg1, ..., argN, kwarg_names]
let callable = self.nth_value(nargs + 2);
if let Some(func) = callable.downcast_ref::<PyFunction>()
if let Some(func) = callable.downcast_ref_if_exact::<PyFunction>(vm)
&& func.func_version() == cached_version
&& cached_version != 0
{
Expand Down Expand Up @@ -4807,11 +4808,11 @@ impl ExecutingFrame<'_> {
.is_some();
let callable = self.nth_value(nargs + 2);
if !self_or_null_is_some
&& let Some(bound_method) = callable.downcast_ref::<PyBoundMethod>()
&& let Some(bound_method) = callable.downcast_ref_if_exact::<PyBoundMethod>(vm)
{
let bound_function = bound_method.function_obj().clone();
let bound_self = bound_method.self_obj().clone();
if let Some(func) = bound_function.downcast_ref::<PyFunction>()
if let Some(func) = bound_function.downcast_ref_if_exact::<PyFunction>(vm)
&& func.func_version() == cached_version
&& cached_version != 0
{
Expand Down Expand Up @@ -4854,8 +4855,10 @@ impl ExecutingFrame<'_> {
.stack_index(stack_len - nargs as usize - 2)
.is_some();
let callable = self.nth_value(nargs + 2);
if callable.downcast_ref::<PyFunction>().is_some()
|| callable.downcast_ref::<PyBoundMethod>().is_some()
if callable.downcast_ref_if_exact::<PyFunction>(vm).is_some()
|| callable
.downcast_ref_if_exact::<PyBoundMethod>(vm)
.is_some()
{
self.deoptimize(Instruction::CallKw {
argc: Arg::marker(),
Expand Down Expand Up @@ -6425,7 +6428,7 @@ impl ExecutingFrame<'_> {
args
};

let is_python_call = callable.downcast_ref::<PyFunction>().is_some();
let is_python_call = callable.downcast_ref_if_exact::<PyFunction>(vm).is_some();

// Fire CALL event
let call_arg0 = if self.monitoring_mask & monitoring::EVENT_CALL != 0 {
Expand Down Expand Up @@ -6728,7 +6731,7 @@ impl ExecutingFrame<'_> {
let func = self.top_value();
// Get the function reference and call the new method
let func_ref = func
.downcast_ref::<PyFunction>()
.downcast_ref_if_exact::<PyFunction>(vm)
.expect("SET_FUNCTION_ATTRIBUTE expects function on stack");

let payload: &PyFunction = func_ref.payload();
Expand Down Expand Up @@ -7706,7 +7709,7 @@ impl ExecutingFrame<'_> {
.is_some();
let callable = self.nth_value(nargs + 1);

if let Some(func) = callable.downcast_ref::<PyFunction>() {
if let Some(func) = callable.downcast_ref_if_exact::<PyFunction>(vm) {
if self.specialization_eval_frame_active(vm) {
unsafe {
self.code.instructions.write_adaptive_counter(
Expand Down Expand Up @@ -7753,8 +7756,10 @@ impl ExecutingFrame<'_> {

// Bound Python method object (`method`) specialization.
if !self_or_null_is_some
&& let Some(bound_method) = callable.downcast_ref::<PyBoundMethod>()
&& let Some(func) = bound_method.function_obj().downcast_ref::<PyFunction>()
&& let Some(bound_method) = callable.downcast_ref_if_exact::<PyBoundMethod>(vm)
&& let Some(func) = bound_method
.function_obj()
.downcast_ref_if_exact::<PyFunction>(vm)
{
if self.specialization_eval_frame_active(vm) {
unsafe {
Expand Down Expand Up @@ -7796,7 +7801,7 @@ impl ExecutingFrame<'_> {

// Try to specialize method descriptor calls
if self_or_null_is_some
&& let Some(descr) = callable.downcast_ref::<PyMethodDescriptor>()
&& let Some(descr) = callable.downcast_ref_if_exact::<PyMethodDescriptor>(vm)
&& descr.method.flags.contains(PyMethodFlags::METHOD)
{
let call_cache_entries = Instruction::CallListAppend.cache_entries();
Expand Down Expand Up @@ -7892,7 +7897,7 @@ impl ExecutingFrame<'_> {
if let (Some(cls_new_fn), Some(obj_new_fn)) = (cls_new, object_new)
&& cls_new_fn as usize == obj_new_fn as usize
&& let Some(init) = cls.get_attr(identifier!(vm, __init__))
&& let Some(init_func) = init.downcast_ref::<PyFunction>()
&& let Some(init_func) = init.downcast_ref_if_exact::<PyFunction>(vm)
&& init_func.can_specialize_call(nargs + 1)
{
let version = cls.tp_version_tag.load(Acquire);
Expand Down Expand Up @@ -7941,7 +7946,7 @@ impl ExecutingFrame<'_> {
.is_some();
let callable = self.nth_value(nargs + 2);

if let Some(func) = callable.downcast_ref::<PyFunction>() {
if let Some(func) = callable.downcast_ref_if_exact::<PyFunction>(vm) {
if self.specialization_eval_frame_active(vm) {
unsafe {
self.code.instructions.write_adaptive_counter(
Expand Down Expand Up @@ -7976,8 +7981,10 @@ impl ExecutingFrame<'_> {
}

if !self_or_null_is_some
&& let Some(bound_method) = callable.downcast_ref::<PyBoundMethod>()
&& let Some(func) = bound_method.function_obj().downcast_ref::<PyFunction>()
&& let Some(bound_method) = callable.downcast_ref_if_exact::<PyBoundMethod>(vm)
&& let Some(func) = bound_method
.function_obj()
.downcast_ref_if_exact::<PyFunction>(vm)
{
if self.specialization_eval_frame_active(vm) {
unsafe {
Expand Down