diff --git a/crates/codegen/src/compile.rs b/crates/codegen/src/compile.rs index b961be42517..438078d3b49 100644 --- a/crates/codegen/src/compile.rs +++ b/crates/codegen/src/compile.rs @@ -1015,16 +1015,31 @@ impl Compiler { // Use varnames from symbol table (already collected in definition order) let varname_cache: IndexSet = ste.varnames.iter().cloned().collect(); - // Build cellvars using dictbytype (CELL scope, sorted) + // Build cellvars in localsplus order: + // 1. cell+local vars in varnames definition order + // 2. cell-only vars sorted alphabetically let mut cellvar_cache = IndexSet::default(); - let mut cell_names: Vec<_> = ste + let cell_scope_names: IndexSet = ste .symbols .iter() .filter(|(_, s)| s.scope == SymbolScope::Cell) .map(|(name, _)| name.clone()) .collect(); - cell_names.sort(); - for name in cell_names { + + // First: cell vars that are also in varnames (in varnames order) + for var in varname_cache.iter() { + if cell_scope_names.contains(var) { + cellvar_cache.insert(var.clone()); + } + } + // Second: cell-only vars (not in varnames, sorted for determinism) + let mut cell_only: Vec<_> = cell_scope_names + .iter() + .filter(|name| !varname_cache.contains(name.as_str())) + .cloned() + .collect(); + cell_only.sort(); + for name in cell_only { cellvar_cache.insert(name); } @@ -4168,11 +4183,11 @@ impl Compiler { flags: bytecode::MakeFunctionFlags, ) -> CompileResult<()> { // Handle free variables (closure) - let has_freevars = !code.freevars.is_empty(); + let has_freevars = !code.freevars().is_empty(); if has_freevars { // Build closure tuple by loading free variables - for var in &code.freevars { + for var in code.freevars() { // Special case: If a class contains a method with a // free variable that has the same name as a method, // the name will be considered free *and* local in the @@ -4229,7 +4244,7 @@ impl Compiler { emit!( self, Instruction::BuildTuple { - size: code.freevars.len().to_u32(), + size: code.freevars().len().to_u32(), } ); } diff --git a/crates/codegen/src/ir.rs b/crates/codegen/src/ir.rs index 59b5f7d4d09..13a15f7bf8d 100644 --- a/crates/codegen/src/ir.rs +++ b/crates/codegen/src/ir.rs @@ -8,8 +8,8 @@ use rustpython_compiler_core::{ OneIndexed, SourceLocation, bytecode::{ AnyInstruction, Arg, CodeFlags, CodeObject, CodeUnit, CodeUnits, ConstantData, - ExceptionTableEntry, InstrDisplayContext, Instruction, InstructionMetadata, Label, OpArg, - PseudoInstruction, PyCodeLocationInfoKind, encode_exception_table, + ExceptionTableEntry, InstrDisplayContext, Instruction, InstructionMetadata, Label, + LocalKind, OpArg, PseudoInstruction, PyCodeLocationInfoKind, encode_exception_table, }, varint::{write_signed_varint, write_varint}, }; @@ -193,7 +193,6 @@ impl CodeInfo { self.optimize_load_fast_borrow(); let max_stackdepth = self.max_stackdepth()?; - let cell2arg = self.cell2arg(); let Self { flags, @@ -219,13 +218,44 @@ impl CodeInfo { varnames: varname_cache, cellvars: cellvar_cache, freevars: freevar_cache, - fast_hidden: _, + fast_hidden, argcount: arg_count, posonlyargcount: posonlyarg_count, kwonlyargcount: kwonlyarg_count, firstlineno: first_line_number, } = metadata; + // Build localsplusnames and localspluskinds + let mut localsplusnames_vec: Vec = Vec::new(); + let mut localspluskinds_vec: Vec = Vec::new(); + + // 1. For each var in varnames + for var in varname_cache.iter() { + let mut kind = LocalKind::LOCAL; + if cellvar_cache.contains(var) { + kind |= LocalKind::CELL; + } + if fast_hidden.get(var).copied().unwrap_or(false) { + kind |= LocalKind::HIDDEN; + } + localsplusnames_vec.push(var.clone()); + localspluskinds_vec.push(kind); + } + + // 2. For each var in cellvars that is NOT in varnames + for var in cellvar_cache.iter() { + if !varname_cache.contains(var) { + localsplusnames_vec.push(var.clone()); + localspluskinds_vec.push(LocalKind::CELL); + } + } + + // 3. For each var in freevars + for var in freevar_cache.iter() { + localsplusnames_vec.push(var.clone()); + localspluskinds_vec.push(LocalKind::FREE); + } + let mut instructions = Vec::new(); let mut locations = Vec::new(); let mut linetable_locations: Vec = Vec::new(); @@ -389,46 +419,16 @@ impl CodeInfo { locations: locations.into_boxed_slice(), constants: constants.into_iter().collect(), names: name_cache.into_iter().collect(), - varnames: varname_cache.into_iter().collect(), - cellvars: cellvar_cache.into_iter().collect(), - freevars: freevar_cache.into_iter().collect(), - cell2arg, + nlocals: varname_cache.len() as u32, + ncellvars: cellvar_cache.len() as u32, + nfreevars: freevar_cache.len() as u32, + localsplusnames: localsplusnames_vec.into_iter().collect(), + localspluskinds: localspluskinds_vec.into_boxed_slice(), linetable, exceptiontable, }) } - fn cell2arg(&self) -> Option> { - if self.metadata.cellvars.is_empty() { - return None; - } - - let total_args = self.metadata.argcount - + self.metadata.kwonlyargcount - + self.flags.contains(CodeFlags::VARARGS) as u32 - + self.flags.contains(CodeFlags::VARKEYWORDS) as u32; - - let mut found_cellarg = false; - let cell2arg = self - .metadata - .cellvars - .iter() - .map(|var| { - self.metadata - .varnames - .get_index_of(var) - // check that it's actually an arg - .filter(|i| *i < total_args as usize) - .map_or(-1, |i| { - found_cellarg = true; - i as i32 - }) - }) - .collect::>(); - - if found_cellarg { Some(cell2arg) } else { None } - } - fn dce(&mut self) { for block in &mut self.blocks { let mut last_instr = None; diff --git a/crates/compiler-core/src/bytecode.rs b/crates/compiler-core/src/bytecode.rs index d177b728efa..6885759aefb 100644 --- a/crates/compiler-core/src/bytecode.rs +++ b/crates/compiler-core/src/bytecode.rs @@ -276,12 +276,13 @@ pub struct CodeObject { pub obj_name: C::Name, /// Qualified name of the object (like CPython's co_qualname) pub qualname: C::Name, - pub cell2arg: Option>, + pub localsplusnames: Box<[C::Name]>, + pub localspluskinds: Box<[LocalKind]>, + pub nlocals: u32, + pub ncellvars: u32, + pub nfreevars: u32, pub constants: Box<[C]>, pub names: Box<[C::Name]>, - pub varnames: Box<[C::Name]>, - pub cellvars: Box<[C::Name]>, - pub freevars: Box<[C::Name]>, /// Line number table (CPython 3.11+ format) pub linetable: Box<[u8]>, /// Exception handling table @@ -304,6 +305,16 @@ bitflags! { } } +bitflags! { + #[derive(Copy, Clone, Debug, PartialEq)] + pub struct LocalKind: u8 { + const LOCAL = 0x20; + const CELL = 0x40; + const FREE = 0x80; + const HIDDEN = 0x10; + } +} + #[derive(Copy, Clone)] #[repr(C)] pub struct CodeUnit { @@ -561,25 +572,42 @@ impl> fmt::Debug for Arguments<'_, N> { } impl CodeObject { + /// Returns varnames (first nlocals entries of localsplusnames) + pub fn varnames(&self) -> &[C::Name] { + &self.localsplusnames[..self.nlocals as usize] + } + + /// Returns freevars (last nfreevars entries of localsplusnames) + pub fn freevars(&self) -> &[C::Name] { + let start = self.localsplusnames.len() - self.nfreevars as usize; + &self.localsplusnames[start..] + } + + /// Total localsplus count + pub fn nlocalsplus(&self) -> usize { + self.localsplusnames.len() + } + /// Get all arguments of the code object /// like inspect.getargs pub fn arg_names(&self) -> Arguments<'_, C::Name> { + let varnames = self.varnames(); let nargs = self.arg_count as usize; let nkwargs = self.kwonlyarg_count as usize; let mut varargs_pos = nargs + nkwargs; - let posonlyargs = &self.varnames[..self.posonlyarg_count as usize]; - let args = &self.varnames[..nargs]; - let kwonlyargs = &self.varnames[nargs..varargs_pos]; + let posonlyargs = &varnames[..self.posonlyarg_count as usize]; + let args = &varnames[..nargs]; + let kwonlyargs = &varnames[nargs..varargs_pos]; let vararg = if self.flags.contains(CodeFlags::VARARGS) { - let vararg = &self.varnames[varargs_pos]; + let vararg = &varnames[varargs_pos]; varargs_pos += 1; Some(vararg) } else { None }; let varkwarg = if self.flags.contains(CodeFlags::VARKEYWORDS) { - Some(&self.varnames[varargs_pos]) + Some(&varnames[varargs_pos]) } else { None }; @@ -682,9 +710,11 @@ impl CodeObject { .map(|x| bag.make_constant(x.borrow_constant())) .collect(), names: map_names(self.names), - varnames: map_names(self.varnames), - cellvars: map_names(self.cellvars), - freevars: map_names(self.freevars), + localsplusnames: map_names(self.localsplusnames), + localspluskinds: self.localspluskinds, + nlocals: self.nlocals, + ncellvars: self.ncellvars, + nfreevars: self.nfreevars, source_path: bag.make_name(self.source_path.as_ref()), obj_name: bag.make_name(self.obj_name.as_ref()), qualname: bag.make_name(self.qualname.as_ref()), @@ -697,7 +727,6 @@ impl CodeObject { kwonlyarg_count: self.kwonlyarg_count, first_line_number: self.first_line_number, max_stackdepth: self.max_stackdepth, - cell2arg: self.cell2arg, linetable: self.linetable, exceptiontable: self.exceptiontable, } @@ -714,9 +743,11 @@ impl CodeObject { .map(|x| bag.make_constant(x.borrow_constant())) .collect(), names: map_names(&self.names), - varnames: map_names(&self.varnames), - cellvars: map_names(&self.cellvars), - freevars: map_names(&self.freevars), + localsplusnames: map_names(&self.localsplusnames), + localspluskinds: self.localspluskinds.clone(), + nlocals: self.nlocals, + ncellvars: self.ncellvars, + nfreevars: self.nfreevars, source_path: bag.make_name(self.source_path.as_ref()), obj_name: bag.make_name(self.obj_name.as_ref()), qualname: bag.make_name(self.qualname.as_ref()), @@ -729,7 +760,6 @@ impl CodeObject { kwonlyarg_count: self.kwonlyarg_count, first_line_number: self.first_line_number, max_stackdepth: self.max_stackdepth, - cell2arg: self.cell2arg.clone(), linetable: self.linetable.clone(), exceptiontable: self.exceptiontable.clone(), } @@ -773,14 +803,20 @@ impl InstrDisplayContext for CodeObject { } fn get_varname(&self, i: usize) -> &str { - self.varnames[i].as_ref() + self.localsplusnames[i].as_ref() } fn get_cell_name(&self, i: usize) -> &str { - self.cellvars - .get(i) - .unwrap_or_else(|| &self.freevars[i - self.cellvars.len()]) - .as_ref() + let mut count = 0; + for (name, &kind) in self.localsplusnames.iter().zip(self.localspluskinds.iter()) { + if kind.intersects(LocalKind::CELL | LocalKind::FREE) { + if count == i { + return name.as_ref(); + } + count += 1; + } + } + panic!("cell/free index {i} out of bounds") } } diff --git a/crates/compiler-core/src/marshal.rs b/crates/compiler-core/src/marshal.rs index decb25d5283..bab91944f59 100644 --- a/crates/compiler-core/src/marshal.rs +++ b/crates/compiler-core/src/marshal.rs @@ -1,4 +1,4 @@ -use crate::{OneIndexed, SourceLocation, bytecode::*}; +use crate::{OneIndexed, SourceLocation, bytecode::LocalKind, bytecode::*}; use core::convert::Infallible; use malachite_bigint::{BigInt, Sign}; use num_complex::Complex64; @@ -223,15 +223,6 @@ pub fn deserialize_code( let len = rdr.read_u32()?; let qualname = bag.make_name(rdr.read_str(len)?); - let len = rdr.read_u32()?; - let cell2arg = (len != 0) - .then(|| { - (0..len) - .map(|_| Ok(rdr.read_u32()? as i32)) - .collect::>>() - }) - .transpose()?; - let len = rdr.read_u32()?; let constants = (0..len) .map(|_| deserialize_value(rdr, bag)) @@ -248,9 +239,32 @@ pub fn deserialize_code( }; let names = read_names()?; - let varnames = read_names()?; - let cellvars = read_names()?; - let freevars = read_names()?; + + // Read localsplusnames and localspluskinds + let localsplusnames = read_names()?; + + let localspluskinds_len = rdr.read_u32()?; + let localspluskinds_raw = rdr.read_slice(localspluskinds_len)?; + let localspluskinds: Box<[LocalKind]> = localspluskinds_raw + .iter() + .map(|&b| LocalKind::from_bits_truncate(b)) + .collect(); + + // Compute counts from localspluskinds + let mut nlocals: u32 = 0; + let mut ncellvars: u32 = 0; + let mut nfreevars: u32 = 0; + for &kind in localspluskinds.iter() { + if kind.contains(LocalKind::LOCAL) { + nlocals += 1; + } + if kind.contains(LocalKind::CELL) { + ncellvars += 1; + } + if kind.contains(LocalKind::FREE) { + nfreevars += 1; + } + } // Read linetable and exceptiontable let linetable_len = rdr.read_u32()?; @@ -274,12 +288,13 @@ pub fn deserialize_code( max_stackdepth, obj_name, qualname, - cell2arg, + localsplusnames, + localspluskinds, constants, names, - varnames, - cellvars, - freevars, + nlocals, + ncellvars, + nfreevars, linetable, exceptiontable, }) @@ -683,12 +698,6 @@ pub fn serialize_code(buf: &mut W, code: &CodeObject) write_vec(buf, code.obj_name.as_ref().as_bytes()); write_vec(buf, code.qualname.as_ref().as_bytes()); - let cell2arg = code.cell2arg.as_deref().unwrap_or(&[]); - write_len(buf, cell2arg.len()); - for &i in cell2arg { - buf.write_u32(i as u32) - } - write_len(buf, code.constants.len()); for constant in &*code.constants { serialize_value(buf, constant.borrow_constant().into()).unwrap_or_else(|x| match x {}) @@ -702,9 +711,15 @@ pub fn serialize_code(buf: &mut W, code: &CodeObject) }; write_names(&code.names); - write_names(&code.varnames); - write_names(&code.cellvars); - write_names(&code.freevars); + + // Serialize localsplusnames + write_names(&code.localsplusnames); + + // Serialize localspluskinds as raw bytes + write_len(buf, code.localspluskinds.len()); + for kind in &*code.localspluskinds { + buf.write_u8(kind.bits()); + } // Serialize linetable and exceptiontable write_vec(buf, &code.linetable); diff --git a/crates/jit/src/lib.rs b/crates/jit/src/lib.rs index 1e278617661..f58d92d3d47 100644 --- a/crates/jit/src/lib.rs +++ b/crates/jit/src/lib.rs @@ -92,7 +92,7 @@ impl Jit { let sig = { let mut compiler = FunctionCompiler::new( &mut builder, - bytecode.varnames.len(), + bytecode.nlocals as usize, args, ret, entry_block, diff --git a/crates/vm/src/builtins/code.rs b/crates/vm/src/builtins/code.rs index 932100db94f..d3bfa85dd23 100644 --- a/crates/vm/src/builtins/code.rs +++ b/crates/vm/src/builtins/code.rs @@ -522,6 +522,42 @@ impl Constructor for PyCode { )], > = vec![(loc, loc); instructions.len()].into_boxed_slice(); + // Build localsplusnames/localspluskinds from varnames, cellvars, freevars + use rustpython_compiler_core::bytecode::LocalKind; + + // Save counts before consuming the arrays + let nlocals = varnames.len() as u32; + let ncellvars = cellvars.len() as u32; + let nfreevars = freevars.len() as u32; + + let cellvar_set: std::collections::HashSet<&str> = + cellvars.iter().map(|s| s.as_str()).collect(); + + let mut lp_names = Vec::new(); + let mut lp_kinds = Vec::new(); + + // 1. varnames (locals, some may be cells too) + for vn in varnames.iter() { + let mut kind = LocalKind::LOCAL; + if cellvar_set.contains(vn.as_str()) { + kind |= LocalKind::CELL; + } + lp_names.push(*vn); + lp_kinds.push(kind); + } + // 2. cell-only vars + for cn in cellvars.iter() { + if !varnames.iter().any(|vn| vn.as_str() == cn.as_str()) { + lp_names.push(*cn); + lp_kinds.push(LocalKind::CELL); + } + } + // 3. free vars + for fv in freevars.iter() { + lp_names.push(*fv); + lp_kinds.push(LocalKind::FREE); + } + // Build the CodeObject let code = CodeObject { instructions, @@ -539,12 +575,13 @@ impl Constructor for PyCode { max_stackdepth: args.stacksize, obj_name: vm.ctx.intern_str(args.name.as_str()), qualname: vm.ctx.intern_str(args.qualname.as_str()), - cell2arg: None, // TODO: reuse `fn cell2arg` + localsplusnames: lp_names.into_boxed_slice(), + localspluskinds: lp_kinds.into_boxed_slice(), + nlocals, + ncellvars, + nfreevars, constants, names, - varnames, - cellvars, - freevars, linetable: args.linetable.as_bytes().to_vec().into_boxed_slice(), exceptiontable: args.exceptiontable.as_bytes().to_vec().into_boxed_slice(), }; @@ -577,17 +614,20 @@ impl PyCode { #[pygetset] pub fn co_cellvars(&self, vm: &VirtualMachine) -> PyTupleRef { - let cellvars = self - .cellvars + let cellvars: Vec = self + .code + .localsplusnames .iter() - .map(|name| name.to_pyobject(vm)) + .zip(self.code.localspluskinds.iter()) + .filter(|&(_, kind)| kind.contains(bytecode::LocalKind::CELL)) + .map(|(name, _)| name.to_object()) .collect(); vm.ctx.new_tuple(cellvars) } #[pygetset] fn co_nlocals(&self) -> usize { - self.code.varnames.len() + self.code.nlocals as usize } #[pygetset] @@ -634,7 +674,7 @@ impl PyCode { #[pygetset] pub fn co_varnames(&self, vm: &VirtualMachine) -> PyTupleRef { - let varnames = self.code.varnames.iter().map(|s| s.to_object()).collect(); + let varnames = self.code.varnames().iter().map(|s| s.to_object()).collect(); vm.ctx.new_tuple(varnames) } @@ -660,14 +700,35 @@ impl PyCode { pub fn co_freevars(&self, vm: &VirtualMachine) -> PyTupleRef { let names = self .code - .freevars - .deref() + .freevars() .iter() .map(|name| name.to_pyobject(vm)) .collect(); vm.ctx.new_tuple(names) } + #[pygetset] + fn co_localsplusnames(&self, vm: &VirtualMachine) -> PyTupleRef { + let names: Vec = self + .code + .localsplusnames + .iter() + .map(|name| name.to_pyobject(vm)) + .collect(); + vm.ctx.new_tuple(names) + } + + #[pygetset] + fn co_localspluskinds(&self, vm: &VirtualMachine) -> crate::builtins::PyBytesRef { + let bytes: Vec = self + .code + .localspluskinds + .iter() + .map(|kind| kind.bits()) + .collect(); + vm.ctx.new_bytes(bytes) + } + #[pygetset] pub fn co_linetable(&self, vm: &VirtualMachine) -> crate::builtins::PyBytesRef { // Return the actual linetable from the code object @@ -947,7 +1008,7 @@ impl PyCode { let varnames = match co_varnames { OptionalArg::Present(varnames) => varnames, - OptionalArg::Missing => self.code.varnames.iter().map(|s| s.to_object()).collect(), + OptionalArg::Missing => self.code.varnames().iter().map(|s| s.to_object()).collect(), }; let qualname = match co_qualname { @@ -969,20 +1030,27 @@ impl PyCode { OptionalArg::Missing => self.code.instructions.clone(), }; - let cellvars = match co_cellvars { + let cellvars: Box<[&'static PyStrInterned]> = match co_cellvars { OptionalArg::Present(cellvars) => cellvars .into_iter() .map(|o| o.as_interned_str(vm).unwrap()) .collect(), - OptionalArg::Missing => self.code.cellvars.clone(), + OptionalArg::Missing => self + .code + .localsplusnames + .iter() + .zip(self.code.localspluskinds.iter()) + .filter(|&(_, kind)| kind.contains(bytecode::LocalKind::CELL)) + .map(|(name, _)| *name) + .collect(), }; - let freevars = match co_freevars { + let freevars: Box<[&'static PyStrInterned]> = match co_freevars { OptionalArg::Present(freevars) => freevars .into_iter() .map(|o| o.as_interned_str(vm).unwrap()) .collect(), - OptionalArg::Missing => self.code.freevars.clone(), + OptionalArg::Missing => self.code.freevars().iter().copied().collect(), }; // Validate co_nlocals if provided @@ -1009,6 +1077,50 @@ impl PyCode { OptionalArg::Missing => self.code.exceptiontable.clone(), }; + let varnames_interned: Box<[&'static PyStrInterned]> = varnames + .into_iter() + .map(|o| o.as_interned_str(vm).unwrap()) + .collect(); + + // Build localsplusnames/localspluskinds from varnames, cellvars, freevars + use rustpython_compiler_core::bytecode::LocalKind; + + // Save counts before consuming the arrays + let nlocals = varnames_interned.len() as u32; + let ncellvars = cellvars.len() as u32; + let nfreevars = freevars.len() as u32; + + let cellvar_set: std::collections::HashSet<&str> = + cellvars.iter().map(|s| s.as_str()).collect(); + + let mut lp_names = Vec::new(); + let mut lp_kinds = Vec::new(); + + // 1. varnames (locals, some may be cells too) + for vn in varnames_interned.iter() { + let mut kind = LocalKind::LOCAL; + if cellvar_set.contains(vn.as_str()) { + kind |= LocalKind::CELL; + } + lp_names.push(*vn); + lp_kinds.push(kind); + } + // 2. cell-only vars + for cn in cellvars.iter() { + if !varnames_interned + .iter() + .any(|vn| vn.as_str() == cn.as_str()) + { + lp_names.push(*cn); + lp_kinds.push(LocalKind::CELL); + } + } + // 3. free vars + for fv in freevars.iter() { + lp_names.push(*fv); + lp_kinds.push(LocalKind::FREE); + } + let new_code = CodeObject { flags: CodeFlags::from_bits_truncate(flags), posonlyarg_count, @@ -1029,13 +1141,11 @@ impl PyCode { .into_iter() .map(|o| o.as_interned_str(vm).unwrap()) .collect(), - varnames: varnames - .into_iter() - .map(|o| o.as_interned_str(vm).unwrap()) - .collect(), - cellvars, - freevars, - cell2arg: self.code.cell2arg.clone(), + localsplusnames: lp_names.into_boxed_slice(), + localspluskinds: lp_kinds.into_boxed_slice(), + nlocals, + ncellvars, + nfreevars, linetable, exceptiontable, }; @@ -1045,30 +1155,14 @@ impl PyCode { #[pymethod] fn _varname_from_oparg(&self, opcode: i32, vm: &VirtualMachine) -> PyResult { - let idx_err = |vm: &VirtualMachine| vm.new_index_error("tuple index out of range"); - - let idx = usize::try_from(opcode).map_err(|_| idx_err(vm))?; - - let varnames_len = self.code.varnames.len(); - let cellvars_len = self.code.cellvars.len(); - - let name = if idx < varnames_len { - // Index in varnames - self.code.varnames.get(idx).ok_or_else(|| idx_err(vm))? - } else if idx < varnames_len + cellvars_len { - // Index in cellvars - self.code - .cellvars - .get(idx - varnames_len) - .ok_or_else(|| idx_err(vm))? - } else { - // Index in freevars - self.code - .freevars - .get(idx - varnames_len - cellvars_len) - .ok_or_else(|| idx_err(vm))? - }; - Ok(name.to_object()) + let idx = usize::try_from(opcode) + .map_err(|_| vm.new_index_error("tuple index out of range".to_owned()))?; + let name = self + .code + .localsplusnames + .get(idx) + .ok_or_else(|| vm.new_index_error("tuple index out of range".to_owned()))?; + Ok(name.to_pyobject(vm)) } } diff --git a/crates/vm/src/builtins/function.rs b/crates/vm/src/builtins/function.rs index 7eb74dec41e..0d452180d7d 100644 --- a/crates/vm/src/builtins/function.rs +++ b/crates/vm/src/builtins/function.rs @@ -247,7 +247,7 @@ impl PyFunction { }; let arg_pos = |range: core::ops::Range<_>, name: &str| { - code.varnames + code.varnames() .iter() .enumerate() .skip(range.start) @@ -312,7 +312,7 @@ impl PyFunction { let mut missing: Vec<_> = (nargs..n_required) .filter_map(|i| { if fastlocals[i].is_none() { - Some(&code.varnames[i]) + Some(&code.varnames()[i]) } else { None } @@ -372,7 +372,7 @@ impl PyFunction { // Check if kw only arguments are all present: for (slot, kwarg) in fastlocals .iter_mut() - .zip(&*code.varnames) + .zip(code.varnames()) .skip(code.arg_count as usize) .take(code.kwonlyarg_count as usize) .filter(|(slot, _)| slot.is_none()) @@ -391,10 +391,24 @@ impl PyFunction { } } - if let Some(cell2arg) = code.cell2arg.as_deref() { - for (cell_idx, arg_idx) in cell2arg.iter().enumerate().filter(|(_, i)| **i != -1) { - let x = fastlocals[*arg_idx as usize].take(); - frame.cells_frees[cell_idx].set(x); + { + let total_args = (code.arg_count + + code.kwonlyarg_count + + code.flags.contains(bytecode::CodeFlags::VARARGS) as u32 + + code.flags.contains(bytecode::CodeFlags::VARKEYWORDS) as u32) + as usize; + let mut cell_idx = 0; + for (i, &kind) in code.localspluskinds[..code.nlocals as usize] + .iter() + .enumerate() + { + if kind.contains(bytecode::LocalKind::CELL) { + if i < total_args { + let arg = fastlocals[i].take(); + frame.cells_frees[cell_idx].set(arg); + } + cell_idx += 1; + } } } @@ -872,11 +886,11 @@ impl Constructor for PyFunction { // Handle closure - must be a tuple of cells let closure = if let Some(closure_tuple) = args.closure { // Check that closure length matches code's free variables - if closure_tuple.len() != args.code.freevars.len() { + if closure_tuple.len() != args.code.nfreevars as usize { return Err(vm.new_value_error(format!( "{} requires closure of length {}, not {}", args.code.obj_name, - args.code.freevars.len(), + args.code.nfreevars, closure_tuple.len() ))); } @@ -884,7 +898,7 @@ impl Constructor for PyFunction { // Validate that all items are cells and create typed tuple let typed_closure = closure_tuple.try_into_typed::(vm)?; Some(typed_closure) - } else if !args.code.freevars.is_empty() { + } else if args.code.nfreevars > 0 { return Err(vm.new_type_error("arg 5 (closure) must be tuple")); } else { None diff --git a/crates/vm/src/builtins/super.rs b/crates/vm/src/builtins/super.rs index b7bc3004332..58d51c79c12 100644 --- a/crates/vm/src/builtins/super.rs +++ b/crates/vm/src/builtins/super.rs @@ -6,7 +6,7 @@ See also [CPython source code.](https://github.com/python/cpython/blob/50b48572d use super::{PyStr, PyType, PyTypeRef}; use crate::{ - AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, bytecode, class::PyClassImpl, common::lock::PyRwLock, function::{FuncArgs, IntoFuncArgs, OptionalArg}, @@ -88,12 +88,14 @@ impl Initializer for PySuper { let obj = frame.fastlocals.lock()[0] .clone() .or_else(|| { - if let Some(cell2arg) = frame.code.cell2arg.as_deref() { - cell2arg[..frame.code.cellvars.len()] - .iter() - .enumerate() - .find(|(_, arg_idx)| **arg_idx == 0) - .and_then(|(cell_idx, _)| frame.cells_frees[cell_idx].get()) + if frame + .code + .localspluskinds + .first() + .map_or(false, |k| k.contains(bytecode::LocalKind::CELL)) + { + // First argument (self) is captured as a cell + frame.cells_frees[0].get() } else { None } @@ -101,9 +103,9 @@ impl Initializer for PySuper { .ok_or_else(|| vm.new_runtime_error("super(): arg[0] deleted"))?; let mut typ = None; - for (i, var) in frame.code.freevars.iter().enumerate() { + for (i, var) in frame.code.freevars().iter().enumerate() { if var.as_bytes() == b"__class__" { - let i = frame.code.cellvars.len() + i; + let i = frame.code.ncellvars as usize + i; let class = frame.cells_frees[i] .get() .ok_or_else(|| vm.new_runtime_error("super(): empty __class__ cell"))?; diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index b7df081a642..16bab636200 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -132,8 +132,8 @@ impl Frame { func_obj: Option, vm: &VirtualMachine, ) -> Self { - let nlocals = code.varnames.len(); - let num_cells = code.cellvars.len(); + let nlocals = code.nlocals as usize; + let num_cells = code.ncellvars as usize; let nfrees = closure.len(); let cells_frees: Box<[PyCellRef]> = @@ -191,9 +191,9 @@ impl Frame { pub fn locals(&self, vm: &VirtualMachine) -> PyResult { let locals = &self.locals; let code = &**self.code; - let map = &code.varnames; - let j = core::cmp::min(map.len(), code.varnames.len()); - if !code.varnames.is_empty() { + let map = code.varnames(); + let j = map.len(); + if code.nlocals > 0 { let fastlocals = self.fastlocals.lock(); for (&k, v) in zip(&map[..j], &**fastlocals) { match locals.mapping().ass_subscript(k, v.clone(), vm) { @@ -203,24 +203,40 @@ impl Frame { } } } - if !code.cellvars.is_empty() || !code.freevars.is_empty() { - let map_to_dict = |keys: &[&PyStrInterned], values: &[PyCellRef]| { - for (&k, v) in zip(keys, values) { - if let Some(value) = v.get() { - locals.mapping().ass_subscript(k, Some(value), vm)?; + if code.ncellvars > 0 || code.nfreevars > 0 { + // Add cell variables + let mut cf_idx = 0; + for (&name, &kind) in code.localsplusnames.iter().zip(code.localspluskinds.iter()) { + if kind.contains(bytecode::LocalKind::CELL) { + if let Some(value) = self.cells_frees[cf_idx].get() { + locals.mapping().ass_subscript(name, Some(value), vm)?; } else { - match locals.mapping().ass_subscript(k, None, vm) { + match locals.mapping().ass_subscript(name, None, vm) { Ok(()) => {} Err(e) if e.fast_isinstance(vm.ctx.exceptions.key_error) => {} Err(e) => return Err(e), } } + cf_idx += 1; + } else if kind.contains(bytecode::LocalKind::FREE) { + cf_idx += 1; } - Ok(()) - }; - map_to_dict(&code.cellvars, &self.cells_frees)?; + } + // Add free variables only in optimized mode if code.flags.contains(bytecode::CodeFlags::OPTIMIZED) { - map_to_dict(&code.freevars, &self.cells_frees[code.cellvars.len()..])?; + let freevars = code.freevars(); + for (i, &name) in freevars.iter().enumerate() { + let cf_idx = code.ncellvars as usize + i; + if let Some(value) = self.cells_frees[cf_idx].get() { + locals.mapping().ass_subscript(name, Some(value), vm)?; + } else { + match locals.mapping().ass_subscript(name, None, vm) { + Ok(()) => {} + Err(e) if e.fast_isinstance(vm.ctx.exceptions.key_error) => {} + Err(e) => return Err(e), + } + } + } } } Ok(locals.clone()) @@ -565,16 +581,33 @@ impl ExecutingFrame<'_> { } fn unbound_cell_exception(&self, i: usize, vm: &VirtualMachine) -> PyBaseExceptionRef { - if let Some(&name) = self.code.cellvars.get(i) { + // Find the i-th cell/free variable name + let mut count = 0; + let mut found_name = None; + for (name, &kind) in self + .code + .localsplusnames + .iter() + .zip(self.code.localspluskinds.iter()) + { + if kind.intersects(bytecode::LocalKind::CELL | bytecode::LocalKind::FREE) { + if count == i { + found_name = Some(name); + break; + } + count += 1; + } + } + let name = found_name.expect("cell/free index out of bounds"); + if i < self.code.ncellvars as usize { vm.new_exception_msg( vm.ctx.exceptions.unbound_local_error.to_owned(), format!("local variable '{name}' referenced before assignment"), ) } else { - let name = self.code.freevars[i - self.code.cellvars.len()]; vm.new_name_error( format!("free variable '{name}' referenced before assignment in enclosing scope"), - name.to_owned(), + (*name).to_owned(), ) } } @@ -824,7 +857,7 @@ impl ExecutingFrame<'_> { vm.ctx.exceptions.unbound_local_error.to_owned(), format!( "local variable '{}' referenced before assignment", - self.code.varnames[idx] + self.code.localsplusnames[idx] ), )); } @@ -1184,11 +1217,23 @@ impl ExecutingFrame<'_> { // Pop dict from stack (locals or classdict depending on context) let class_dict = self.pop_value(); let i = i.get(arg) as usize; - let name = if i < self.code.cellvars.len() { - self.code.cellvars[i] - } else { - self.code.freevars[i - self.code.cellvars.len()] - }; + // Find the i-th cell/free variable name + let mut count = 0; + let mut name = self.code.localsplusnames[0]; // placeholder + for (n, &kind) in self + .code + .localsplusnames + .iter() + .zip(self.code.localspluskinds.iter()) + { + if kind.intersects(bytecode::LocalKind::CELL | bytecode::LocalKind::FREE) { + if count == i { + name = *n; + break; + } + count += 1; + } + } // Only treat KeyError as "not found", propagate other exceptions let value = if let Some(dict_obj) = class_dict.downcast_ref::() { dict_obj.get_item_opt(name, vm)? @@ -1278,7 +1323,7 @@ impl ExecutingFrame<'_> { let idx = idx.get(arg) as usize; let x = self.fastlocals.lock()[idx] .clone() - .ok_or_else(|| reference_error(self.code.varnames[idx], vm))?; + .ok_or_else(|| reference_error(self.code.localsplusnames[idx], vm))?; self.push_value(x); Ok(None) } @@ -1301,7 +1346,7 @@ impl ExecutingFrame<'_> { vm.ctx.exceptions.unbound_local_error.to_owned(), format!( "local variable '{}' referenced before assignment", - self.code.varnames[idx] + self.code.localsplusnames[idx] ), ) })?; @@ -1320,7 +1365,7 @@ impl ExecutingFrame<'_> { vm.ctx.exceptions.unbound_local_error.to_owned(), format!( "local variable '{}' referenced before assignment", - self.code.varnames[idx1] + self.code.localsplusnames[idx1] ), ) })?; @@ -1329,7 +1374,7 @@ impl ExecutingFrame<'_> { vm.ctx.exceptions.unbound_local_error.to_owned(), format!( "local variable '{}' referenced before assignment", - self.code.varnames[idx2] + self.code.localsplusnames[idx2] ), ) })?; @@ -1353,7 +1398,7 @@ impl ExecutingFrame<'_> { vm.ctx.exceptions.unbound_local_error.to_owned(), format!( "local variable '{}' referenced before assignment", - self.code.varnames[idx] + self.code.localsplusnames[idx] ), ) })?; @@ -1371,7 +1416,7 @@ impl ExecutingFrame<'_> { vm.ctx.exceptions.unbound_local_error.to_owned(), format!( "local variable '{}' referenced before assignment", - self.code.varnames[idx1] + self.code.localsplusnames[idx1] ), ) })?; @@ -1380,7 +1425,7 @@ impl ExecutingFrame<'_> { vm.ctx.exceptions.unbound_local_error.to_owned(), format!( "local variable '{}' referenced before assignment", - self.code.varnames[idx2] + self.code.localsplusnames[idx2] ), ) })?; diff --git a/crates/vm/src/stdlib/builtins.rs b/crates/vm/src/stdlib/builtins.rs index 1f14f6f5b04..224b217fe98 100644 --- a/crates/vm/src/stdlib/builtins.rs +++ b/crates/vm/src/stdlib/builtins.rs @@ -377,7 +377,7 @@ mod builtins { Either::B(code_obj) => code_obj, }; - if !code_obj.freevars.is_empty() { + if code_obj.nfreevars > 0 { return Err(vm.new_type_error(format!( "code object passed to {func}() may not contain free variables" )));