-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Implement collections.defaultdict in rust
#8132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,18 +7,19 @@ mod _collections { | |
| atomic_func, | ||
| builtins::{ | ||
| IterStatus::{Active, Exhausted}, | ||
| PositionIterInternal, PyGenericAlias, PyInt, PyStr, PyType, PyTypeRef, | ||
| PositionIterInternal, PyDict, PyGenericAlias, PyInt, PyStr, PyType, PyTypeRef, | ||
| }, | ||
| common::lock::{PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}, | ||
| function::{KwArgs, OptionalArg, PyComparisonValue}, | ||
| convert::ToPyObject, | ||
| function::{FuncArgs, KwArgs, OptionalArg, PyComparisonValue}, | ||
| iter::PyExactSizeIterator, | ||
| protocol::{PyIterReturn, PyNumberMethods, PySequenceMethods}, | ||
| protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods}, | ||
| recursion::ReprGuard, | ||
| sequence::{MutObjectSequenceOp, OptionalRangeArgs}, | ||
| sliceable::SequenceIndexOp, | ||
| types::{ | ||
| AsNumber, AsSequence, Comparable, Constructor, DefaultConstructor, Initializer, | ||
| IterNext, Iterable, PyComparisonOp, Representable, SelfIter, | ||
| AsMapping, AsNumber, AsSequence, Comparable, Constructor, DefaultConstructor, | ||
| Initializer, IterNext, Iterable, PyComparisonOp, Representable, SelfIter, | ||
| }, | ||
| utils::collection_repr, | ||
| }; | ||
|
|
@@ -746,4 +747,200 @@ mod _collections { | |
| }) | ||
| } | ||
| } | ||
|
|
||
| #[pyattr] | ||
| #[pyclass( | ||
| module = "collections", | ||
| name = "defaultdict", | ||
| base = PyDict, | ||
| unhashable = true | ||
| )] | ||
| #[derive(Debug, Default)] | ||
| struct PyDefaultDict { | ||
| dict: PyDict, | ||
| default_factory: PyRwLock<Option<PyObjectRef>>, | ||
| } | ||
|
|
||
| #[pyclass( | ||
| with(AsMapping, AsNumber, Constructor, Initializer, Representable), | ||
| flags(BASETYPE, MAPPING, HAS_DICT) | ||
| )] | ||
| impl PyDefaultDict { | ||
| #[pygetset] | ||
| fn default_factory(&self) -> Option<PyObjectRef> { | ||
| self.default_factory.read().clone() | ||
| } | ||
|
|
||
| #[pygetset(name = "default_factory", setter)] | ||
| fn default_factory_setter(&self, value: PyObjectRef, vm: &VirtualMachine) { | ||
| *self.default_factory.write() = if value.is(&vm.ctx.none()) { | ||
| None | ||
| } else { | ||
| Some(value) | ||
| }; | ||
| } | ||
|
|
||
| #[pymethod] | ||
| fn __missing__(&self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult { | ||
| let factory = self.default_factory.read().clone(); | ||
| if let Some(f) = factory { | ||
| let value = f.call((), vm)?; | ||
| self.dict.setdefault(key, value.into(), vm) | ||
| } else { | ||
| Err(vm.new_key_error(key)) | ||
| } | ||
| } | ||
|
|
||
| #[pymethod] | ||
| #[pymethod(name = "__copy__")] | ||
| fn copy(&self) -> Self { | ||
| let default_factory = self.default_factory.read().clone(); | ||
|
|
||
| Self { | ||
| dict: self.dict.copy(), | ||
| default_factory: PyRwLock::new(default_factory), | ||
| } | ||
| } | ||
|
|
||
| #[pymethod] | ||
| fn __reduce__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult { | ||
| let cls = zelf.class().to_owned(); | ||
|
|
||
| let factory_tuple = match &*zelf.default_factory.read() { | ||
| Some(f) => vm.ctx.new_tuple(vec![f.clone()]), | ||
| None => vm.ctx.new_tuple(vec![]), | ||
| }; | ||
|
|
||
| let items_fn = zelf.as_object().get_attr("items", vm)?; | ||
| let items_iter = items_fn.call((), vm)?; | ||
| let iter = items_iter.get_iter(vm)?; | ||
| let none = vm.ctx.none(); | ||
|
|
||
| Ok(vm | ||
| .ctx | ||
| .new_tuple(vec![ | ||
| cls.into(), | ||
| factory_tuple.into(), | ||
| none.clone(), | ||
| none, | ||
| iter.into(), | ||
| ]) | ||
| .into()) | ||
| } | ||
| } | ||
|
|
||
| impl PyDefaultDict { | ||
| fn __or__(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { | ||
| Ok(if let Some(zelf) = lhs.downcast_ref::<Self>() { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bschoenmaeckers I'd love to get your review on this as well if you don't mind. I had lots of issues with implementing this correctly for some reason, I'm sure it can be simplifiers |
||
| if !rhs.fast_isinstance(vm.ctx.types.dict_type) { | ||
| vm.ctx.not_implemented.clone().into() | ||
| } else { | ||
| let default_factory = zelf.default_factory.read().clone(); | ||
| let dict = zelf.dict.copy(); | ||
|
|
||
| dict.update(rhs.into(), KwArgs::default(), vm)?; | ||
|
|
||
| Self { | ||
| dict, | ||
| default_factory: PyRwLock::new(default_factory), | ||
| } | ||
| .to_pyobject(vm) | ||
| } | ||
| } else if let Some(zelf) = rhs.downcast_ref::<Self>() { | ||
| let default_factory = zelf.default_factory.read().clone(); | ||
| if let Some(dict) = lhs.downcast_ref::<PyDict>() { | ||
| let dict = dict.copy(); | ||
| dict.update(rhs.into(), KwArgs::default(), vm)?; | ||
|
|
||
| Self { | ||
| dict, | ||
| default_factory: PyRwLock::new(default_factory), | ||
| } | ||
| .to_pyobject(vm) | ||
| } else { | ||
| vm.ctx.not_implemented.clone().into() | ||
| } | ||
| } else { | ||
| return Err(vm.new_type_error(format!( | ||
| "unsupported operand type(s) for |: '{}' and '{}'", | ||
| lhs.class().name(), | ||
| rhs.class().name() | ||
| ))); | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| impl DefaultConstructor for PyDefaultDict {} | ||
|
|
||
| impl Initializer for PyDefaultDict { | ||
| type Args = FuncArgs; | ||
|
|
||
| fn init(zelf: PyRef<Self>, mut args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { | ||
| let default_factory = args.take_positional().map_or(Ok(None), |factory| { | ||
| let is_none = factory.is(&vm.ctx.none()); | ||
|
|
||
| if !is_none && !factory.is_callable() { | ||
| Err(vm.new_type_error("first argument must be callable or None")) | ||
| } else if is_none { | ||
| Ok(None) | ||
| } else { | ||
| Ok(Some(factory)) | ||
| } | ||
| })?; | ||
|
|
||
| *zelf.default_factory.write() = default_factory; | ||
|
|
||
| zelf.dict.update( | ||
| OptionalArg::from_option(args.take_positional()), | ||
| args.kwargs, | ||
| vm, | ||
| )?; | ||
|
Comment on lines
+878
to
+897
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Inspect how Initializer(FuncArgs) callers enforce full argument consumption.
# 1) Find other Initializer impls with FuncArgs and how they reject extra args.
rg -n -C3 'impl Initializer for .*{[[:space:][:print:]]*type Args = FuncArgs;' crates/vm/src
# 2) Look for explicit checks of leftover positional args after take_positional().
rg -n -C3 'take_positional\(\)|remaining|is_empty\(\)|too many positional|takes .* positional' crates/vm/src/stdlib crates/vm/src/builtins crates/vm/src/function
# 3) Inspect defaultdict init call path for automatic leftover-arg validation.
rg -n -C3 'trait Initializer|fn init\(|py_new\(|FuncArgs' crates/vm/srcRepository: RustPython/RustPython Length of output: 50569 🏁 Script executed: # Find the Initializer trait definition
rg -n 'trait Initializer' crates/vm/src --type rust -A 10
# Look at how py_new/init is called and if remaining args are validated
rg -n 'fn init\(' crates/vm/src/stdlib/_collections.rs -B 2 -A 15
# Search for other defaultdict-like Initializer impls that might show the pattern
rg -n 'impl Initializer' crates/vm/src/stdlib crates/vm/src/builtins --type rust -A 8 | head -100Repository: RustPython/RustPython Length of output: 8999 🏁 Script executed: # See the complete defaultdict.__init__ implementation
sed -n '876,897p' crates/vm/src/stdlib/_collections.rs
# Check what happens in slot_init wrapper - does it validate remaining args?
sed -n '1647,1680p' crates/vm/src/types/slot.rs
# See how dict.update handles the args
rg -n 'fn update' crates/vm/src/stdlib/_collections.rs -A 20 | head -60
# Check if FuncArgs has validation or if take_positional enforces limits
rg -n 'fn take_positional|impl FuncArgs' crates/vm/src/function.rs -A 5 | head -40Repository: RustPython/RustPython Length of output: 2164 🏁 Script executed: # Find FuncArgs definition and bind method
rg -n 'struct FuncArgs|impl FuncArgs' crates/vm/src --type rust -A 3 | head -50
# Find take_positional method definition
rg -n 'fn take_positional' crates/vm/src --type rust -B 2 -A 8
# Check the bind method that converts FuncArgs
rg -n 'fn bind' crates/vm/src --type rust -B 2 -A 10 | head -80Repository: RustPython/RustPython Length of output: 4589 🏁 Script executed: # Check FromArgs implementation for FuncArgs
rg -n 'impl FromArgs for FuncArgs|fn from_args' crates/vm/src/function/argument.rs -A 15
# Also check if there's a FromArgs impl that validates
rg -n 'trait FromArgs' crates/vm/src -A 8 | head -40Repository: RustPython/RustPython Length of output: 5310 🏁 Script executed: # Check if defaultdict uses slot_init or a custom init
rg -n 'fn slot_init|fn init' crates/vm/src/stdlib/_collections.rs -B 3 -A 1
# Also verify how slot_init is invoked - does it check leftover args AFTER init?
sed -n '1652,1690p' crates/vm/src/types/slot.rsRepository: RustPython/RustPython Length of output: 1909 defaultdict.init silently accepts extra positional arguments The implementation does not reject positional arguments beyond To match CPython's behavior, add a check after the second Example fixzelf.dict.update(
OptionalArg::from_option(args.take_positional()),
args.kwargs,
vm,
)?;
if !args.args.is_empty() {
return Err(vm.new_type_error(format!(
"defaultdict expected at most 2 positional arguments, got {}",
2 + args.args.len()
)));
}🤖 Prompt for AI Agents |
||
|
|
||
| Ok(()) | ||
| } | ||
| } | ||
|
|
||
| impl Representable for PyDefaultDict { | ||
| fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> { | ||
| let default_factory = zelf.default_factory.read(); | ||
|
|
||
| let factory_repr = match default_factory.as_ref() { | ||
| Some(factory) => { | ||
| if let Some(_guard) = ReprGuard::enter(vm, factory) { | ||
| factory.repr(vm)?.to_string() | ||
| } else { | ||
| String::from("...") | ||
| } | ||
| } | ||
| None => String::from("None"), | ||
| }; | ||
|
|
||
| let dict_repr = Representable::repr(&zelf.dict.copy().into_ref(&vm.ctx), vm)?; | ||
|
|
||
| Ok(format!( | ||
| "{}({}, {})", | ||
| zelf.class().name(), | ||
| factory_repr, | ||
| dict_repr | ||
| )) | ||
| } | ||
| } | ||
|
|
||
| impl AsMapping for PyDefaultDict { | ||
| fn as_mapping() -> &'static PyMappingMethods { | ||
| PyDict::as_mapping() | ||
| } | ||
| } | ||
|
|
||
| impl AsNumber for PyDefaultDict { | ||
| fn as_number() -> &'static PyNumberMethods { | ||
| static AS_NUMBER: PyNumberMethods = PyNumberMethods { | ||
| or: Some(|a, b, vm| { | ||
| PyDefaultDict::__or__(a.to_pyobject(vm), b.to_pyobject(vm), vm) | ||
| }), | ||
| ..PyNumberMethods::NOT_IMPLEMENTED | ||
| }; | ||
| &AS_NUMBER | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__missing__should assign unconditionally, not viasetdefaultLine 788 uses
setdefault, which can return a pre-inserted value if the factory mutates the dict re-entrantly.defaultdict.__missing__must store and return the factory result.Proposed fix
#[pymethod] fn __missing__(&self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult { let factory = self.default_factory.read().clone(); if let Some(f) = factory { let value = f.call((), vm)?; - self.dict.setdefault(key, value.into(), vm) + self.dict.inner_setitem(&*key, value.clone(), vm)?; + Ok(value) } else { Err(vm.new_key_error(key)) } }📝 Committable suggestion
🤖 Prompt for AI Agents