Skip to content

Commit 93ec575

Browse files
committed
Merge pull request RustPython#2341 from carbotanium/test
Change sum and round to take keywords
1 parent 7c32f8c commit 93ec575

2 files changed

Lines changed: 43 additions & 22 deletions

File tree

Lib/test/test_builtin.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import builtins
55
import collections
66
import decimal
7-
# import fractions XXX RustPython
7+
import fractions
88
import io
99
import locale
1010
import os
@@ -1200,8 +1200,6 @@ def test_repr(self):
12001200
a[0] = a
12011201
self.assertEqual(repr(a), '{0: {...}}')
12021202

1203-
# TODO: RUSTPYTHON
1204-
@unittest.expectedFailure
12051203
def test_round(self):
12061204
self.assertEqual(round(0.0), 0.0)
12071205
self.assertEqual(type(round(0.0)), int)
@@ -1297,8 +1295,6 @@ def test_round_large(self):
12971295
self.assertEqual(round(5e15+2), 5e15+2)
12981296
self.assertEqual(round(5e15+3), 5e15+3)
12991297

1300-
# TODO: RUSTPYTHON
1301-
@unittest.expectedFailure
13021298
def test_bug_27936(self):
13031299
# Verify that ndigits=None means the same as passing in no argument
13041300
for x in [1234,
@@ -1316,8 +1312,6 @@ def test_setattr(self):
13161312

13171313
# test_str(): see test_unicode.py and test_bytes.py for str() tests.
13181314

1319-
# TODO: RUSTPYTHON
1320-
@unittest.expectedFailure
13211315
def test_sum(self):
13221316
self.assertEqual(sum([]), 0)
13231317
self.assertEqual(sum(list(range(2,8))), 27)

vm/src/builtins/make_module.rs

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mod decl {
2020
use crate::builtins::pystr::{PyStr, PyStrRef};
2121
use crate::builtins::pytype::PyTypeRef;
2222
use crate::builtins::PyInt;
23+
use crate::builtins::{PyByteArray, PyBytes};
2324
use crate::byteslike::PyBytesLike;
2425
use crate::common::{hash::PyHash, str::to_ascii};
2526
#[cfg(feature = "rustpython-compiler")]
@@ -707,21 +708,22 @@ mod decl {
707708
}
708709
}
709710

710-
#[pyfunction]
711-
fn round(
711+
#[derive(FromArgs)]
712+
pub struct RoundArgs {
713+
#[pyarg(any)]
712714
number: PyObjectRef,
713-
ndigits: OptionalArg<Option<PyIntRef>>,
714-
vm: &VirtualMachine,
715-
) -> PyResult {
716-
let rounded = match ndigits {
717-
OptionalArg::Present(ndigits) => match ndigits {
718-
Some(int) => {
719-
let ndigits = vm.call_method(int.as_object(), "__int__", ())?;
720-
vm.call_method(&number, "__round__", (ndigits,))?
721-
}
722-
None => vm.call_method(&number, "__round__", ())?,
723-
},
724-
OptionalArg::Missing => {
715+
#[pyarg(any, optional)]
716+
ndigits: OptionalOption<PyObjectRef>,
717+
}
718+
719+
#[pyfunction]
720+
fn round(RoundArgs { number, ndigits }: RoundArgs, vm: &VirtualMachine) -> PyResult {
721+
let rounded = match ndigits.flatten() {
722+
Some(obj) => {
723+
let ndigits = vm.to_index(&obj)?;
724+
vm.call_method(&number, "__round__", (ndigits,))?
725+
}
726+
None => {
725727
// without a parameter, the result type is coerced to int
726728
vm.call_method(&number, "__round__", ())?
727729
}
@@ -750,10 +752,35 @@ mod decl {
750752
Ok(lst)
751753
}
752754

755+
#[derive(FromArgs)]
756+
pub struct SumArgs {
757+
#[pyarg(positional)]
758+
iterable: PyIterable,
759+
#[pyarg(any, optional)]
760+
start: OptionalArg<PyObjectRef>,
761+
}
762+
753763
#[pyfunction]
754-
fn sum(iterable: PyIterable, start: OptionalArg, vm: &VirtualMachine) -> PyResult {
764+
fn sum(SumArgs { iterable, start }: SumArgs, vm: &VirtualMachine) -> PyResult {
755765
// Start with zero and add at will:
756766
let mut sum = start.into_option().unwrap_or_else(|| vm.ctx.new_int(0));
767+
768+
match_class!(match sum {
769+
PyStr =>
770+
return Err(vm.new_type_error(
771+
"sum() can't sum strings [use ''.join(seq) instead]".to_owned()
772+
)),
773+
PyBytes =>
774+
return Err(vm.new_type_error(
775+
"sum() can't sum bytes [use b''.join(seq) instead]".to_owned()
776+
)),
777+
PyByteArray =>
778+
return Err(vm.new_type_error(
779+
"sum() can't sum bytearray [use b''.join(seq) instead]".to_owned()
780+
)),
781+
_ => (),
782+
});
783+
757784
for item in iterable.iter(vm)? {
758785
sum = vm._add(&sum, &item?)?;
759786
}

0 commit comments

Comments
 (0)