11use std:: cell:: { Cell , RefCell } ;
22use std:: cmp:: Ordering ;
3+ use std:: iter;
34use std:: ops:: { AddAssign , SubAssign } ;
45use std:: rc:: Rc ;
56
67use num_bigint:: BigInt ;
78use num_traits:: ToPrimitive ;
89
9- use crate :: function:: { OptionalArg , PyFuncArgs } ;
10+ use crate :: function:: { Args , OptionalArg , PyFuncArgs } ;
1011use crate :: obj:: objbool;
1112use crate :: obj:: objint:: { self , PyInt , PyIntRef } ;
12- use crate :: obj:: objiter:: { call_next, get_iter, new_stop_iteration} ;
13+ use crate :: obj:: objiter:: { call_next, get_all , get_iter, new_stop_iteration} ;
1314use crate :: obj:: objtuple:: PyTuple ;
1415use crate :: obj:: objtype:: { self , PyClassRef } ;
1516use crate :: pyobject:: {
@@ -736,6 +737,123 @@ impl PyItertoolsTee {
736737 }
737738}
738739
740+ #[ pyclass]
741+ #[ derive( Debug ) ]
742+ struct PyIterToolsProduct {
743+ pools : Vec < Vec < PyObjectRef > > ,
744+ idxs : RefCell < Vec < usize > > ,
745+ cur : Cell < usize > ,
746+ stop : Cell < bool > ,
747+ }
748+
749+ impl PyValue for PyIterToolsProduct {
750+ fn class ( vm : & VirtualMachine ) -> PyClassRef {
751+ vm. class ( "itertools" , "product" )
752+ }
753+ }
754+
755+ #[ derive( FromArgs ) ]
756+ struct ProductArgs {
757+ #[ pyarg( keyword_only, optional = true ) ]
758+ repeat : OptionalArg < usize > ,
759+ }
760+
761+ #[ pyimpl]
762+ impl PyIterToolsProduct {
763+ #[ pyslot( new) ]
764+ fn tp_new (
765+ cls : PyClassRef ,
766+ iterables : Args < PyObjectRef > ,
767+ args : ProductArgs ,
768+ vm : & VirtualMachine ,
769+ ) -> PyResult < PyRef < Self > > {
770+ let repeat = match args. repeat . into_option ( ) {
771+ Some ( i) => i,
772+ None => 1 ,
773+ } ;
774+
775+ let mut pools = Vec :: new ( ) ;
776+ for arg in iterables. into_iter ( ) {
777+ let it = get_iter ( vm, & arg) ?;
778+ let pool = get_all ( vm, & it) ?;
779+
780+ pools. push ( pool) ;
781+ }
782+ let pools = iter:: repeat ( pools)
783+ . take ( repeat)
784+ . flatten ( )
785+ . collect :: < Vec < Vec < PyObjectRef > > > ( ) ;
786+
787+ let l = pools. len ( ) ;
788+
789+ PyIterToolsProduct {
790+ pools,
791+ idxs : RefCell :: new ( vec ! [ 0 ; l] ) ,
792+ cur : Cell :: new ( l - 1 ) ,
793+ stop : Cell :: new ( false ) ,
794+ }
795+ . into_ref_with_type ( vm, cls)
796+ }
797+
798+ #[ pymethod( name = "__next__" ) ]
799+ fn next ( & self , vm : & VirtualMachine ) -> PyResult {
800+ // stop signal
801+ if self . stop . get ( ) {
802+ return Err ( new_stop_iteration ( vm) ) ;
803+ }
804+
805+ let pools = & self . pools ;
806+
807+ for p in pools {
808+ if p. is_empty ( ) {
809+ return Err ( new_stop_iteration ( vm) ) ;
810+ }
811+ }
812+
813+ let res = PyTuple :: from (
814+ pools
815+ . iter ( )
816+ . zip ( self . idxs . borrow ( ) . iter ( ) )
817+ . map ( |( pool, idx) | pool[ * idx] . clone ( ) )
818+ . collect :: < Vec < PyObjectRef > > ( ) ,
819+ ) ;
820+
821+ self . update_idxs ( ) ;
822+
823+ if self . is_end ( ) {
824+ self . stop . set ( true ) ;
825+ }
826+
827+ Ok ( res. into_ref ( vm) . into_object ( ) )
828+ }
829+
830+ fn is_end ( & self ) -> bool {
831+ ( self . idxs . borrow ( ) [ self . cur . get ( ) ] == & self . pools [ self . cur . get ( ) ] . len ( ) - 1
832+ && self . cur . get ( ) == 0 )
833+ }
834+
835+ fn update_idxs ( & self ) {
836+ let lst_idx = & self . pools [ self . cur . get ( ) ] . len ( ) - 1 ;
837+
838+ if self . idxs . borrow ( ) [ self . cur . get ( ) ] == lst_idx {
839+ if self . is_end ( ) {
840+ return ;
841+ }
842+ self . idxs . borrow_mut ( ) [ self . cur . get ( ) ] = 0 ;
843+ self . cur . set ( self . cur . get ( ) - 1 ) ;
844+ self . update_idxs ( ) ;
845+ } else {
846+ self . idxs . borrow_mut ( ) [ self . cur . get ( ) ] += 1 ;
847+ self . cur . set ( self . idxs . borrow ( ) . len ( ) - 1 ) ;
848+ }
849+ }
850+
851+ #[ pymethod( name = "__iter__" ) ]
852+ fn iter ( zelf : PyRef < Self > , _vm : & VirtualMachine ) -> PyRef < Self > {
853+ zelf
854+ }
855+ }
856+
739857pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
740858 let ctx = & vm. ctx ;
741859
@@ -767,6 +885,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
767885
768886 let tee = ctx. new_class ( "tee" , ctx. object ( ) ) ;
769887 PyItertoolsTee :: extend_class ( ctx, & tee) ;
888+ let product = ctx. new_class ( "product" , ctx. object ( ) ) ;
889+ PyIterToolsProduct :: extend_class ( ctx, & product) ;
770890
771891 py_module ! ( vm, "itertools" , {
772892 "chain" => chain,
@@ -780,5 +900,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
780900 "filterfalse" => filterfalse,
781901 "accumulate" => accumulate,
782902 "tee" => tee,
903+ "product" => product,
783904 } )
784905}
0 commit comments