Skip to content

Commit 4857fbb

Browse files
committed
defaultdict rust impl
1 parent fe2a7db commit 4857fbb

4 files changed

Lines changed: 205 additions & 71 deletions

File tree

Lib/collections/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@
5757
try:
5858
from _collections import defaultdict
5959
except ImportError:
60-
# TODO: RUSTPYTHON - implement defaultdict in Rust
61-
from ._defaultdict import defaultdict
60+
pass
6261

6362
heapq = None # Lazily imported
6463

Lib/collections/_defaultdict.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

crates/vm/src/builtins/dict.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ impl PyDict {
387387
}
388388

389389
#[pymethod]
390-
fn setdefault(
390+
pub(crate) fn setdefault(
391391
&self,
392392
key: PyObjectRef,
393393
default: OptionalArg<PyObjectRef>,
@@ -406,7 +406,7 @@ impl PyDict {
406406
}
407407

408408
#[pymethod]
409-
fn update(
409+
pub(crate) fn update(
410410
&self,
411411
dict_obj: OptionalArg<PyObjectRef>,
412412
kwargs: KwArgs,

crates/vm/src/stdlib/_collections.rs

Lines changed: 202 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@ mod _collections {
77
atomic_func,
88
builtins::{
99
IterStatus::{Active, Exhausted},
10-
PositionIterInternal, PyGenericAlias, PyInt, PyStr, PyType, PyTypeRef,
10+
PositionIterInternal, PyDict, PyGenericAlias, PyInt, PyStr, PyType, PyTypeRef,
1111
},
1212
common::lock::{PyMutex, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard},
13-
function::{KwArgs, OptionalArg, PyComparisonValue},
13+
convert::ToPyObject,
14+
function::{FuncArgs, KwArgs, OptionalArg, PyComparisonValue},
1415
iter::PyExactSizeIterator,
15-
protocol::{PyIterReturn, PyNumberMethods, PySequenceMethods},
16+
protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods},
1617
recursion::ReprGuard,
1718
sequence::{MutObjectSequenceOp, OptionalRangeArgs},
1819
sliceable::SequenceIndexOp,
1920
types::{
20-
AsNumber, AsSequence, Comparable, Constructor, DefaultConstructor, Initializer,
21-
IterNext, Iterable, PyComparisonOp, Representable, SelfIter,
21+
AsMapping, AsNumber, AsSequence, Comparable, Constructor, DefaultConstructor,
22+
Initializer, IterNext, Iterable, PyComparisonOp, Representable, SelfIter,
2223
},
2324
utils::collection_repr,
2425
};
@@ -746,4 +747,200 @@ mod _collections {
746747
})
747748
}
748749
}
750+
751+
#[pyattr]
752+
#[pyclass(
753+
module = "collections",
754+
name = "defaultdict",
755+
base = PyDict,
756+
unhashable = true
757+
)]
758+
#[derive(Debug, Default)]
759+
struct PyDefaultDict {
760+
dict: PyDict,
761+
default_factory: PyRwLock<Option<PyObjectRef>>,
762+
}
763+
764+
#[pyclass(
765+
with(AsMapping, AsNumber, Constructor, Initializer, Representable),
766+
flags(BASETYPE, MAPPING, HAS_DICT)
767+
)]
768+
impl PyDefaultDict {
769+
#[pygetset]
770+
fn default_factory(&self) -> Option<PyObjectRef> {
771+
self.default_factory.read().clone()
772+
}
773+
774+
#[pygetset(name = "default_factory", setter)]
775+
fn default_factory_setter(&self, value: PyObjectRef, vm: &VirtualMachine) {
776+
*self.default_factory.write() = if value.is(&vm.ctx.none()) {
777+
None
778+
} else {
779+
Some(value)
780+
};
781+
}
782+
783+
#[pymethod]
784+
fn __missing__(&self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult {
785+
let factory = self.default_factory.read().clone();
786+
if let Some(f) = factory {
787+
let value = f.call((), vm)?;
788+
self.dict.setdefault(key, value.into(), vm)
789+
} else {
790+
Err(vm.new_key_error(key))
791+
}
792+
}
793+
794+
#[pymethod]
795+
#[pymethod(name = "__copy__")]
796+
fn copy(&self) -> Self {
797+
let default_factory = self.default_factory.read().clone();
798+
799+
Self {
800+
dict: self.dict.copy(),
801+
default_factory: PyRwLock::new(default_factory),
802+
}
803+
}
804+
805+
#[pymethod]
806+
fn __reduce__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
807+
let cls = zelf.class().to_owned();
808+
809+
let factory_tuple = match &*zelf.default_factory.read() {
810+
Some(f) => vm.ctx.new_tuple(vec![f.clone()]),
811+
None => vm.ctx.new_tuple(vec![]),
812+
};
813+
814+
let items_fn = zelf.as_object().get_attr("items", vm)?;
815+
let items_iter = items_fn.call((), vm)?;
816+
let iter = items_iter.get_iter(vm)?;
817+
let none = vm.ctx.none();
818+
819+
Ok(vm
820+
.ctx
821+
.new_tuple(vec![
822+
cls.into(),
823+
factory_tuple.into(),
824+
none.clone(),
825+
none,
826+
iter.into(),
827+
])
828+
.into())
829+
}
830+
}
831+
832+
impl PyDefaultDict {
833+
fn __or__(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
834+
Ok(if let Some(zelf) = lhs.downcast_ref::<Self>() {
835+
if !rhs.fast_isinstance(vm.ctx.types.dict_type) {
836+
vm.ctx.not_implemented.clone().into()
837+
} else {
838+
let default_factory = zelf.default_factory.read().clone();
839+
let dict = zelf.dict.copy();
840+
841+
dict.update(rhs.into(), KwArgs::default(), vm)?;
842+
843+
Self {
844+
dict,
845+
default_factory: PyRwLock::new(default_factory),
846+
}
847+
.to_pyobject(vm)
848+
}
849+
} else if let Some(zelf) = rhs.downcast_ref::<Self>() {
850+
let default_factory = zelf.default_factory.read().clone();
851+
if let Some(dict) = lhs.downcast_ref::<PyDict>() {
852+
let dict = dict.copy();
853+
dict.update(rhs.into(), KwArgs::default(), vm)?;
854+
855+
Self {
856+
dict,
857+
default_factory: PyRwLock::new(default_factory),
858+
}
859+
.to_pyobject(vm)
860+
} else {
861+
vm.ctx.not_implemented.clone().into()
862+
}
863+
} else {
864+
return Err(vm.new_type_error(format!(
865+
"unsupported operand type(s) for |: '{}' and '{}'",
866+
lhs.class().name(),
867+
rhs.class().name()
868+
)));
869+
})
870+
}
871+
}
872+
873+
impl DefaultConstructor for PyDefaultDict {}
874+
875+
impl Initializer for PyDefaultDict {
876+
type Args = FuncArgs;
877+
878+
fn init(zelf: PyRef<Self>, mut args: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
879+
let default_factory = args.take_positional().map_or(Ok(None), |factory| {
880+
let is_none = factory.is(&vm.ctx.none());
881+
882+
if !is_none && !factory.is_callable() {
883+
Err(vm.new_type_error("first argument must be callable or None"))
884+
} else if is_none {
885+
Ok(None)
886+
} else {
887+
Ok(Some(factory))
888+
}
889+
})?;
890+
891+
*zelf.default_factory.write() = default_factory;
892+
893+
zelf.dict.update(
894+
OptionalArg::from_option(args.take_positional()),
895+
args.kwargs,
896+
vm,
897+
)?;
898+
899+
Ok(())
900+
}
901+
}
902+
903+
impl Representable for PyDefaultDict {
904+
fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
905+
let default_factory = zelf.default_factory.read();
906+
907+
let factory_repr = match default_factory.as_ref() {
908+
Some(factory) => {
909+
if let Some(_guard) = ReprGuard::enter(vm, factory) {
910+
factory.repr(vm)?.to_string()
911+
} else {
912+
String::from("...")
913+
}
914+
}
915+
None => String::from("None"),
916+
};
917+
918+
let dict_repr = Representable::repr(&zelf.dict.copy().into_ref(&vm.ctx), vm)?;
919+
920+
Ok(format!(
921+
"{}({}, {})",
922+
zelf.class().name(),
923+
factory_repr,
924+
dict_repr
925+
))
926+
}
927+
}
928+
929+
impl AsMapping for PyDefaultDict {
930+
fn as_mapping() -> &'static PyMappingMethods {
931+
PyDict::as_mapping()
932+
}
933+
}
934+
935+
impl AsNumber for PyDefaultDict {
936+
fn as_number() -> &'static PyNumberMethods {
937+
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
938+
or: Some(|a, b, vm| {
939+
PyDefaultDict::__or__(a.to_pyobject(vm), b.to_pyobject(vm), vm)
940+
}),
941+
..PyNumberMethods::NOT_IMPLEMENTED
942+
};
943+
&AS_NUMBER
944+
}
945+
}
749946
}

0 commit comments

Comments
 (0)