11use std:: cell:: { Cell , RefCell } ;
22use std:: cmp:: Ordering ;
33use std:: ops:: { AddAssign , SubAssign } ;
4+ use std:: rc:: Rc ;
45
56use num_bigint:: BigInt ;
67use num_traits:: ToPrimitive ;
@@ -10,9 +11,12 @@ use crate::obj::objbool;
1011use crate :: obj:: objint;
1112use crate :: obj:: objint:: { PyInt , PyIntRef } ;
1213use crate :: obj:: objiter:: { call_next, get_iter, new_stop_iteration} ;
14+ use crate :: obj:: objtuple:: PyTuple ;
1315use crate :: obj:: objtype;
1416use crate :: obj:: objtype:: PyClassRef ;
15- use crate :: pyobject:: { IdProtocol , PyCallable , PyClassImpl , PyObjectRef , PyRef , PyResult , PyValue } ;
17+ use crate :: pyobject:: {
18+ IdProtocol , PyCallable , PyClassImpl , PyObjectRef , PyRef , PyResult , PyValue , TypeProtocol ,
19+ } ;
1620use crate :: vm:: VirtualMachine ;
1721
1822#[ pyclass( name = "chain" ) ]
@@ -629,6 +633,114 @@ impl PyItertoolsAccumulate {
629633 }
630634}
631635
636+ #[ derive( Debug ) ]
637+ struct PyItertoolsTeeData {
638+ iterable : PyObjectRef ,
639+ values : RefCell < Vec < PyObjectRef > > ,
640+ }
641+
642+ impl PyItertoolsTeeData {
643+ fn new (
644+ iterable : PyObjectRef ,
645+ vm : & VirtualMachine ,
646+ ) -> Result < Rc < PyItertoolsTeeData > , PyObjectRef > {
647+ Ok ( Rc :: new ( PyItertoolsTeeData {
648+ iterable : get_iter ( vm, & iterable) ?,
649+ values : RefCell :: new ( vec ! [ ] ) ,
650+ } ) )
651+ }
652+
653+ fn get_item ( & self , vm : & VirtualMachine , index : usize ) -> PyResult {
654+ if self . values . borrow ( ) . len ( ) == index {
655+ let result = call_next ( vm, & self . iterable ) ?;
656+ self . values . borrow_mut ( ) . push ( result) ;
657+ }
658+ Ok ( self . values . borrow ( ) [ index] . clone ( ) )
659+ }
660+ }
661+
662+ #[ pyclass]
663+ #[ derive( Debug ) ]
664+ struct PyItertoolsTee {
665+ tee_data : Rc < PyItertoolsTeeData > ,
666+ index : Cell < usize > ,
667+ }
668+
669+ impl PyValue for PyItertoolsTee {
670+ fn class ( vm : & VirtualMachine ) -> PyClassRef {
671+ vm. class ( "itertools" , "tee" )
672+ }
673+ }
674+
675+ #[ pyimpl]
676+ impl PyItertoolsTee {
677+ fn from_iter ( iterable : PyObjectRef , vm : & VirtualMachine ) -> PyResult < PyObjectRef > {
678+ let it = get_iter ( vm, & iterable) ?;
679+ if it. class ( ) . is ( & PyItertoolsTee :: class ( vm) ) {
680+ return vm. call_method ( & it, "__copy__" , PyFuncArgs :: from ( vec ! [ ] ) ) ;
681+ }
682+ Ok ( PyItertoolsTee {
683+ tee_data : PyItertoolsTeeData :: new ( it, vm) ?,
684+ index : Cell :: from ( 0 ) ,
685+ }
686+ . into_ref_with_type ( vm, PyItertoolsTee :: class ( vm) ) ?
687+ . into_object ( ) )
688+ }
689+
690+ #[ pymethod( name = "__new__" ) ]
691+ #[ allow( clippy:: new_ret_no_self) ]
692+ fn new (
693+ _cls : PyClassRef ,
694+ iterable : PyObjectRef ,
695+ n : OptionalArg < PyIntRef > ,
696+ vm : & VirtualMachine ,
697+ ) -> PyResult < PyRef < PyTuple > > {
698+ let n = match n {
699+ OptionalArg :: Present ( x) => match x. as_bigint ( ) . to_usize ( ) {
700+ Some ( y) => y,
701+ None => return Err ( vm. new_overflow_error ( String :: from ( "n is too big" ) ) ) ,
702+ } ,
703+ OptionalArg :: Missing => 2 ,
704+ } ;
705+
706+ let copyable = if objtype:: class_has_attr ( & iterable. class ( ) , "__copy__" ) {
707+ vm. call_method ( & iterable, "__copy__" , PyFuncArgs :: from ( vec ! [ ] ) ) ?
708+ } else {
709+ PyItertoolsTee :: from_iter ( iterable, vm) ?
710+ } ;
711+
712+ let mut tee_vec: Vec < PyObjectRef > = Vec :: with_capacity ( n) ;
713+ for _ in 0 ..n {
714+ let no_args = PyFuncArgs :: from ( vec ! [ ] ) ;
715+ tee_vec. push ( vm. call_method ( & copyable, "__copy__" , no_args) ?) ;
716+ }
717+
718+ Ok ( PyTuple :: from ( tee_vec) . into_ref ( vm) )
719+ }
720+
721+ #[ pymethod( name = "__copy__" ) ]
722+ fn copy ( & self , vm : & VirtualMachine ) -> PyResult {
723+ Ok ( PyItertoolsTee {
724+ tee_data : Rc :: clone ( & self . tee_data ) ,
725+ index : self . index . clone ( ) ,
726+ }
727+ . into_ref_with_type ( vm, Self :: class ( vm) ) ?
728+ . into_object ( ) )
729+ }
730+
731+ #[ pymethod( name = "__next__" ) ]
732+ fn next ( & self , vm : & VirtualMachine ) -> PyResult {
733+ let value = self . tee_data . get_item ( vm, self . index . get ( ) ) ?;
734+ self . index . set ( self . index . get ( ) + 1 ) ;
735+ Ok ( value)
736+ }
737+
738+ #[ pymethod( name = "__iter__" ) ]
739+ fn iter ( zelf : PyRef < Self > , _vm : & VirtualMachine ) -> PyRef < Self > {
740+ zelf
741+ }
742+ }
743+
632744pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
633745 let ctx = & vm. ctx ;
634746
@@ -658,6 +770,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
658770 let accumulate = ctx. new_class ( "accumulate" , ctx. object ( ) ) ;
659771 PyItertoolsAccumulate :: extend_class ( ctx, & accumulate) ;
660772
773+ let tee = ctx. new_class ( "tee" , ctx. object ( ) ) ;
774+ PyItertoolsTee :: extend_class ( ctx, & tee) ;
775+
661776 py_module ! ( vm, "itertools" , {
662777 "chain" => chain,
663778 "compress" => compress,
@@ -669,5 +784,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
669784 "islice" => islice,
670785 "filterfalse" => filterfalse,
671786 "accumulate" => accumulate,
787+ "tee" => tee,
672788 } )
673789}
0 commit comments