@@ -7,18 +7,19 @@ mod _collections {
77 atomic_func,
88 builtins:: {
99 IterStatus :: { Active , Exhausted } ,
10- PositionIterInternal , PyGenericAlias , PyInt , PyStr , PyType , PyTypeRef ,
10+ PositionIterInternal , PyDict , PyGenericAlias , PyInt , PyStr , PyType , PyTypeRef ,
1111 } ,
1212 common:: lock:: { PyMutex , PyRwLock , PyRwLockReadGuard , PyRwLockWriteGuard } ,
13- function:: { KwArgs , OptionalArg , PyComparisonValue } ,
13+ convert:: ToPyObject ,
14+ function:: { FuncArgs , KwArgs , OptionalArg , PyComparisonValue } ,
1415 iter:: PyExactSizeIterator ,
15- protocol:: { PyIterReturn , PyNumberMethods , PySequenceMethods } ,
16+ protocol:: { PyIterReturn , PyMappingMethods , PyNumberMethods , PySequenceMethods } ,
1617 recursion:: ReprGuard ,
1718 sequence:: { MutObjectSequenceOp , OptionalRangeArgs } ,
1819 sliceable:: SequenceIndexOp ,
1920 types:: {
20- AsNumber , AsSequence , Comparable , Constructor , DefaultConstructor , Initializer ,
21- IterNext , Iterable , PyComparisonOp , Representable , SelfIter ,
21+ AsMapping , AsNumber , AsSequence , Comparable , Constructor , DefaultConstructor ,
22+ Initializer , IterNext , Iterable , PyComparisonOp , Representable , SelfIter ,
2223 } ,
2324 utils:: collection_repr,
2425 } ;
@@ -746,4 +747,207 @@ mod _collections {
746747 } )
747748 }
748749 }
750+
751+ #[ pyattr]
752+ #[ pyclass(
753+ module = "collections" ,
754+ name = "defaultdict" ,
755+ base = PyDict ,
756+ unhashable = true
757+ ) ]
758+ #[ derive( Debug , Default ) ]
759+ struct PyDefaultDict {
760+ dict : PyDict ,
761+ default_factory : PyRwLock < Option < PyObjectRef > > ,
762+ }
763+
764+ #[ pyclass(
765+ with(
766+ AsMapping ,
767+ AsNumber ,
768+ Constructor ,
769+ Initializer ,
770+ Representable
771+ // Comparable,
772+ ) ,
773+ flags( BASETYPE , MAPPING , HAS_DICT )
774+ ) ]
775+ impl PyDefaultDict {
776+ #[ pygetset]
777+ fn default_factory ( & self ) -> Option < PyObjectRef > {
778+ self . default_factory . read ( ) . clone ( )
779+ }
780+
781+ #[ pygetset( name = "default_factory" , setter) ]
782+ fn default_factory_setter ( & self , value : PyObjectRef , vm : & VirtualMachine ) {
783+ * self . default_factory . write ( ) = if value. is ( & vm. ctx . none ( ) ) {
784+ None
785+ } else {
786+ Some ( value)
787+ } ;
788+ }
789+
790+ #[ pymethod]
791+ fn __missing__ ( & self , key : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
792+ let factory = self . default_factory . read ( ) . clone ( ) ;
793+ if let Some ( f) = factory {
794+ let value = f. call ( ( ) , vm) ?;
795+ self . dict . setdefault ( key, value. into ( ) , vm)
796+ } else {
797+ Err ( vm. new_key_error ( key) )
798+ }
799+ }
800+
801+ #[ pymethod]
802+ #[ pymethod( name = "__copy__" ) ]
803+ fn copy ( & self ) -> Self {
804+ let default_factory = self . default_factory . read ( ) . clone ( ) ;
805+
806+ Self {
807+ dict : self . dict . copy ( ) ,
808+ default_factory : PyRwLock :: new ( default_factory) ,
809+ }
810+ }
811+
812+ #[ pymethod]
813+ fn __reduce__ ( zelf : PyRef < Self > , vm : & VirtualMachine ) -> PyResult {
814+ let cls = zelf. class ( ) . to_owned ( ) ;
815+
816+ let factory_tuple = match & * zelf. default_factory . read ( ) {
817+ Some ( f) => vm. ctx . new_tuple ( vec ! [ f. clone( ) ] ) ,
818+ None => vm. ctx . new_tuple ( vec ! [ ] ) ,
819+ } ;
820+
821+ let items_fn = zelf. as_object ( ) . get_attr ( "items" , vm) ?;
822+ let items_iter = items_fn. call ( ( ) , vm) ?;
823+ let iter = items_iter. get_iter ( vm) ?;
824+ let none = vm. ctx . none ( ) ;
825+
826+ Ok ( vm
827+ . ctx
828+ . new_tuple ( vec ! [
829+ cls. into( ) ,
830+ factory_tuple. into( ) ,
831+ none. clone( ) ,
832+ none,
833+ iter. into( ) ,
834+ ] )
835+ . into ( ) )
836+ }
837+ }
838+
839+ impl PyDefaultDict {
840+ fn __or__ ( lhs : PyObjectRef , rhs : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
841+ Ok ( if let Some ( zelf) = lhs. downcast_ref :: < Self > ( ) {
842+ if !rhs. fast_isinstance ( vm. ctx . types . dict_type ) {
843+ vm. ctx . not_implemented . clone ( ) . into ( )
844+ } else {
845+ let default_factory = zelf. default_factory . read ( ) . clone ( ) ;
846+ let dict = zelf. dict . copy ( ) ;
847+
848+ dict. update ( rhs. into ( ) , KwArgs :: default ( ) , vm) ?;
849+
850+ Self {
851+ dict,
852+ default_factory : PyRwLock :: new ( default_factory) ,
853+ }
854+ . to_pyobject ( vm)
855+ }
856+ } else if let Some ( zelf) = rhs. downcast_ref :: < Self > ( ) {
857+ let default_factory = zelf. default_factory . read ( ) . clone ( ) ;
858+ if let Some ( dict) = lhs. downcast_ref :: < PyDict > ( ) {
859+ let dict = dict. copy ( ) ;
860+ dict. update ( rhs. into ( ) , KwArgs :: default ( ) , vm) ?;
861+
862+ Self {
863+ dict,
864+ default_factory : PyRwLock :: new ( default_factory) ,
865+ }
866+ . to_pyobject ( vm)
867+ } else {
868+ vm. ctx . not_implemented . clone ( ) . into ( )
869+ }
870+ } else {
871+ return Err ( vm. new_type_error ( format ! (
872+ "unsupported operand type(s) for |: '{}' and '{}'" ,
873+ lhs. class( ) . name( ) ,
874+ rhs. class( ) . name( )
875+ ) ) ) ;
876+ } )
877+ }
878+ }
879+
880+ impl DefaultConstructor for PyDefaultDict { }
881+
882+ impl Initializer for PyDefaultDict {
883+ type Args = FuncArgs ;
884+
885+ fn init ( zelf : PyRef < Self > , mut args : Self :: Args , vm : & VirtualMachine ) -> PyResult < ( ) > {
886+ let default_factory = args. take_positional ( ) . map_or ( Ok ( None ) , |factory| {
887+ let is_none = factory. is ( & vm. ctx . none ( ) ) ;
888+
889+ if !is_none && !factory. is_callable ( ) {
890+ Err ( vm. new_type_error ( "first argument must be callable or None" ) )
891+ } else if is_none {
892+ Ok ( None )
893+ } else {
894+ Ok ( Some ( factory) )
895+ }
896+ } ) ?;
897+
898+ * zelf. default_factory . write ( ) = default_factory;
899+
900+ zelf. dict . update (
901+ OptionalArg :: from_option ( args. take_positional ( ) ) ,
902+ args. kwargs ,
903+ vm,
904+ ) ?;
905+
906+ Ok ( ( ) )
907+ }
908+ }
909+
910+ impl Representable for PyDefaultDict {
911+ fn repr_str ( zelf : & Py < Self > , vm : & VirtualMachine ) -> PyResult < String > {
912+ let default_factory = zelf. default_factory . read ( ) ;
913+
914+ let factory_repr = match default_factory. as_ref ( ) {
915+ Some ( factory) => {
916+ if let Some ( _guard) = ReprGuard :: enter ( vm, factory) {
917+ factory. repr ( vm) ?. to_string ( )
918+ } else {
919+ String :: from ( "..." )
920+ }
921+ }
922+ None => String :: from ( "None" ) ,
923+ } ;
924+
925+ let dict_repr = Representable :: repr ( & zelf. dict . copy ( ) . into_ref ( & vm. ctx ) , vm) ?;
926+
927+ Ok ( format ! (
928+ "{}({}, {})" ,
929+ zelf. class( ) . name( ) ,
930+ factory_repr,
931+ dict_repr
932+ ) )
933+ }
934+ }
935+
936+ impl AsMapping for PyDefaultDict {
937+ fn as_mapping ( ) -> & ' static PyMappingMethods {
938+ PyDict :: as_mapping ( )
939+ }
940+ }
941+
942+ impl AsNumber for PyDefaultDict {
943+ fn as_number ( ) -> & ' static PyNumberMethods {
944+ static AS_NUMBER : PyNumberMethods = PyNumberMethods {
945+ or : Some ( |a, b, vm| {
946+ PyDefaultDict :: __or__ ( a. to_pyobject ( vm) , b. to_pyobject ( vm) , vm)
947+ } ) ,
948+ ..PyNumberMethods :: NOT_IMPLEMENTED
949+ } ;
950+ & AS_NUMBER
951+ }
952+ }
749953}
0 commit comments