Skip to content

Commit af01eb8

Browse files
committed
Basic defaultdict impl
1 parent fe2a7db commit af01eb8

4 files changed

Lines changed: 214 additions & 72 deletions

File tree

Lib/collections/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@
5757
try:
5858
from _collections import defaultdict
5959
except ImportError:
60-
# TODO: RUSTPYTHON - implement defaultdict in Rust
61-
from ._defaultdict import defaultdict
60+
pass
6261

6362
heapq = None # Lazily imported
6463

Lib/collections/_defaultdict.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

crates/vm/src/builtins/dict.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ impl PyDict {
141141
self.entries.items()
142142
}
143143

144-
fn merge_object_with_override(
144+
pub(crate) fn merge_object_with_override(
145145
&self,
146146
other: PyObjectRef,
147147
override_existing: bool,
@@ -364,7 +364,7 @@ impl PyDict {
364364
self.entries.clear()
365365
}
366366

367-
fn __setitem__(
367+
pub(crate) fn __setitem__(
368368
&self,
369369
key: PyObjectRef,
370370
value: PyObjectRef,
@@ -406,7 +406,7 @@ impl PyDict {
406406
}
407407

408408
#[pymethod]
409-
fn update(
409+
pub fn update(
410410
&self,
411411
dict_obj: OptionalArg<PyObjectRef>,
412412
kwargs: KwArgs,

crates/vm/src/stdlib/_collections.rs

Lines changed: 210 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@ mod _collections {
77
atomic_func,
88
builtins::{
99
IterStatus::{Active, Exhausted},
10-
PositionIterInternal, PyGenericAlias, PyInt, PyStr, PyType, PyTypeRef,
10+
PositionIterInternal, PyDict, PyGenericAlias, PyInt, PyStr, PyType, PyTypeRef,
1111
},
1212
common::lock::{PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard},
13-
function::{KwArgs, OptionalArg, PyComparisonValue},
13+
convert::ToPyObject,
14+
function::{FuncArgs, KwArgs, OptionalArg, PyComparisonValue},
1415
iter::PyExactSizeIterator,
15-
protocol::{PyIterReturn, PyNumberMethods, PySequenceMethods},
16+
protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods},
1617
recursion::ReprGuard,
1718
sequence::{MutObjectSequenceOp, OptionalRangeArgs},
1819
sliceable::SequenceIndexOp,
1920
types::{
20-
AsNumber, AsSequence, Comparable, Constructor, DefaultConstructor, Initializer,
21-
IterNext, Iterable, PyComparisonOp, Representable, SelfIter,
21+
AsMapping, AsNumber, AsSequence, Comparable, Constructor, DefaultConstructor,
22+
Initializer, IterNext, Iterable, PyComparisonOp, Representable, SelfIter,
2223
},
2324
utils::collection_repr,
2425
};
@@ -746,4 +747,208 @@ mod _collections {
746747
})
747748
}
748749
}
750+
751+
#[pyattr]
752+
#[pyclass(
753+
module = "collections",
754+
name = "defaultdict",
755+
base = PyDict,
756+
unhashable = true
757+
)]
758+
#[derive(Debug, Default)]
759+
struct PyDefaultDict {
760+
dict: PyDict,
761+
default_factory: PyRwLock<Option<PyObjectRef>>,
762+
}
763+
764+
#[pyclass(
765+
with(
766+
AsMapping,
767+
AsNumber,
768+
Constructor,
769+
Initializer,
770+
Representable
771+
// Comparable,
772+
),
773+
flags(BASETYPE, MAPPING, HAS_DICT)
774+
)]
775+
impl PyDefaultDict {
776+
#[pygetset]
777+
fn default_factory(&self) -> Option<PyObjectRef> {
778+
self.default_factory.read().clone()
779+
}
780+
781+
#[pygetset(name = "default_factory", setter)]
782+
fn default_factory_setter(&self, value: PyObjectRef, vm: &VirtualMachine) {
783+
*self.default_factory.write() = if value.is(&vm.ctx.none()) {
784+
None
785+
} else {
786+
Some(value)
787+
};
788+
}
789+
790+
#[pymethod]
791+
fn __missing__(&self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult {
792+
let factory = self.default_factory.read().clone();
793+
if let Some(f) = factory {
794+
let value = f.call((), vm)?;
795+
self.dict.__setitem__(key, value.clone(), vm)?;
796+
Ok(value)
797+
} else {
798+
Err(vm.new_key_error(key))
799+
}
800+
}
801+
802+
#[pymethod]
803+
#[pymethod(name = "__copy__")]
804+
fn copy(&self) -> Self {
805+
let default_factory = self.default_factory.read().clone();
806+
807+
Self {
808+
dict: self.dict.copy(),
809+
default_factory: PyRwLock::new(default_factory),
810+
}
811+
}
812+
813+
#[pymethod]
814+
fn __reduce__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
815+
let cls = zelf.class().to_owned();
816+
817+
let factory_tuple = match &*zelf.default_factory.read() {
818+
Some(f) => vm.ctx.new_tuple(vec![f.clone()]),
819+
None => vm.ctx.new_tuple(vec![]),
820+
};
821+
822+
let items_fn = zelf.as_object().get_attr("items", vm)?;
823+
let items_iter = items_fn.call((), vm)?;
824+
let iter = items_iter.get_iter(vm)?;
825+
let none = vm.ctx.none();
826+
827+
Ok(vm
828+
.ctx
829+
.new_tuple(vec![
830+
cls.into(),
831+
factory_tuple.into(),
832+
none.clone(),
833+
none,
834+
iter.into(),
835+
])
836+
.into())
837+
}
838+
}
839+
840+
impl PyDefaultDict {
841+
fn __or__(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
842+
Ok(if let Some(zelf) = lhs.downcast_ref::<Self>() {
843+
if !rhs.fast_isinstance(vm.ctx.types.dict_type) {
844+
vm.ctx.not_implemented.clone().into()
845+
} else {
846+
let default_factory = zelf.default_factory.read().clone();
847+
let dict = zelf.dict.copy();
848+
849+
dict.update(rhs.into(), KwArgs::default(), vm)?;
850+
851+
Self {
852+
dict,
853+
default_factory: PyRwLock::new(default_factory),
854+
}
855+
.to_pyobject(vm)
856+
}
857+
} else if let Some(zelf) = rhs.downcast_ref::<Self>() {
858+
let default_factory = zelf.default_factory.read().clone();
859+
if let Some(dict) = lhs.downcast_ref::<PyDict>() {
860+
let dict = dict.copy();
861+
dict.update(rhs.into(), KwArgs::default(), vm)?;
862+
863+
Self {
864+
dict,
865+
default_factory: PyRwLock::new(default_factory),
866+
}
867+
.to_pyobject(vm)
868+
} else {
869+
vm.ctx.not_implemented.clone().into()
870+
}
871+
} else {
872+
return Err(vm.new_type_error(format!(
873+
"unsupported operand type(s) for |: '{}' and '{}'",
874+
lhs.class().name(),
875+
rhs.class().name()
876+
)));
877+
})
878+
}
879+
}
880+
881+
impl DefaultConstructor for PyDefaultDict {}
882+
883+
impl Initializer for PyDefaultDict {
884+
type Args = FuncArgs;
885+
886+
fn init(zelf: PyRef<Self>, mut args: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
887+
let default_factory = args.take_positional().map_or(Ok(None), |factory| {
888+
let is_none = factory.is(&vm.ctx.none());
889+
890+
if !is_none && !factory.is_callable() {
891+
Err(vm.new_type_error("first argument must be callable or None"))
892+
} else if is_none {
893+
Ok(None)
894+
} else {
895+
Ok(Some(factory))
896+
}
897+
})?;
898+
899+
*zelf.default_factory.write() = default_factory;
900+
901+
zelf.dict.update(
902+
OptionalArg::from_option(args.take_positional()),
903+
args.kwargs,
904+
vm,
905+
)?;
906+
907+
Ok(())
908+
}
909+
}
910+
911+
impl Representable for PyDefaultDict {
912+
fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
913+
let default_factory = zelf.default_factory.read();
914+
915+
let factory_repr = match default_factory.as_ref() {
916+
Some(factory) => {
917+
if let Some(_guard) = ReprGuard::enter(vm, factory) {
918+
factory.repr(vm)?.to_string()
919+
} else {
920+
String::from("...")
921+
}
922+
}
923+
None => String::from("None"),
924+
};
925+
926+
let dict_repr = Representable::repr(&zelf.dict.copy().into_ref(&vm.ctx), vm)?;
927+
928+
Ok(format!(
929+
"{}({}, {})",
930+
zelf.class().name(),
931+
factory_repr,
932+
dict_repr
933+
))
934+
}
935+
}
936+
937+
impl AsMapping for PyDefaultDict {
938+
fn as_mapping() -> &'static PyMappingMethods {
939+
PyDict::as_mapping()
940+
}
941+
}
942+
943+
impl AsNumber for PyDefaultDict {
944+
fn as_number() -> &'static PyNumberMethods {
945+
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
946+
or: Some(|a, b, vm| {
947+
PyDefaultDict::__or__(a.to_pyobject(vm), b.to_pyobject(vm), vm)
948+
}),
949+
..PyNumberMethods::NOT_IMPLEMENTED
950+
};
951+
&AS_NUMBER
952+
}
953+
}
749954
}

0 commit comments

Comments
 (0)