@@ -7,18 +7,19 @@ mod _collections {
77 atomic_func,
88 builtins:: {
99 IterStatus :: { Active , Exhausted } ,
10- PositionIterInternal , PyGenericAlias , PyInt , PyStr , PyType , PyTypeRef ,
10+ PositionIterInternal , PyDict , PyGenericAlias , PyInt , PyStr , PyType , PyTypeRef ,
1111 } ,
1212 common:: lock:: { PyMutex , PyRwLock , PyRwLockReadGuard , PyRwLockWriteGuard } ,
13- function:: { KwArgs , OptionalArg , PyComparisonValue } ,
13+ convert:: ToPyObject ,
14+ function:: { FuncArgs , KwArgs , OptionalArg , PyComparisonValue } ,
1415 iter:: PyExactSizeIterator ,
15- protocol:: { PyIterReturn , PyNumberMethods , PySequenceMethods } ,
16+ protocol:: { PyIterReturn , PyMappingMethods , PyNumberMethods , PySequenceMethods } ,
1617 recursion:: ReprGuard ,
1718 sequence:: { MutObjectSequenceOp , OptionalRangeArgs } ,
1819 sliceable:: SequenceIndexOp ,
1920 types:: {
20- AsNumber , AsSequence , Comparable , Constructor , DefaultConstructor , Initializer ,
21- IterNext , Iterable , PyComparisonOp , Representable , SelfIter ,
21+ AsMapping , AsNumber , AsSequence , Comparable , Constructor , DefaultConstructor ,
22+ Initializer , IterNext , Iterable , PyComparisonOp , Representable , SelfIter ,
2223 } ,
2324 utils:: collection_repr,
2425 } ;
@@ -746,4 +747,200 @@ mod _collections {
746747 } )
747748 }
748749 }
750+
751+ #[ pyattr]
752+ #[ pyclass(
753+ module = "collections" ,
754+ name = "defaultdict" ,
755+ base = PyDict ,
756+ unhashable = true
757+ ) ]
758+ #[ derive( Debug , Default ) ]
759+ struct PyDefaultDict {
760+ dict : PyDict ,
761+ default_factory : PyRwLock < Option < PyObjectRef > > ,
762+ }
763+
764+ #[ pyclass(
765+ with( AsMapping , AsNumber , Constructor , Initializer , Representable ) ,
766+ flags( BASETYPE , MAPPING , HAS_DICT )
767+ ) ]
768+ impl PyDefaultDict {
769+ #[ pygetset]
770+ fn default_factory ( & self ) -> Option < PyObjectRef > {
771+ self . default_factory . read ( ) . clone ( )
772+ }
773+
774+ #[ pygetset( name = "default_factory" , setter) ]
775+ fn default_factory_setter ( & self , value : PyObjectRef , vm : & VirtualMachine ) {
776+ * self . default_factory . write ( ) = if value. is ( & vm. ctx . none ( ) ) {
777+ None
778+ } else {
779+ Some ( value)
780+ } ;
781+ }
782+
783+ #[ pymethod]
784+ fn __missing__ ( & self , key : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
785+ let factory = self . default_factory . read ( ) . clone ( ) ;
786+ if let Some ( f) = factory {
787+ let value = f. call ( ( ) , vm) ?;
788+ self . dict . setdefault ( key, value. into ( ) , vm)
789+ } else {
790+ Err ( vm. new_key_error ( key) )
791+ }
792+ }
793+
794+ #[ pymethod]
795+ #[ pymethod( name = "__copy__" ) ]
796+ fn copy ( & self ) -> Self {
797+ let default_factory = self . default_factory . read ( ) . clone ( ) ;
798+
799+ Self {
800+ dict : self . dict . copy ( ) ,
801+ default_factory : PyRwLock :: new ( default_factory) ,
802+ }
803+ }
804+
805+ #[ pymethod]
806+ fn __reduce__ ( zelf : PyRef < Self > , vm : & VirtualMachine ) -> PyResult {
807+ let cls = zelf. class ( ) . to_owned ( ) ;
808+
809+ let factory_tuple = match & * zelf. default_factory . read ( ) {
810+ Some ( f) => vm. ctx . new_tuple ( vec ! [ f. clone( ) ] ) ,
811+ None => vm. ctx . new_tuple ( vec ! [ ] ) ,
812+ } ;
813+
814+ let items_fn = zelf. as_object ( ) . get_attr ( "items" , vm) ?;
815+ let items_iter = items_fn. call ( ( ) , vm) ?;
816+ let iter = items_iter. get_iter ( vm) ?;
817+ let none = vm. ctx . none ( ) ;
818+
819+ Ok ( vm
820+ . ctx
821+ . new_tuple ( vec ! [
822+ cls. into( ) ,
823+ factory_tuple. into( ) ,
824+ none. clone( ) ,
825+ none,
826+ iter. into( ) ,
827+ ] )
828+ . into ( ) )
829+ }
830+ }
831+
832+ impl PyDefaultDict {
833+ fn __or__ ( lhs : PyObjectRef , rhs : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
834+ Ok ( if let Some ( zelf) = lhs. downcast_ref :: < Self > ( ) {
835+ if !rhs. fast_isinstance ( vm. ctx . types . dict_type ) {
836+ vm. ctx . not_implemented . clone ( ) . into ( )
837+ } else {
838+ let default_factory = zelf. default_factory . read ( ) . clone ( ) ;
839+ let dict = zelf. dict . copy ( ) ;
840+
841+ dict. update ( rhs. into ( ) , KwArgs :: default ( ) , vm) ?;
842+
843+ Self {
844+ dict,
845+ default_factory : PyRwLock :: new ( default_factory) ,
846+ }
847+ . to_pyobject ( vm)
848+ }
849+ } else if let Some ( zelf) = rhs. downcast_ref :: < Self > ( ) {
850+ let default_factory = zelf. default_factory . read ( ) . clone ( ) ;
851+ if let Some ( dict) = lhs. downcast_ref :: < PyDict > ( ) {
852+ let dict = dict. copy ( ) ;
853+ dict. update ( rhs. into ( ) , KwArgs :: default ( ) , vm) ?;
854+
855+ Self {
856+ dict,
857+ default_factory : PyRwLock :: new ( default_factory) ,
858+ }
859+ . to_pyobject ( vm)
860+ } else {
861+ vm. ctx . not_implemented . clone ( ) . into ( )
862+ }
863+ } else {
864+ return Err ( vm. new_type_error ( format ! (
865+ "unsupported operand type(s) for |: '{}' and '{}'" ,
866+ lhs. class( ) . name( ) ,
867+ rhs. class( ) . name( )
868+ ) ) ) ;
869+ } )
870+ }
871+ }
872+
873+ impl DefaultConstructor for PyDefaultDict { }
874+
875+ impl Initializer for PyDefaultDict {
876+ type Args = FuncArgs ;
877+
878+ fn init ( zelf : PyRef < Self > , mut args : Self :: Args , vm : & VirtualMachine ) -> PyResult < ( ) > {
879+ let default_factory = args. take_positional ( ) . map_or ( Ok ( None ) , |factory| {
880+ let is_none = factory. is ( & vm. ctx . none ( ) ) ;
881+
882+ if !is_none && !factory. is_callable ( ) {
883+ Err ( vm. new_type_error ( "first argument must be callable or None" ) )
884+ } else if is_none {
885+ Ok ( None )
886+ } else {
887+ Ok ( Some ( factory) )
888+ }
889+ } ) ?;
890+
891+ * zelf. default_factory . write ( ) = default_factory;
892+
893+ zelf. dict . update (
894+ OptionalArg :: from_option ( args. take_positional ( ) ) ,
895+ args. kwargs ,
896+ vm,
897+ ) ?;
898+
899+ Ok ( ( ) )
900+ }
901+ }
902+
903+ impl Representable for PyDefaultDict {
904+ fn repr_str ( zelf : & Py < Self > , vm : & VirtualMachine ) -> PyResult < String > {
905+ let default_factory = zelf. default_factory . read ( ) ;
906+
907+ let factory_repr = match default_factory. as_ref ( ) {
908+ Some ( factory) => {
909+ if let Some ( _guard) = ReprGuard :: enter ( vm, factory) {
910+ factory. repr ( vm) ?. to_string ( )
911+ } else {
912+ String :: from ( "..." )
913+ }
914+ }
915+ None => String :: from ( "None" ) ,
916+ } ;
917+
918+ let dict_repr = Representable :: repr ( & zelf. dict . copy ( ) . into_ref ( & vm. ctx ) , vm) ?;
919+
920+ Ok ( format ! (
921+ "{}({}, {})" ,
922+ zelf. class( ) . name( ) ,
923+ factory_repr,
924+ dict_repr
925+ ) )
926+ }
927+ }
928+
929+ impl AsMapping for PyDefaultDict {
930+ fn as_mapping ( ) -> & ' static PyMappingMethods {
931+ PyDict :: as_mapping ( )
932+ }
933+ }
934+
935+ impl AsNumber for PyDefaultDict {
936+ fn as_number ( ) -> & ' static PyNumberMethods {
937+ static AS_NUMBER : PyNumberMethods = PyNumberMethods {
938+ or : Some ( |a, b, vm| {
939+ PyDefaultDict :: __or__ ( a. to_pyobject ( vm) , b. to_pyobject ( vm) , vm)
940+ } ) ,
941+ ..PyNumberMethods :: NOT_IMPLEMENTED
942+ } ;
943+ & AS_NUMBER
944+ }
945+ }
749946}
0 commit comments