Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Implement pickle more
  • Loading branch information
youknowone committed Jan 16, 2026
commit 9efc8d1acc53376f1321f41260859276d98655e1
210 changes: 188 additions & 22 deletions crates/vm/src/builtins/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,15 @@ fn type_slot_names(typ: &Py<PyType>, vm: &VirtualMachine) -> PyResult<Option<sup
Ok(result)
}

// object_getstate_default in CPython
// object_getstate_default
fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine) -> PyResult {
// TODO: itemsize
// if required && obj.class().slots.itemsize > 0 {
// return vm.new_type_error(format!(
// "cannot pickle {:.200} objects",
// obj.class().name()
// ));
// }
// Check itemsize
if required && obj.class().slots.itemsize > 0 {
return Err(vm.new_type_error(format!(
"cannot pickle {:.200} objects",
obj.class().name()
)));
}

let state = if obj.dict().is_none_or(|d| d.is_empty()) {
vm.ctx.none()
Expand All @@ -208,21 +208,23 @@ fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine)
type_slot_names(obj.class(), vm).map_err(|_| vm.new_type_error("cannot pickle object"))?;

if required {
let mut basicsize = obj.class().slots.basicsize;
// if obj.class().slots.dict_offset > 0
// && !obj.class().slots.flags.has_feature(PyTypeFlags::MANAGED_DICT)
// {
// basicsize += std::mem::size_of::<PyObjectRef>();
// }
// if obj.class().slots.weaklist_offset > 0 {
// basicsize += std::mem::size_of::<PyObjectRef>();
// }
// Start with PyBaseObject_Type's basicsize
let mut basicsize = vm.ctx.types.object_type.slots.basicsize;

// Add __dict__ size if type has dict
if obj.class().slots.flags.has_feature(PyTypeFlags::HAS_DICT) {
basicsize += core::mem::size_of::<PyObjectRef>();
}

// Add slots size
if let Some(ref slot_names) = slot_names {
basicsize += core::mem::size_of::<PyObjectRef>() * slot_names.__len__();
}

// Fail if actual type's basicsize > expected basicsize
if obj.class().slots.basicsize > basicsize {
return Err(
vm.new_type_error(format!("cannot pickle {:.200} object", obj.class().name()))
vm.new_type_error(format!("cannot pickle '{}' object", obj.class().name()))
);
}
}
Expand All @@ -249,7 +251,7 @@ fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine)
Ok(state)
}

// object_getstate in CPython
// object_getstate
// fn object_getstate(
// obj: &PyObject,
// required: bool,
Expand Down Expand Up @@ -550,11 +552,175 @@ pub fn init(ctx: &Context) {
PyBaseObject::extend_class(ctx, ctx.types.object_type);
}

/// Get arguments for __new__ from __getnewargs_ex__ or __getnewargs__
/// Returns (args, kwargs) tuple where either can be None
fn get_new_arguments(
obj: &PyObject,
vm: &VirtualMachine,
) -> PyResult<(Option<super::PyTupleRef>, Option<super::PyDictRef>)> {
// First try __getnewargs_ex__
if let Some(getnewargs_ex) = vm.get_special_method(obj, identifier!(vm, __getnewargs_ex__))? {
let newargs = getnewargs_ex.invoke((), vm)?;

let newargs_tuple: PyRef<super::PyTuple> = newargs.downcast().map_err(|obj| {
vm.new_type_error(format!(
"__getnewargs_ex__ should return a tuple, not '{}'",
obj.class().name()
))
})?;

if newargs_tuple.len() != 2 {
return Err(vm.new_value_error(format!(
"__getnewargs_ex__ should return a tuple of length 2, not {}",
newargs_tuple.len()
)));
}

let args = newargs_tuple.as_slice()[0].clone();
let kwargs = newargs_tuple.as_slice()[1].clone();

let args_tuple: PyRef<super::PyTuple> = args.downcast().map_err(|obj| {
vm.new_type_error(format!(
"first item of the tuple returned by __getnewargs_ex__ must be a tuple, not '{}'",
obj.class().name()
))
})?;

let kwargs_dict: PyRef<super::PyDict> = kwargs.downcast().map_err(|obj| {
vm.new_type_error(format!(
"second item of the tuple returned by __getnewargs_ex__ must be a dict, not '{}'",
obj.class().name()
))
})?;

return Ok((Some(args_tuple), Some(kwargs_dict)));
}

// Fall back to __getnewargs__
if let Some(getnewargs) = vm.get_special_method(obj, identifier!(vm, __getnewargs__))? {
let args = getnewargs.invoke((), vm)?;

let args_tuple: PyRef<super::PyTuple> = args.downcast().map_err(|obj| {
vm.new_type_error(format!(
"__getnewargs__ should return a tuple, not '{}'",
obj.class().name()
))
})?;

return Ok((Some(args_tuple), None));
}

// No __getnewargs_ex__ or __getnewargs__
Ok((None, None))
}

/// Check if __getstate__ is overridden by comparing with object.__getstate__
fn is_getstate_overridden(obj: &PyObject, vm: &VirtualMachine) -> bool {
let obj_cls = obj.class();
let object_type = vm.ctx.types.object_type;

// If the class is object itself, not overridden
if obj_cls.is(object_type) {
return false;
}

// Check if __getstate__ in the MRO comes from object or elsewhere
// If the type has its own __getstate__, it's overridden
if let Some(getstate) = obj_cls.get_attr(identifier!(vm, __getstate__))
&& let Some(obj_getstate) = object_type.get_attr(identifier!(vm, __getstate__))
{
return !getstate.is(&obj_getstate);
}
false
}

/// object_getstate - calls __getstate__ method or default implementation
fn object_getstate(obj: &PyObject, required: bool, vm: &VirtualMachine) -> PyResult {
// If __getstate__ is not overridden, use the default implementation with required flag
if !is_getstate_overridden(obj, vm) {
return object_getstate_default(obj, required, vm);
}

// __getstate__ is overridden, call it without required
let getstate = obj.get_attr(identifier!(vm, __getstate__), vm)?;
getstate.call((), vm)
}

/// Get list items iterator if obj is a list (or subclass), None iterator otherwise
fn get_items_iter(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, PyObjectRef)> {
let listitems: PyObjectRef = if obj.fast_isinstance(vm.ctx.types.list_type) {
obj.get_iter(vm)?.into()
} else {
vm.ctx.none()
};

let dictitems: PyObjectRef = if obj.fast_isinstance(vm.ctx.types.dict_type) {
let items = vm.call_method(obj, "items", ())?;
items.get_iter(vm)?.into()
} else {
vm.ctx.none()
};

Ok((listitems, dictitems))
}

/// reduce_newobj - creates reduce tuple for protocol >= 2
fn reduce_newobj(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult {
// Check if type has tp_new
let cls = obj.class();
if cls.slots.new.load().is_none() {
return Err(
vm.new_type_error(format!("cannot pickle '{}' object", cls.name()))
);
}

let (args, kwargs) = get_new_arguments(&obj, vm)?;

let copyreg = vm.import("copyreg", 0)?;

let has_args = args.is_some();

let (newobj, newargs): (PyObjectRef, PyObjectRef) = if kwargs.is_none() || kwargs.as_ref().is_some_and(|k| k.is_empty()) {
// Use copyreg.__newobj__
let newobj = copyreg.get_attr("__newobj__", vm)?;

let args_vec: Vec<PyObjectRef> = args
.map(|a| a.as_slice().to_vec())
.unwrap_or_default();

// Create (cls, *args) tuple
let mut newargs_vec: Vec<PyObjectRef> = vec![cls.to_owned().into()];
newargs_vec.extend(args_vec);
let newargs = vm.ctx.new_tuple(newargs_vec);

(newobj, newargs.into())
} else {
// Use copyreg.__newobj_ex__
let newobj = copyreg.get_attr("__newobj_ex__", vm)?;
let args_tuple: PyObjectRef = args.map(|a| a.into()).unwrap_or_else(|| vm.ctx.empty_tuple.clone().into());
let kwargs_dict: PyObjectRef = kwargs.map(|k| k.into()).unwrap_or_else(|| vm.ctx.new_dict().into());

let newargs = vm.ctx.new_tuple(vec![cls.to_owned().into(), args_tuple, kwargs_dict]);
(newobj, newargs.into())
};

// Determine if state is required
// required = !(has_args || is_list || is_dict)
let is_list = obj.fast_isinstance(vm.ctx.types.list_type);
let is_dict = obj.fast_isinstance(vm.ctx.types.dict_type);
let required = !(has_args || is_list || is_dict);

let state = object_getstate(&obj, required, vm)?;

let (listitems, dictitems) = get_items_iter(&obj, vm)?;

let result = vm.ctx.new_tuple(vec![newobj, newargs, state, listitems, dictitems]);
Ok(result.into())
}

fn common_reduce(obj: PyObjectRef, proto: usize, vm: &VirtualMachine) -> PyResult {
if proto >= 2 {
let reducelib = vm.import("__reducelib", 0)?;
let reduce_2 = reducelib.get_attr("reduce_2", vm)?;
reduce_2.call((obj,), vm)
reduce_newobj(obj, vm)
} else {
let copyreg = vm.import("copyreg", 0)?;
let reduce_ex = copyreg.get_attr("_reduce_ex", vm)?;
Expand Down
124 changes: 122 additions & 2 deletions crates/vm/src/stdlib/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ mod _io {
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult,
TryFromBorrowedObject, TryFromObject,
builtins::{
PyBaseExceptionRef, PyBool, PyByteArray, PyBytes, PyBytesRef, PyMemoryView, PyStr,
PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, PyUtf8StrRef,
PyBaseExceptionRef, PyBool, PyByteArray, PyBytes, PyBytesRef, PyDict, PyMemoryView,
PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, PyUtf8StrRef,
},
class::StaticType,
common::lock::{
Expand Down Expand Up @@ -4077,6 +4077,67 @@ mod _io {
const fn line_buffering(&self) -> bool {
false
}

#[pymethod]
fn __getstate__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
let buffer = zelf.buffer(vm)?;
let content = Wtf8Buf::from_bytes(buffer.getvalue())
.map_err(|_| vm.new_value_error("Error Retrieving Value"))?;
let pos = buffer.tell();
drop(buffer);

// Get __dict__ if it exists and is non-empty
let dict_obj: PyObjectRef = match zelf.as_object().dict() {
Some(d) if !d.is_empty() => d.into(),
_ => vm.ctx.none(),
};

// Return (content, newline, position, dict)
// TODO: store actual newline setting when it's implemented
Ok(vm.ctx.new_tuple(vec![
vm.ctx.new_str(content).into(),
vm.ctx.new_str("\n").into(),
vm.ctx.new_int(pos).into(),
dict_obj,
]))
}

#[pymethod]
fn __setstate__(zelf: PyRef<Self>, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
if state.len() != 4 {
return Err(vm.new_type_error(format!(
"__setstate__ argument should be 4-tuple, got {}",
state.len()
)));
}

let content: PyStrRef = state[0].clone().try_into_value(vm)?;
// state[1] is newline - TODO: use when newline handling is implemented
let pos: u64 = state[2].clone().try_into_value(vm)?;
let dict = &state[3];

// Set content
let raw_bytes = content.as_bytes().to_vec();
*zelf.buffer.write() = BufferedIO::new(Cursor::new(raw_bytes));

// Set position
zelf.buffer(vm)?
.seek(SeekFrom::Start(pos))
.map_err(|err| os_err(vm, err))?;

// Set __dict__ if provided
if !vm.is_none(dict) {
let dict_ref: PyRef<PyDict> = dict.clone().try_into_value(vm)?;
if let Some(obj_dict) = zelf.as_object().dict() {
obj_dict.clear();
for (key, value) in dict_ref.into_iter() {
obj_dict.set_item(&*key, value, vm)?;
}
}
}

Ok(())
}
Comment thread
youknowone marked this conversation as resolved.
}

#[pyattr]
Expand Down Expand Up @@ -4225,6 +4286,65 @@ mod _io {
self.closed.store(true);
Ok(())
}

#[pymethod]
fn __getstate__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
let buffer = zelf.buffer(vm)?;
let content = buffer.getvalue();
let pos = buffer.tell();
drop(buffer);

// Get __dict__ if it exists and is non-empty
let dict_obj: PyObjectRef = match zelf.as_object().dict() {
Some(d) if !d.is_empty() => d.into(),
_ => vm.ctx.none(),
};

// Return (content, position, dict)
Ok(vm.ctx.new_tuple(vec![
vm.ctx.new_bytes(content).into(),
vm.ctx.new_int(pos).into(),
dict_obj,
]))
}

#[pymethod]
fn __setstate__(zelf: PyRef<Self>, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
if zelf.closed.load() {
return Err(vm.new_value_error("__setstate__ on closed file"));
}
if state.len() != 3 {
return Err(vm.new_type_error(format!(
"__setstate__ argument should be 3-tuple, got {}",
state.len()
)));
}

let content: PyBytesRef = state[0].clone().try_into_value(vm)?;
let pos: u64 = state[1].clone().try_into_value(vm)?;
let dict = &state[2];

// Set content
*zelf.buffer.write() = BufferedIO::new(Cursor::new(content.as_bytes().to_vec()));

// Set position
zelf.buffer(vm)?
.seek(SeekFrom::Start(pos))
.map_err(|err| os_err(vm, err))?;

// Set __dict__ if provided
if !vm.is_none(dict) {
let dict_ref: PyRef<PyDict> = dict.clone().try_into_value(vm)?;
if let Some(obj_dict) = zelf.as_object().dict() {
obj_dict.clear();
for (key, value) in dict_ref.into_iter() {
obj_dict.set_item(&*key, value, vm)?;
}
}
}

Ok(())
}
}

#[pyclass]
Expand Down
1 change: 1 addition & 0 deletions crates/vm/src/vm/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ declare_const_name! {
__getformat__,
__getitem__,
__getnewargs__,
__getnewargs_ex__,
__getstate__,
__gt__,
__hash__,
Expand Down