diff --git a/.cspell.dict/python-more.txt b/.cspell.dict/python-more.txt index f1e88214ad7..7ea660a5d1f 100644 --- a/.cspell.dict/python-more.txt +++ b/.cspell.dict/python-more.txt @@ -189,6 +189,7 @@ posonlyargcount prepending profilefunc pycache +pycapsule pycodecs pycs pydatetime diff --git a/crates/capi/src/lib.rs b/crates/capi/src/lib.rs index ac7bdf5ef11..9262e4e2e65 100644 --- a/crates/capi/src/lib.rs +++ b/crates/capi/src/lib.rs @@ -14,6 +14,7 @@ pub mod ceval; pub mod import; pub mod longobject; pub mod object; +pub mod pycapsule; pub mod pyerrors; pub mod pylifecycle; pub mod pystate; diff --git a/crates/capi/src/pycapsule.rs b/crates/capi/src/pycapsule.rs new file mode 100644 index 00000000000..0e1b55b617d --- /dev/null +++ b/crates/capi/src/pycapsule.rs @@ -0,0 +1,160 @@ +use crate::PyObject; +use crate::pystate::with_vm; +use core::ffi::{CStr, c_char, c_int, c_void}; +use core::ptr::NonNull; +use rustpython_vm::builtins::PyCapsule; +use rustpython_vm::{PyObjectRef, PyResult, VirtualMachine}; + +#[allow(non_camel_case_types)] +pub type PyCapsule_Destructor = unsafe extern "C" fn(capsule: *mut PyObject); + +#[unsafe(no_mangle)] +pub extern "C" fn PyCapsule_New( + pointer: *mut c_void, + name: *const c_char, + destructor: Option, +) -> *mut PyObject { + with_vm(|vm| { + if pointer.is_null() { + return Err(vm.new_value_error("PyCapsule_New called with null pointer")); + } + let name = NonNull::new(name.cast_mut()).map(|ptr| unsafe { CStr::from_ptr(ptr.as_ptr()) }); + Ok(vm.ctx.new_capsule(pointer, name, destructor)) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_GetPointer( + capsule: *mut PyObject, + name: *const c_char, +) -> *mut c_void { + with_vm(|vm| Ok(checked_capsule(vm, unsafe { &*capsule }, name)?.pointer())) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_GetName(capsule: *mut PyObject) -> *const c_char { + with_vm(|vm| { + let capsule = unsafe { &*capsule } + .downcast_ref_if_exact::(vm) + .ok_or_else(|| vm.new_value_error("Invalid capsule"))?; + Ok(capsule.name().map(CStr::as_ptr).unwrap_or_default()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_GetContext(capsule: *mut PyObject) -> *mut c_void { + with_vm(|vm| { + let capsule = unsafe { &*capsule } + .downcast_ref_if_exact::(vm) + .ok_or_else(|| vm.new_value_error("Invalid capsule"))?; + Ok(capsule.context()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_SetContext( + capsule: *mut PyObject, + context: *mut c_void, +) -> c_int { + with_vm(|vm| { + let capsule = unsafe { &*capsule } + .downcast_ref_if_exact::(vm) + .ok_or_else(|| vm.new_value_error("Invalid capsule"))?; + let _: () = capsule.set_context(context); + Ok(()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_SetPointer( + capsule: *mut PyObject, + pointer: *mut c_void, +) -> c_int { + with_vm(|vm| { + let capsule = unsafe { &*capsule } + .downcast_ref_if_exact::(vm) + .ok_or_else(|| vm.new_value_error("Invalid capsule"))?; + let _: () = capsule.set_pointer(pointer); + Ok(()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_IsValid(capsule: *mut PyObject, name: *const c_char) -> c_int { + with_vm(|vm| { + if capsule.is_null() { + return false; + } + + checked_capsule(vm, unsafe { &*capsule }, name).is_ok() + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn PyCapsule_Import(name: *const c_char, _no_block: c_int) -> *mut c_void { + with_vm(|vm| { + let capsule_name = unsafe { CStr::from_ptr(name) } + .to_str() + .map_err(|_| vm.new_system_error("capsule name is not valid UTF-8"))?; + let (module_name, attrs_path) = capsule_name.split_once('.').ok_or_else(|| { + vm.new_import_error( + "capsule name is missing attribute path", + vm.ctx.new_str(capsule_name), + ) + })?; + let mut obj: PyObjectRef = vm.import(module_name, 0)?; + + for attr in attrs_path.split('.') { + obj = obj.get_attr(attr, vm)?; + } + + Ok(checked_capsule(vm, &obj, name)?.pointer()) + }) +} + +#[inline] +fn names_match(stored_name: *const c_char, expected_name: *const c_char) -> bool { + if stored_name.is_null() || expected_name.is_null() { + stored_name.is_null() && expected_name.is_null() + } else { + unsafe { CStr::from_ptr(stored_name) == CStr::from_ptr(expected_name) } + } +} + +#[inline] +fn checked_capsule<'a>( + vm: &VirtualMachine, + obj: &'a PyObject, + name: *const c_char, +) -> PyResult<&'a PyCapsule> { + let capsule = obj + .downcast_ref_if_exact::(vm) + .ok_or_else(|| vm.new_value_error("Invalid capsule"))?; + + if !names_match(capsule.name().map(CStr::as_ptr).unwrap_or_default(), name) { + return Err(vm.new_value_error("Capsule name does not match")); + } + + if capsule.pointer().is_null() { + return Err(vm.new_value_error("Capsule has null pointer")); + } + + Ok(capsule) +} + +#[cfg(test)] +mod tests { + use pyo3::prelude::*; + use pyo3::types::PyCapsule; + + #[test] + fn test_capsule_new() { + Python::attach(|py| { + let value = String::from("Some data"); + let capsule = PyCapsule::new_with_value(py, value, c"my_capsule").unwrap(); + assert!(capsule.is_valid_checked(Some(c"my_capsule"))); + let ptr = capsule.pointer_checked(Some(c"my_capsule")).unwrap(); + assert_eq!(unsafe { ptr.cast::().as_ref() }, "Some data"); + }) + } +} diff --git a/crates/capi/src/util.rs b/crates/capi/src/util.rs index 9119581902a..6eef9163a5c 100644 --- a/crates/capi/src/util.rs +++ b/crates/capi/src/util.rs @@ -76,6 +76,14 @@ impl FfiResult<*mut c_char> for *const u8 { } } +impl FfiResult for *const c_char { + const ERR_VALUE: *const c_char = core::ptr::null_mut(); + + fn into_output(self, _vm: &VirtualMachine) -> *const c_char { + self + } +} + impl FfiResult for usize { const ERR_VALUE: isize = -1; diff --git a/crates/vm/src/builtins/capsule.rs b/crates/vm/src/builtins/capsule.rs index 33680fb1973..43efa0fb214 100644 --- a/crates/vm/src/builtins/capsule.rs +++ b/crates/vm/src/builtins/capsule.rs @@ -4,7 +4,7 @@ use crate::{ class::PyClassImpl, types::{Destructor, Representable}, }; -use core::ffi::c_void; +use core::ffi::{CStr, c_void}; use core::sync::atomic::AtomicPtr; /// PyCapsule - a container for C pointers. @@ -13,6 +13,8 @@ use core::sync::atomic::AtomicPtr; #[derive(Debug)] pub struct PyCapsule { ptr: AtomicPtr, + context: AtomicPtr, + name: Option<&'static CStr>, destructor: Option, } @@ -27,10 +29,13 @@ impl PyPayload for PyCapsule { impl PyCapsule { pub fn new( ptr: *mut c_void, + name: Option<&'static CStr>, destructor: Option, ) -> Self { Self { ptr: ptr.into(), + context: core::ptr::null_mut::().into(), + name, destructor, } } @@ -39,6 +44,24 @@ impl PyCapsule { self.ptr.load(core::sync::atomic::Ordering::Relaxed) } + pub fn set_pointer(&self, pointer: *mut c_void) { + self.ptr + .store(pointer, core::sync::atomic::Ordering::Relaxed); + } + + pub fn context(&self) -> *mut c_void { + self.context.load(core::sync::atomic::Ordering::Relaxed) + } + + pub fn set_context(&self, context: *mut c_void) { + self.context + .store(context, core::sync::atomic::Ordering::Relaxed); + } + + pub fn name(&self) -> Option<&CStr> { + self.name + } + fn destructor(&self) -> Option { self.destructor } diff --git a/crates/vm/src/vm/context.rs b/crates/vm/src/vm/context.rs index 8d20e750706..4d16c5d8075 100644 --- a/crates/vm/src/vm/context.rs +++ b/crates/vm/src/vm/context.rs @@ -26,6 +26,7 @@ use crate::{ object::{Py, PyObjectPayload, PyObjectRef, PyPayload, PyRef}, types::{PyTypeFlags, PyTypeSlots, TypeZoo}, }; +use core::ffi::{CStr, c_void}; use malachite_bigint::BigInt; use num_complex::Complex64; use num_traits::ToPrimitive; @@ -754,10 +755,11 @@ impl Context { pub fn new_capsule( &self, - ptr: *mut core::ffi::c_void, + ptr: *mut c_void, + name: Option<&'static CStr>, destructor: Option, ) -> PyRef { - PyCapsule::new(ptr, destructor).into_ref(self) + PyCapsule::new(ptr, name, destructor).into_ref(self) } }