Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Apply TypeDataSlot to ctypes
  • Loading branch information
youknowone committed Dec 11, 2025
commit 7a2b967abc807253230396f6200fc77a85fffb40
58 changes: 29 additions & 29 deletions crates/vm/src/stdlib/ctypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,39 +389,35 @@ pub(crate) mod _ctypes {
/// Get the size of a ctypes type or instance
#[pyfunction(name = "sizeof")]
pub fn size_of(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
use super::array::{PyCArray, PyCArrayType};
use super::pointer::PyCPointer;
use super::structure::{PyCStructType, PyCStructure};
use super::union::{PyCUnion, PyCUnionType};
use super::union::PyCUnionType;
use super::util::StgInfo;
use crate::builtins::PyType;

// 1. Instances with stg_info
if obj.fast_isinstance(PyCArray::static_type()) {
// Get stg_info from the type
if let Some(type_obj) = obj.class().as_object().downcast_ref::<PyCArrayType>() {
return Ok(type_obj.stg_info.size);
}
// 1. Check TypeDataSlot on class (for instances)
if let Some(stg_info) = obj.class().get_type_data::<StgInfo>() {
return Ok(stg_info.size);
}

// 2. Check TypeDataSlot on type itself (for type objects)
if let Some(type_obj) = obj.downcast_ref::<PyType>()
&& let Some(stg_info) = type_obj.get_type_data::<StgInfo>()
{
return Ok(stg_info.size);
}

// 3. Instances with cdata buffer
if let Some(structure) = obj.downcast_ref::<PyCStructure>() {
return Ok(structure.cdata.read().size());
}
if obj.fast_isinstance(PyCUnion::static_type()) {
// Get stg_info from the type
if let Some(type_obj) = obj.class().as_object().downcast_ref::<PyCUnionType>() {
return Ok(type_obj.stg_info.size);
}
}
if let Some(simple) = obj.downcast_ref::<PyCSimple>() {
return Ok(simple.cdata.read().size());
}
if obj.fast_isinstance(PyCPointer::static_type()) {
return Ok(std::mem::size_of::<usize>());
}

// 2. Types (metatypes with stg_info)
if let Some(array_type) = obj.downcast_ref::<PyCArrayType>() {
return Ok(array_type.stg_info.size);
}

// 3. Type objects
if let Ok(type_ref) = obj.clone().downcast::<crate::builtins::PyType>() {
// Structure types - check if metaclass is or inherits from PyCStructType
Expand Down Expand Up @@ -659,33 +655,37 @@ pub(crate) mod _ctypes {

#[pyfunction]
fn alignment(tp: Either<PyTypeRef, PyObjectRef>, vm: &VirtualMachine) -> PyResult<usize> {
use super::array::{PyCArray, PyCArrayType};
use super::base::PyCSimpleType;
use super::pointer::PyCPointer;
use super::structure::PyCStructure;
use super::union::PyCUnion;
use super::util::StgInfo;
use crate::builtins::PyType;

let obj = match &tp {
Either::A(t) => t.as_object(),
Either::B(o) => o.as_ref(),
};

// Try to get alignment from stg_info directly (for instances)
if let Some(array_type) = obj.downcast_ref::<PyCArrayType>() {
return Ok(array_type.stg_info.align);
// 1. Check TypeDataSlot on class (for instances)
if let Some(stg_info) = obj.class().get_type_data::<StgInfo>() {
return Ok(stg_info.align);
}

// 2. Check TypeDataSlot on type itself (for type objects)
if let Some(type_obj) = obj.downcast_ref::<PyType>()
&& let Some(stg_info) = type_obj.get_type_data::<StgInfo>()
{
return Ok(stg_info.align);
}

// 3. Fallback for simple types without TypeDataSlot
if obj.fast_isinstance(PyCSimple::static_type()) {
// Get stg_info from the type by reading _type_ attribute
let cls = obj.class().to_owned();
let stg_info = PyCSimpleType::get_stg_info(&cls, vm);
return Ok(stg_info.align);
}
if obj.fast_isinstance(PyCArray::static_type()) {
// Get stg_info from the type
if let Some(type_obj) = obj.class().as_object().downcast_ref::<PyCArrayType>() {
return Ok(type_obj.stg_info.align);
}
}
if obj.fast_isinstance(PyCStructure::static_type()) {
// Calculate alignment from _fields_
let cls = obj.class();
Expand Down
167 changes: 87 additions & 80 deletions crates/vm/src/stdlib/ctypes/array.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use crate::atomic_func;
use crate::builtins::{PyBytes, PyInt};
use crate::convert::ToPyObject;
use crate::class::StaticType;
use crate::function::FuncArgs;
use crate::protocol::{
BufferDescriptor, BufferMethods, PyBuffer, PyNumberMethods, PySequenceMethods,
};
use crate::stdlib::ctypes::base::CDataObject;
use crate::stdlib::ctypes::util::StgInfo;
use crate::types::{AsBuffer, AsNumber, AsSequence, Callable};
use crate::types::{AsBuffer, AsNumber, AsSequence};
use crate::{AsObject, Py, PyObjectRef, PyPayload};
use crate::{
PyResult, VirtualMachine,
Expand All @@ -20,56 +20,49 @@ use rustpython_common::lock::PyRwLock;
use rustpython_vm::stdlib::ctypes::_ctypes::get_size;
use rustpython_vm::stdlib::ctypes::base::PyCData;

/// PyCArrayType - metatype for Array types
/// CPython stores array info (type, length) in StgInfo via type_data
#[pyclass(name = "PyCArrayType", base = PyType, module = "_ctypes")]
#[derive(PyPayload)]
pub struct PyCArrayType {
pub(super) stg_info: StgInfo,
pub(super) typ: PyRwLock<PyObjectRef>,
pub(super) length: AtomicCell<usize>,
pub(super) element_size: AtomicCell<usize>,
}

impl std::fmt::Debug for PyCArrayType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PyCArrayType")
.field("typ", &self.typ)
.field("length", &self.length)
.finish()
}
}

impl Callable for PyCArrayType {
type Args = FuncArgs;
fn call(zelf: &Py<Self>, args: Self::Args, vm: &VirtualMachine) -> PyResult {
// Create an instance of the array
let element_type = zelf.typ.read().clone();
let length = zelf.length.load();
let element_size = zelf.element_size.load();
let total_size = element_size * length;
let mut buffer = vec![0u8; total_size];

// Initialize from positional arguments
for (i, value) in args.args.iter().enumerate() {
if i >= length {
break;
}
let offset = i * element_size;
if let Ok(int_val) = value.try_int(vm) {
let bytes = PyCArray::int_to_bytes(int_val.as_bigint(), element_size);
if offset + element_size <= buffer.len() {
buffer[offset..offset + element_size].copy_from_slice(&bytes);
}
}
#[derive(Debug, Default, PyPayload)]
pub struct PyCArrayType {}

/// Create a new Array type with StgInfo stored in type_data (CPython style)
pub fn create_array_type_with_stg_info(stg_info: StgInfo, vm: &VirtualMachine) -> PyResult {
// Get PyCArrayType as metaclass
let metaclass = PyCArrayType::static_type().to_owned();

// Create a unique name for the array type
let type_name = format!("Array_{}", stg_info.length);

// Create args for type(): (name, bases, dict)
let name = vm.ctx.new_str(type_name);
let bases = vm
.ctx
.new_tuple(vec![PyCArray::static_type().to_owned().into()]);
let dict = vm.ctx.new_dict();

let args = FuncArgs::new(
vec![name.into(), bases.into(), dict.into()],
crate::function::KwArgs::default(),
);

// Create the new type using PyType::slot_new with PyCArrayType as metaclass
let new_type = crate::builtins::type_::PyType::slot_new(metaclass, args, vm)?;

// Set StgInfo in type_data
let type_ref: PyTypeRef = new_type
.clone()
.downcast()
.map_err(|_| vm.new_type_error("Failed to create array type".to_owned()))?;

if type_ref.init_type_data(stg_info.clone()).is_err() {
// Type data already initialized - update it
if let Some(mut existing) = type_ref.get_type_data_mut::<StgInfo>() {
*existing = stg_info;
}

Ok(PyCArray {
typ: PyRwLock::new(element_type),
length: AtomicCell::new(length),
element_size: AtomicCell::new(element_size),
cdata: PyRwLock::new(CDataObject::from_bytes(buffer, None)),
}
.into_pyobject(vm))
}

Ok(new_type)
}

impl Constructor for PyCArrayType {
Expand All @@ -80,54 +73,62 @@ impl Constructor for PyCArrayType {
}
}

#[pyclass(flags(IMMUTABLETYPE), with(Callable, Constructor, AsNumber))]
#[pyclass(flags(IMMUTABLETYPE), with(Constructor, AsNumber))]
impl PyCArrayType {
#[pygetset(name = "_type_")]
fn typ(&self) -> PyObjectRef {
self.typ.read().clone()
fn typ(zelf: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
zelf.downcast_ref::<PyType>()
.and_then(|t| t.get_type_data::<StgInfo>())
.and_then(|stg| stg.element_type.clone())
.unwrap_or_else(|| vm.ctx.none())
}

#[pygetset(name = "_length_")]
fn length(&self) -> usize {
self.length.load()
fn length(zelf: PyObjectRef) -> usize {
zelf.downcast_ref::<PyType>()
.and_then(|t| t.get_type_data::<StgInfo>())
.map(|stg| stg.length)
.unwrap_or(0)
}

#[pymethod]
fn __mul__(zelf: &Py<Self>, n: isize, vm: &VirtualMachine) -> PyResult {
fn __mul__(zelf: PyObjectRef, n: isize, vm: &VirtualMachine) -> PyResult {
if n < 0 {
return Err(vm.new_value_error(format!("Array length must be >= 0, not {n}")));
}
// Create a nested array type: (inner_type * inner_length) * n
// The new array has n elements, each element is the current array type
// e.g., (c_int * 5) * 3 = Array of 3 elements, each is (c_int * 5)
let inner_length = zelf.length.load();
let inner_element_size = zelf.element_size.load();

// Get inner array info from TypeDataSlot
let type_ref = zelf.downcast_ref::<PyType>().unwrap();
let (_inner_length, inner_size) = type_ref
.get_type_data::<StgInfo>()
.map(|stg| (stg.length, stg.size))
.unwrap_or((0, 0));

// The element type of the new array is the current array type itself
let current_array_type: PyObjectRef = zelf.as_object().to_owned();
let current_array_type: PyObjectRef = zelf.clone();

// Element size is the total size of the inner array
let new_element_size = inner_length * inner_element_size;
let new_element_size = inner_size;
let total_size = new_element_size * (n as usize);
let stg_info = StgInfo::new(total_size, inner_element_size);

Ok(PyCArrayType {
stg_info,
typ: PyRwLock::new(current_array_type),
length: AtomicCell::new(n as usize),
element_size: AtomicCell::new(new_element_size),
}
.to_pyobject(vm))
let stg_info = StgInfo::new_array(
total_size,
new_element_size,
n as usize,
current_array_type,
new_element_size,
);

create_array_type_with_stg_info(stg_info, vm)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

#[pyclassmethod]
fn in_dll(
zelf: &Py<Self>,
zelf: PyObjectRef,
dll: PyObjectRef,
name: crate::builtins::PyStrRef,
vm: &VirtualMachine,
) -> PyResult {
use crate::stdlib::ctypes::_ctypes::size_of;
use libloading::Symbol;

// Get the library handle from dll object
Expand Down Expand Up @@ -168,10 +169,18 @@ impl PyCArrayType {
return Err(vm.new_attribute_error("Library is closed".to_owned()));
};

// Get size from the array type
let element_type = zelf.typ.read().clone();
let length = zelf.length.load();
let element_size = size_of(element_type.clone(), vm)?;
// Get size from the array type via TypeDataSlot
let type_ref = zelf.downcast_ref::<PyType>().unwrap();
let (element_type, length, element_size) = type_ref
.get_type_data::<StgInfo>()
.map(|stg| {
(
stg.element_type.clone().unwrap_or_else(|| vm.ctx.none()),
stg.length,
stg.element_size,
)
})
.unwrap_or_else(|| (vm.ctx.none(), 0, 0));
let total_size = element_size * length;
Comment on lines +172 to 184
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Same unwrap issue in in_dll.

Line 173 has the same .unwrap() pattern on downcast_ref::<PyType>(). Apply the same fix as suggested for __mul__.

         // Get size from the array type via TypeDataSlot
-        let type_ref = zelf.downcast_ref::<PyType>().unwrap();
+        let type_ref = zelf.downcast_ref::<PyType>().ok_or_else(|| {
+            vm.new_type_error("in_dll requires a type object".to_owned())
+        })?;
         let (element_type, length, element_size) = type_ref
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Get size from the array type via TypeDataSlot
let type_ref = zelf.downcast_ref::<PyType>().unwrap();
let (element_type, length, element_size) = type_ref
.get_type_data::<StgInfo>()
.map(|stg| {
(
stg.element_type.clone().unwrap_or_else(|| vm.ctx.none()),
stg.length,
stg.element_size,
)
})
.unwrap_or_else(|| (vm.ctx.none(), 0, 0));
let total_size = element_size * length;
// Get size from the array type via TypeDataSlot
let type_ref = zelf.downcast_ref::<PyType>().ok_or_else(|| {
vm.new_type_error("in_dll requires a type object".to_owned())
})?;
let (element_type, length, element_size) = type_ref
.get_type_data::<StgInfo>()
.map(|stg| {
(
stg.element_type.clone().unwrap_or_else(|| vm.ctx.none()),
stg.length,
stg.element_size,
)
})
.unwrap_or_else(|| (vm.ctx.none(), 0, 0));
let total_size = element_size * length;
🤖 Prompt for AI Agents
crates/vm/src/stdlib/ctypes/array.rs around lines 172-184: the code currently
calls zelf.downcast_ref::<PyType>().unwrap(), which can panic; replace the
unwrap with the same defensive pattern used in __mul__: check
downcast_ref::<PyType>() and if it returns None return a Python TypeError (or
appropriate Err) from the function with a clear message instead of panicking,
then proceed to read the TypeDataSlot when you have a valid type_ref.


// Read data from symbol address
Expand Down Expand Up @@ -206,15 +215,13 @@ impl AsNumber for PyCArrayType {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
multiply: Some(|a, b, vm| {
let zelf = a
.downcast_ref::<PyCArrayType>()
.ok_or_else(|| vm.new_type_error("expected PyCArrayType".to_owned()))?;
// a is a type object whose metaclass is PyCArrayType (e.g., Array_5)
let n = b
.try_index(vm)?
.as_bigint()
.to_isize()
.ok_or_else(|| vm.new_overflow_error("array size too large".to_owned()))?;
PyCArrayType::__mul__(zelf, n, vm)
PyCArrayType::__mul__(a.to_owned(), n, vm)
}),
..PyNumberMethods::NOT_IMPLEMENTED
};
Expand Down
25 changes: 11 additions & 14 deletions crates/vm/src/stdlib/ctypes/base.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use super::_ctypes::bytes_to_pyobject;
use super::array::PyCArrayType;
use super::util::StgInfo;
use crate::builtins::{PyBytes, PyFloat, PyInt, PyNone, PyStr, PyStrRef, PyType, PyTypeRef};
use crate::convert::ToPyObject;
use crate::function::{ArgBytesLike, Either, FuncArgs, KwArgs, OptionalArg};
use crate::protocol::{BufferDescriptor, BufferMethods, PyBuffer, PyNumberMethods};
use crate::stdlib::ctypes::_ctypes::new_simple_type;
Expand Down Expand Up @@ -231,10 +229,7 @@ impl PyCData {

#[pyclass(module = "_ctypes", name = "PyCSimpleType", base = PyType)]
#[derive(Debug, PyPayload, Default)]
pub struct PyCSimpleType {
#[allow(dead_code)]
pub stg_info: StgInfo,
}
pub struct PyCSimpleType {}

#[pyclass(flags(BASETYPE), with(AsNumber))]
impl PyCSimpleType {
Expand Down Expand Up @@ -747,6 +742,8 @@ impl PyCSimple {
#[pyclassmethod]
fn repeat(cls: PyTypeRef, n: isize, vm: &VirtualMachine) -> PyResult {
use super::_ctypes::get_size;
use super::array::create_array_type_with_stg_info;

if n < 0 {
return Err(vm.new_value_error(format!("Array length must be >= 0, not {n}")));
}
Expand All @@ -766,14 +763,14 @@ impl PyCSimple {
std::mem::size_of::<usize>()
};
let total_size = element_size * (n as usize);
let stg_info = super::util::StgInfo::new(total_size, element_size);
Ok(PyCArrayType {
stg_info,
typ: PyRwLock::new(cls.clone().into()),
length: AtomicCell::new(n as usize),
element_size: AtomicCell::new(element_size),
}
.to_pyobject(vm))
let stg_info = super::util::StgInfo::new_array(
total_size,
element_size,
n as usize,
cls.clone().into(),
element_size,
);
create_array_type_with_stg_info(stg_info, vm)
}

#[pyclassmethod]
Expand Down
Loading
Loading