Skip to content

Commit 9fd4a09

Browse files
committed
defaultdict rust impl
1 parent fe2a7db commit 9fd4a09

4 files changed

Lines changed: 214 additions & 73 deletions

File tree

Lib/collections/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@
5757
try:
5858
from _collections import defaultdict
5959
except ImportError:
60-
# TODO: RUSTPYTHON - implement defaultdict in Rust
61-
from ._defaultdict import defaultdict
60+
pass
6261

6362
heapq = None # Lazily imported
6463

Lib/collections/_defaultdict.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

crates/vm/src/builtins/dict.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ impl PyDict {
141141
self.entries.items()
142142
}
143143

144-
fn merge_object_with_override(
144+
pub(crate) fn merge_object_with_override(
145145
&self,
146146
other: PyObjectRef,
147147
override_existing: bool,
@@ -364,7 +364,7 @@ impl PyDict {
364364
self.entries.clear()
365365
}
366366

367-
fn __setitem__(
367+
pub(crate) fn __setitem__(
368368
&self,
369369
key: PyObjectRef,
370370
value: PyObjectRef,
@@ -387,7 +387,7 @@ impl PyDict {
387387
}
388388

389389
#[pymethod]
390-
fn setdefault(
390+
pub(crate) fn setdefault(
391391
&self,
392392
key: PyObjectRef,
393393
default: OptionalArg<PyObjectRef>,
@@ -406,7 +406,7 @@ impl PyDict {
406406
}
407407

408408
#[pymethod]
409-
fn update(
409+
pub fn update(
410410
&self,
411411
dict_obj: OptionalArg<PyObjectRef>,
412412
kwargs: KwArgs,

crates/vm/src/stdlib/_collections.rs

Lines changed: 209 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@ mod _collections {
77
atomic_func,
88
builtins::{
99
IterStatus::{Active, Exhausted},
10-
PositionIterInternal, PyGenericAlias, PyInt, PyStr, PyType, PyTypeRef,
10+
PositionIterInternal, PyDict, PyGenericAlias, PyInt, PyStr, PyType, PyTypeRef,
1111
},
1212
common::lock::{PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard},
13-
function::{KwArgs, OptionalArg, PyComparisonValue},
13+
convert::ToPyObject,
14+
function::{FuncArgs, KwArgs, OptionalArg, PyComparisonValue},
1415
iter::PyExactSizeIterator,
15-
protocol::{PyIterReturn, PyNumberMethods, PySequenceMethods},
16+
protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods},
1617
recursion::ReprGuard,
1718
sequence::{MutObjectSequenceOp, OptionalRangeArgs},
1819
sliceable::SequenceIndexOp,
1920
types::{
20-
AsNumber, AsSequence, Comparable, Constructor, DefaultConstructor, Initializer,
21-
IterNext, Iterable, PyComparisonOp, Representable, SelfIter,
21+
AsMapping, AsNumber, AsSequence, Comparable, Constructor, DefaultConstructor,
22+
Initializer, IterNext, Iterable, PyComparisonOp, Representable, SelfIter,
2223
},
2324
utils::collection_repr,
2425
};
@@ -746,4 +747,207 @@ mod _collections {
746747
})
747748
}
748749
}
750+
751+
#[pyattr]
752+
#[pyclass(
753+
module = "collections",
754+
name = "defaultdict",
755+
base = PyDict,
756+
unhashable = true
757+
)]
758+
#[derive(Debug, Default)]
759+
struct PyDefaultDict {
760+
dict: PyDict,
761+
default_factory: PyRwLock<Option<PyObjectRef>>,
762+
}
763+
764+
#[pyclass(
765+
with(
766+
AsMapping,
767+
AsNumber,
768+
Constructor,
769+
Initializer,
770+
Representable
771+
// Comparable,
772+
),
773+
flags(BASETYPE, MAPPING, HAS_DICT)
774+
)]
775+
impl PyDefaultDict {
776+
#[pygetset]
777+
fn default_factory(&self) -> Option<PyObjectRef> {
778+
self.default_factory.read().clone()
779+
}
780+
781+
#[pygetset(name = "default_factory", setter)]
782+
fn default_factory_setter(&self, value: PyObjectRef, vm: &VirtualMachine) {
783+
*self.default_factory.write() = if value.is(&vm.ctx.none()) {
784+
None
785+
} else {
786+
Some(value)
787+
};
788+
}
789+
790+
#[pymethod]
791+
fn __missing__(&self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult {
792+
let factory = self.default_factory.read().clone();
793+
if let Some(f) = factory {
794+
let value = f.call((), vm)?;
795+
self.dict.setdefault(key, value.into(), vm)
796+
} else {
797+
Err(vm.new_key_error(key))
798+
}
799+
}
800+
801+
#[pymethod]
802+
#[pymethod(name = "__copy__")]
803+
fn copy(&self) -> Self {
804+
let default_factory = self.default_factory.read().clone();
805+
806+
Self {
807+
dict: self.dict.copy(),
808+
default_factory: PyRwLock::new(default_factory),
809+
}
810+
}
811+
812+
#[pymethod]
813+
fn __reduce__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
814+
let cls = zelf.class().to_owned();
815+
816+
let factory_tuple = match &*zelf.default_factory.read() {
817+
Some(f) => vm.ctx.new_tuple(vec![f.clone()]),
818+
None => vm.ctx.new_tuple(vec![]),
819+
};
820+
821+
let items_fn = zelf.as_object().get_attr("items", vm)?;
822+
let items_iter = items_fn.call((), vm)?;
823+
let iter = items_iter.get_iter(vm)?;
824+
let none = vm.ctx.none();
825+
826+
Ok(vm
827+
.ctx
828+
.new_tuple(vec![
829+
cls.into(),
830+
factory_tuple.into(),
831+
none.clone(),
832+
none,
833+
iter.into(),
834+
])
835+
.into())
836+
}
837+
}
838+
839+
impl PyDefaultDict {
840+
fn __or__(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
841+
Ok(if let Some(zelf) = lhs.downcast_ref::<Self>() {
842+
if !rhs.fast_isinstance(vm.ctx.types.dict_type) {
843+
vm.ctx.not_implemented.clone().into()
844+
} else {
845+
let default_factory = zelf.default_factory.read().clone();
846+
let dict = zelf.dict.copy();
847+
848+
dict.update(rhs.into(), KwArgs::default(), vm)?;
849+
850+
Self {
851+
dict,
852+
default_factory: PyRwLock::new(default_factory),
853+
}
854+
.to_pyobject(vm)
855+
}
856+
} else if let Some(zelf) = rhs.downcast_ref::<Self>() {
857+
let default_factory = zelf.default_factory.read().clone();
858+
if let Some(dict) = lhs.downcast_ref::<PyDict>() {
859+
let dict = dict.copy();
860+
dict.update(rhs.into(), KwArgs::default(), vm)?;
861+
862+
Self {
863+
dict,
864+
default_factory: PyRwLock::new(default_factory),
865+
}
866+
.to_pyobject(vm)
867+
} else {
868+
vm.ctx.not_implemented.clone().into()
869+
}
870+
} else {
871+
return Err(vm.new_type_error(format!(
872+
"unsupported operand type(s) for |: '{}' and '{}'",
873+
lhs.class().name(),
874+
rhs.class().name()
875+
)));
876+
})
877+
}
878+
}
879+
880+
impl DefaultConstructor for PyDefaultDict {}
881+
882+
impl Initializer for PyDefaultDict {
883+
type Args = FuncArgs;
884+
885+
fn init(zelf: PyRef<Self>, mut args: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
886+
let default_factory = args.take_positional().map_or(Ok(None), |factory| {
887+
let is_none = factory.is(&vm.ctx.none());
888+
889+
if !is_none && !factory.is_callable() {
890+
Err(vm.new_type_error("first argument must be callable or None"))
891+
} else if is_none {
892+
Ok(None)
893+
} else {
894+
Ok(Some(factory))
895+
}
896+
})?;
897+
898+
*zelf.default_factory.write() = default_factory;
899+
900+
zelf.dict.update(
901+
OptionalArg::from_option(args.take_positional()),
902+
args.kwargs,
903+
vm,
904+
)?;
905+
906+
Ok(())
907+
}
908+
}
909+
910+
impl Representable for PyDefaultDict {
911+
fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
912+
let default_factory = zelf.default_factory.read();
913+
914+
let factory_repr = match default_factory.as_ref() {
915+
Some(factory) => {
916+
if let Some(_guard) = ReprGuard::enter(vm, factory) {
917+
factory.repr(vm)?.to_string()
918+
} else {
919+
String::from("...")
920+
}
921+
}
922+
None => String::from("None"),
923+
};
924+
925+
let dict_repr = Representable::repr(&zelf.dict.copy().into_ref(&vm.ctx), vm)?;
926+
927+
Ok(format!(
928+
"{}({}, {})",
929+
zelf.class().name(),
930+
factory_repr,
931+
dict_repr
932+
))
933+
}
934+
}
935+
936+
impl AsMapping for PyDefaultDict {
937+
fn as_mapping() -> &'static PyMappingMethods {
938+
PyDict::as_mapping()
939+
}
940+
}
941+
942+
impl AsNumber for PyDefaultDict {
943+
fn as_number() -> &'static PyNumberMethods {
944+
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
945+
or: Some(|a, b, vm| {
946+
PyDefaultDict::__or__(a.to_pyobject(vm), b.to_pyobject(vm), vm)
947+
}),
948+
..PyNumberMethods::NOT_IMPLEMENTED
949+
};
950+
&AS_NUMBER
951+
}
952+
}
749953
}

0 commit comments

Comments
 (0)