@@ -20,6 +20,7 @@ mod decl {
2020 use crate :: builtins:: pystr:: { PyStr , PyStrRef } ;
2121 use crate :: builtins:: pytype:: PyTypeRef ;
2222 use crate :: builtins:: PyInt ;
23+ use crate :: builtins:: { PyByteArray , PyBytes } ;
2324 use crate :: byteslike:: PyBytesLike ;
2425 use crate :: common:: { hash:: PyHash , str:: to_ascii} ;
2526 #[ cfg( feature = "rustpython-compiler" ) ]
@@ -707,21 +708,22 @@ mod decl {
707708 }
708709 }
709710
710- #[ pyfunction]
711- fn round (
711+ #[ derive( FromArgs ) ]
712+ pub struct RoundArgs {
713+ #[ pyarg( any) ]
712714 number : PyObjectRef ,
713- ndigits : OptionalArg < Option < PyIntRef > > ,
714- vm : & VirtualMachine ,
715- ) -> PyResult {
716- let rounded = match ndigits {
717- OptionalArg :: Present ( ndigits ) => match ndigits {
718- Some ( int ) => {
719- let ndigits = vm . call_method ( int . as_object ( ) , "__int__" , ( ) ) ? ;
720- vm . call_method ( & number , "__round__" , ( ndigits , ) ) ?
721- }
722- None => vm. call_method ( & number, "__round__" , ( ) ) ?,
723- } ,
724- OptionalArg :: Missing => {
715+ # [ pyarg ( any , optional ) ]
716+ ndigits : OptionalOption < PyObjectRef > ,
717+ }
718+
719+ # [ pyfunction ]
720+ fn round ( RoundArgs { number , ndigits } : RoundArgs , vm : & VirtualMachine ) -> PyResult {
721+ let rounded = match ndigits . flatten ( ) {
722+ Some ( obj ) => {
723+ let ndigits = vm . to_index ( & obj ) ? ;
724+ vm. call_method ( & number, "__round__" , ( ndigits , ) ) ?
725+ }
726+ None => {
725727 // without a parameter, the result type is coerced to int
726728 vm. call_method ( & number, "__round__" , ( ) ) ?
727729 }
@@ -750,10 +752,35 @@ mod decl {
750752 Ok ( lst)
751753 }
752754
755+ #[ derive( FromArgs ) ]
756+ pub struct SumArgs {
757+ #[ pyarg( positional) ]
758+ iterable : PyIterable ,
759+ #[ pyarg( any, optional) ]
760+ start : OptionalArg < PyObjectRef > ,
761+ }
762+
753763 #[ pyfunction]
754- fn sum ( iterable : PyIterable , start : OptionalArg , vm : & VirtualMachine ) -> PyResult {
764+ fn sum ( SumArgs { iterable , start } : SumArgs , vm : & VirtualMachine ) -> PyResult {
755765 // Start with zero and add at will:
756766 let mut sum = start. into_option ( ) . unwrap_or_else ( || vm. ctx . new_int ( 0 ) ) ;
767+
768+ match_class ! ( match sum {
769+ PyStr =>
770+ return Err ( vm. new_type_error(
771+ "sum() can't sum strings [use ''.join(seq) instead]" . to_owned( )
772+ ) ) ,
773+ PyBytes =>
774+ return Err ( vm. new_type_error(
775+ "sum() can't sum bytes [use b''.join(seq) instead]" . to_owned( )
776+ ) ) ,
777+ PyByteArray =>
778+ return Err ( vm. new_type_error(
779+ "sum() can't sum bytearray [use b''.join(seq) instead]" . to_owned( )
780+ ) ) ,
781+ _ => ( ) ,
782+ } ) ;
783+
757784 for item in iterable. iter ( vm) ? {
758785 sum = vm. _add ( & sum, & item?) ?;
759786 }
0 commit comments