Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Lib/collections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@
try:
from _collections import defaultdict
except ImportError:
# TODO: RUSTPYTHON - implement defaultdict in Rust
from ._defaultdict import defaultdict
pass

heapq = None # Lazily imported

Expand Down
62 changes: 0 additions & 62 deletions Lib/collections/_defaultdict.py

This file was deleted.

2 changes: 0 additions & 2 deletions Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,6 @@ class C:
self.assertIsNot(d['f'], t)
self.assertEqual(d['f'].my_a(), 6)

@unittest.expectedFailure # TODO: RUSTPYTHON
def test_helper_asdict_defaultdict(self):
# Ensure asdict() does not throw exceptions when a
# defaultdict is a member of a dataclass
Expand Down Expand Up @@ -1938,7 +1937,6 @@ class C:
t = astuple(c, tuple_factory=list)
self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])

@unittest.expectedFailure # TODO: RUSTPYTHON
def test_helper_astuple_defaultdict(self):
# Ensure astuple() does not throw exceptions when a
# defaultdict is a member of a dataclass
Expand Down
4 changes: 2 additions & 2 deletions crates/vm/src/builtins/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ impl PyDict {
}

#[pymethod]
fn setdefault(
pub(crate) fn setdefault(
&self,
key: PyObjectRef,
default: OptionalArg<PyObjectRef>,
Expand All @@ -406,7 +406,7 @@ impl PyDict {
}

#[pymethod]
fn update(
pub(crate) fn update(
&self,
dict_obj: OptionalArg<PyObjectRef>,
kwargs: KwArgs,
Expand Down
207 changes: 202 additions & 5 deletions crates/vm/src/stdlib/_collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@ mod _collections {
atomic_func,
builtins::{
IterStatus::{Active, Exhausted},
PositionIterInternal, PyGenericAlias, PyInt, PyStr, PyType, PyTypeRef,
PositionIterInternal, PyDict, PyGenericAlias, PyInt, PyStr, PyType, PyTypeRef,
},
common::lock::{PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard},
function::{KwArgs, OptionalArg, PyComparisonValue},
convert::ToPyObject,
function::{FuncArgs, KwArgs, OptionalArg, PyComparisonValue},
iter::PyExactSizeIterator,
protocol::{PyIterReturn, PyNumberMethods, PySequenceMethods},
protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods},
recursion::ReprGuard,
sequence::{MutObjectSequenceOp, OptionalRangeArgs},
sliceable::SequenceIndexOp,
types::{
AsNumber, AsSequence, Comparable, Constructor, DefaultConstructor, Initializer,
IterNext, Iterable, PyComparisonOp, Representable, SelfIter,
AsMapping, AsNumber, AsSequence, Comparable, Constructor, DefaultConstructor,
Initializer, IterNext, Iterable, PyComparisonOp, Representable, SelfIter,
},
utils::collection_repr,
};
Expand Down Expand Up @@ -746,4 +747,200 @@ mod _collections {
})
}
}

#[pyattr]
#[pyclass(
module = "collections",
name = "defaultdict",
base = PyDict,
unhashable = true
)]
#[derive(Debug, Default)]
struct PyDefaultDict {
dict: PyDict,
default_factory: PyRwLock<Option<PyObjectRef>>,
}

#[pyclass(
with(AsMapping, AsNumber, Constructor, Initializer, Representable),
flags(BASETYPE, MAPPING, HAS_DICT)
)]
impl PyDefaultDict {
#[pygetset]
fn default_factory(&self) -> Option<PyObjectRef> {
self.default_factory.read().clone()
}

#[pygetset(name = "default_factory", setter)]
fn default_factory_setter(&self, value: PyObjectRef, vm: &VirtualMachine) {
*self.default_factory.write() = if value.is(&vm.ctx.none()) {
None
} else {
Some(value)
};
}

#[pymethod]
fn __missing__(&self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult {
let factory = self.default_factory.read().clone();
if let Some(f) = factory {
let value = f.call((), vm)?;
self.dict.setdefault(key, value.into(), vm)
} else {
Comment on lines +784 to +789

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

__missing__ should assign unconditionally, not via setdefault

Line 788 uses setdefault, which can return a pre-inserted value if the factory mutates the dict re-entrantly. defaultdict.__missing__ must store and return the factory result.

Proposed fix
         #[pymethod]
         fn __missing__(&self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult {
             let factory = self.default_factory.read().clone();
             if let Some(f) = factory {
                 let value = f.call((), vm)?;
-                self.dict.setdefault(key, value.into(), vm)
+                self.dict.inner_setitem(&*key, value.clone(), vm)?;
+                Ok(value)
             } else {
                 Err(vm.new_key_error(key))
             }
         }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
fn __missing__(&self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult {
let factory = self.default_factory.read().clone();
if let Some(f) = factory {
let value = f.call((), vm)?;
self.dict.setdefault(key, value.into(), vm)
} else {
fn __missing__(&self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult {
let factory = self.default_factory.read().clone();
if let Some(f) = factory {
let value = f.call((), vm)?;
self.dict.inner_setitem(&*key, value.clone(), vm)?;
Ok(value)
} else {
Err(vm.new_key_error(key))
}
}
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@crates/vm/src/stdlib/_collections.rs` around lines 784 - 789, In the
`__missing__` method, the current implementation uses `setdefault` to store the
factory result, but this can return a pre-existing value if the factory function
mutates the dict during execution. Replace the `setdefault` call with a direct
assignment using the dict's set method instead, ensuring the factory result is
unconditionally stored and returned regardless of any re-entrant mutations that
occur during factory execution.

Err(vm.new_key_error(key))
}
}

#[pymethod]
#[pymethod(name = "__copy__")]
fn copy(&self) -> Self {
let default_factory = self.default_factory.read().clone();

Self {
dict: self.dict.copy(),
default_factory: PyRwLock::new(default_factory),
}
}

#[pymethod]
fn __reduce__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
let cls = zelf.class().to_owned();

let factory_tuple = match &*zelf.default_factory.read() {
Some(f) => vm.ctx.new_tuple(vec![f.clone()]),
None => vm.ctx.new_tuple(vec![]),
};

let items_fn = zelf.as_object().get_attr("items", vm)?;
let items_iter = items_fn.call((), vm)?;
let iter = items_iter.get_iter(vm)?;
let none = vm.ctx.none();

Ok(vm
.ctx
.new_tuple(vec![
cls.into(),
factory_tuple.into(),
none.clone(),
none,
iter.into(),
])
.into())
}
}

impl PyDefaultDict {
fn __or__(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
Ok(if let Some(zelf) = lhs.downcast_ref::<Self>() {

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bschoenmaeckers I'd love to get your review on this as well if you don't mind. I had lots of issues with implementing this correctly for some reason, I'm sure it can be simplifiers

if !rhs.fast_isinstance(vm.ctx.types.dict_type) {
vm.ctx.not_implemented.clone().into()
} else {
let default_factory = zelf.default_factory.read().clone();
let dict = zelf.dict.copy();

dict.update(rhs.into(), KwArgs::default(), vm)?;

Self {
dict,
default_factory: PyRwLock::new(default_factory),
}
.to_pyobject(vm)
}
} else if let Some(zelf) = rhs.downcast_ref::<Self>() {
let default_factory = zelf.default_factory.read().clone();
if let Some(dict) = lhs.downcast_ref::<PyDict>() {
let dict = dict.copy();
dict.update(rhs.into(), KwArgs::default(), vm)?;

Self {
dict,
default_factory: PyRwLock::new(default_factory),
}
.to_pyobject(vm)
} else {
vm.ctx.not_implemented.clone().into()
}
} else {
return Err(vm.new_type_error(format!(
"unsupported operand type(s) for |: '{}' and '{}'",
lhs.class().name(),
rhs.class().name()
)));
})
}
}

impl DefaultConstructor for PyDefaultDict {}

impl Initializer for PyDefaultDict {
type Args = FuncArgs;

fn init(zelf: PyRef<Self>, mut args: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
let default_factory = args.take_positional().map_or(Ok(None), |factory| {
let is_none = factory.is(&vm.ctx.none());

if !is_none && !factory.is_callable() {
Err(vm.new_type_error("first argument must be callable or None"))
} else if is_none {
Ok(None)
} else {
Ok(Some(factory))
}
})?;

*zelf.default_factory.write() = default_factory;

zelf.dict.update(
OptionalArg::from_option(args.take_positional()),
args.kwargs,
vm,
)?;
Comment on lines +878 to +897

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Inspect how Initializer(FuncArgs) callers enforce full argument consumption.
# 1) Find other Initializer impls with FuncArgs and how they reject extra args.
rg -n -C3 'impl Initializer for .*{[[:space:][:print:]]*type Args = FuncArgs;' crates/vm/src

# 2) Look for explicit checks of leftover positional args after take_positional().
rg -n -C3 'take_positional\(\)|remaining|is_empty\(\)|too many positional|takes .* positional' crates/vm/src/stdlib crates/vm/src/builtins crates/vm/src/function

# 3) Inspect defaultdict init call path for automatic leftover-arg validation.
rg -n -C3 'trait Initializer|fn init\(|py_new\(|FuncArgs' crates/vm/src

Repository: RustPython/RustPython

Length of output: 50569


🏁 Script executed:

# Find the Initializer trait definition
rg -n 'trait Initializer' crates/vm/src --type rust -A 10

# Look at how py_new/init is called and if remaining args are validated
rg -n 'fn init\(' crates/vm/src/stdlib/_collections.rs -B 2 -A 15

# Search for other defaultdict-like Initializer impls that might show the pattern
rg -n 'impl Initializer' crates/vm/src/stdlib crates/vm/src/builtins --type rust -A 8 | head -100

Repository: RustPython/RustPython

Length of output: 8999


🏁 Script executed:

# See the complete defaultdict.__init__ implementation
sed -n '876,897p' crates/vm/src/stdlib/_collections.rs

# Check what happens in slot_init wrapper - does it validate remaining args?
sed -n '1647,1680p' crates/vm/src/types/slot.rs

# See how dict.update handles the args
rg -n 'fn update' crates/vm/src/stdlib/_collections.rs -A 20 | head -60

# Check if FuncArgs has validation or if take_positional enforces limits
rg -n 'fn take_positional|impl FuncArgs' crates/vm/src/function.rs -A 5 | head -40

Repository: RustPython/RustPython

Length of output: 2164


🏁 Script executed:

# Find FuncArgs definition and bind method
rg -n 'struct FuncArgs|impl FuncArgs' crates/vm/src --type rust -A 3 | head -50

# Find take_positional method definition
rg -n 'fn take_positional' crates/vm/src --type rust -B 2 -A 8

# Check the bind method that converts FuncArgs
rg -n 'fn bind' crates/vm/src --type rust -B 2 -A 10 | head -80

Repository: RustPython/RustPython

Length of output: 4589


🏁 Script executed:

# Check FromArgs implementation for FuncArgs
rg -n 'impl FromArgs for FuncArgs|fn from_args' crates/vm/src/function/argument.rs -A 15

# Also check if there's a FromArgs impl that validates
rg -n 'trait FromArgs' crates/vm/src -A 8 | head -40

Repository: RustPython/RustPython

Length of output: 5310


🏁 Script executed:

# Check if defaultdict uses slot_init or a custom init
rg -n 'fn slot_init|fn init' crates/vm/src/stdlib/_collections.rs -B 3 -A 1

# Also verify how slot_init is invoked - does it check leftover args AFTER init?
sed -n '1652,1690p' crates/vm/src/types/slot.rs

Repository: RustPython/RustPython

Length of output: 1909


defaultdict.init silently accepts extra positional arguments

The implementation does not reject positional arguments beyond default_factory and the optional mapping/iterable. It calls take_positional() twice (lines 879, 894) but leaves any 3rd+ argument in the FuncArgs struct. Due to how FuncArgs::from_args consumes the entire struct in slot_init's binding step, the validation check at the framework level cannot catch these leftovers.

To match CPython's behavior, add a check after the second take_positional() to reject any remaining args:

Example fix
zelf.dict.update(
    OptionalArg::from_option(args.take_positional()),
    args.kwargs,
    vm,
)?;

if !args.args.is_empty() {
    return Err(vm.new_type_error(format!(
        "defaultdict expected at most 2 positional arguments, got {}",
        2 + args.args.len()
    )));
}
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@crates/vm/src/stdlib/_collections.rs` around lines 878 - 897, The defaultdict
init method does not validate that only the expected number of positional
arguments are provided. After the second take_positional() call in the init
method (which consumes the default_factory and the optional mapping), add a
validation check to ensure args.args is empty. If args.args is not empty, return
a TypeError with a message indicating the maximum number of positional arguments
expected (2) and the actual number provided. This will prevent the method from
silently accepting extra positional arguments beyond default_factory and the
optional mapping/iterable parameter.


Ok(())
}
}

impl Representable for PyDefaultDict {
fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
let default_factory = zelf.default_factory.read();

let factory_repr = match default_factory.as_ref() {
Some(factory) => {
if let Some(_guard) = ReprGuard::enter(vm, factory) {
factory.repr(vm)?.to_string()
} else {
String::from("...")
}
}
None => String::from("None"),
};

let dict_repr = Representable::repr(&zelf.dict.copy().into_ref(&vm.ctx), vm)?;

Ok(format!(
"{}({}, {})",
zelf.class().name(),
factory_repr,
dict_repr
))
}
}

impl AsMapping for PyDefaultDict {
fn as_mapping() -> &'static PyMappingMethods {
PyDict::as_mapping()
}
}

impl AsNumber for PyDefaultDict {
fn as_number() -> &'static PyNumberMethods {
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
or: Some(|a, b, vm| {
PyDefaultDict::__or__(a.to_pyobject(vm), b.to_pyobject(vm), vm)
}),
..PyNumberMethods::NOT_IMPLEMENTED
};
&AS_NUMBER
}
}
}
Loading