diff --git a/py/objlist.c b/py/objlist.c index 41c920511d2d4..4059a223e46ee 100644 --- a/py/objlist.c +++ b/py/objlist.c @@ -132,6 +132,9 @@ static mp_obj_t list_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) { if (n < 0) { n = 0; } + if (o->len != 0 && (size_t)n > SIZE_MAX / o->len) { + mp_raise_msg(&mp_type_OverflowError, MP_ERROR_TEXT("repeated list is too long")); + } mp_obj_list_t *s = list_new(o->len * n); mp_seq_multiply(o->items, sizeof(*o->items), o->len, n, s->items); return MP_OBJ_FROM_PTR(s); diff --git a/py/objstr.c b/py/objstr.c index 06afb91fc7f73..18740b96eb3b7 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -393,6 +393,9 @@ mp_obj_t mp_obj_str_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_i if (n <= 0) { return make_empty_str_of_type(lhs_type); } + if (lhs_len != 0 && (size_t)n > SIZE_MAX / lhs_len) { + mp_raise_msg(&mp_type_OverflowError, MP_ERROR_TEXT("repeated string is too long")); + } vstr_t vstr; vstr_init_len(&vstr, lhs_len * n); mp_seq_multiply(lhs_data, sizeof(*lhs_data), lhs_len, n, vstr.buf); diff --git a/py/objtuple.c b/py/objtuple.c index c0ba1fd0d11e4..9a1798636b3f0 100644 --- a/py/objtuple.c +++ b/py/objtuple.c @@ -169,6 +169,9 @@ mp_obj_t mp_obj_tuple_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) { if (n <= 0) { return mp_const_empty_tuple; } + if (o->len != 0 && (size_t)n > SIZE_MAX / o->len) { + mp_raise_msg(&mp_type_OverflowError, MP_ERROR_TEXT("repeated tuple is too long")); + } mp_obj_tuple_t *s = MP_OBJ_TO_PTR(mp_obj_new_tuple(o->len * n, NULL)); mp_seq_multiply(o->items, sizeof(*o->items), o->len, n, s->items); return MP_OBJ_FROM_PTR(s); diff --git a/tests/basics/sequence_repeat_overflow.py b/tests/basics/sequence_repeat_overflow.py new file mode 100644 index 0000000000000..21f12e773f9c6 --- /dev/null +++ b/tests/basics/sequence_repeat_overflow.py @@ -0,0 +1,14 @@ +# Repeating a sequence by a huge count must raise, not overflow the size +# calculation and corrupt memory. +import sys + +# count chosen so that (length-8 sequence) * count wraps a size_t to a small value +word = 64 if sys.maxsize > 2**32 else 32 +count = (1 << (word - 3)) + 1 + +for x in (b"01234567", "01234567", [0, 1, 2, 3, 4, 5, 6, 7], (0, 1, 2, 3, 4, 5, 6, 7)): + try: + x * count + print("no exception") + except (OverflowError, MemoryError): + print("caught")