Skip to content

Commit 744e767

Browse files
committed
py: Make list.sort keep stack usage within O(log(N)) bound.
Also fix list.sort so it works with user-defined types, and parse the keyword arguments properly. Addresses issue adafruit#338.
1 parent ae3150c commit 744e767

2 files changed

Lines changed: 54 additions & 22 deletions

File tree

py/objlist.c

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "py/objlist.h"
3232
#include "py/runtime0.h"
3333
#include "py/runtime.h"
34+
#include "py/stackctrl.h"
3435

3536
STATIC mp_obj_t mp_obj_new_list_iterator(mp_obj_list_t *list, mp_uint_t cur);
3637
STATIC mp_obj_list_t *list_new(mp_uint_t n);
@@ -284,16 +285,15 @@ STATIC mp_obj_t list_pop(mp_uint_t n_args, const mp_obj_t *args) {
284285
return ret;
285286
}
286287

287-
// TODO make this conform to CPython's definition of sort
288-
STATIC void mp_quicksort(mp_obj_t *head, mp_obj_t *tail, mp_obj_t key_fn, bool reversed) {
289-
mp_uint_t op = reversed ? MP_BINARY_OP_MORE : MP_BINARY_OP_LESS;
288+
STATIC void mp_quicksort(mp_obj_t *head, mp_obj_t *tail, mp_obj_t key_fn, mp_obj_t binop_less_result) {
289+
MP_STACK_CHECK();
290290
while (head < tail) {
291291
mp_obj_t *h = head - 1;
292292
mp_obj_t *t = tail;
293-
mp_obj_t v = key_fn == NULL ? tail[0] : mp_call_function_1(key_fn, tail[0]); // get pivot using key_fn
293+
mp_obj_t v = key_fn == MP_OBJ_NULL ? tail[0] : mp_call_function_1(key_fn, tail[0]); // get pivot using key_fn
294294
for (;;) {
295-
do ++h; while (mp_binary_op(op, key_fn == NULL ? h[0] : mp_call_function_1(key_fn, h[0]), v) == mp_const_true);
296-
do --t; while (h < t && mp_binary_op(op, v, key_fn == NULL ? t[0] : mp_call_function_1(key_fn, t[0])) == mp_const_true);
295+
do ++h; while (h < t && mp_binary_op(MP_BINARY_OP_LESS, key_fn == MP_OBJ_NULL ? h[0] : mp_call_function_1(key_fn, h[0]), v) == binop_less_result);
296+
do --t; while (h < t && mp_binary_op(MP_BINARY_OP_LESS, v, key_fn == MP_OBJ_NULL ? t[0] : mp_call_function_1(key_fn, t[0])) == binop_less_result);
297297
if (h >= t) break;
298298
mp_obj_t x = h[0];
299299
h[0] = t[0];
@@ -302,27 +302,38 @@ STATIC void mp_quicksort(mp_obj_t *head, mp_obj_t *tail, mp_obj_t key_fn, bool r
302302
mp_obj_t x = h[0];
303303
h[0] = tail[0];
304304
tail[0] = x;
305-
mp_quicksort(head, t, key_fn, reversed);
306-
head = h + 1;
305+
// do the smaller recursive call first, to keep stack within O(log(N))
306+
if (t - head < tail - h - 1) {
307+
mp_quicksort(head, t, key_fn, binop_less_result);
308+
head = h + 1;
309+
} else {
310+
mp_quicksort(h + 1, tail, key_fn, binop_less_result);
311+
tail = t;
312+
}
307313
}
308314
}
309315

310-
mp_obj_t mp_obj_list_sort(mp_uint_t n_args, const mp_obj_t *args, mp_map_t *kwargs) {
311-
assert(n_args >= 1);
312-
assert(MP_OBJ_IS_TYPE(args[0], &mp_type_list));
313-
if (n_args > 1) {
314-
nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError,
315-
"list.sort takes no positional arguments"));
316-
}
317-
mp_obj_list_t *self = args[0];
316+
// TODO Python defines sort to be stable but ours is not
317+
mp_obj_t mp_obj_list_sort(mp_uint_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
318+
static const mp_arg_t allowed_args[] = {
319+
{ MP_QSTR_key, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_obj = mp_const_none} },
320+
{ MP_QSTR_reverse, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} },
321+
};
322+
323+
// parse args
324+
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
325+
mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
326+
327+
mp_obj_list_t *self = pos_args[0];
328+
assert(MP_OBJ_IS_TYPE(self, &mp_type_list));
329+
318330
if (self->len > 1) {
319-
mp_map_elem_t *keyfun = mp_map_lookup(kwargs, MP_OBJ_NEW_QSTR(MP_QSTR_key), MP_MAP_LOOKUP);
320-
mp_map_elem_t *reverse = mp_map_lookup(kwargs, MP_OBJ_NEW_QSTR(MP_QSTR_reverse), MP_MAP_LOOKUP);
321331
mp_quicksort(self->items, self->items + self->len - 1,
322-
keyfun ? keyfun->value : NULL,
323-
reverse && reverse->value ? mp_obj_is_true(reverse->value) : false);
332+
args[0].u_obj == mp_const_none ? MP_OBJ_NULL : args[0].u_obj,
333+
args[1].u_bool ? mp_const_false : mp_const_true);
324334
}
325-
return mp_const_none; // return None, as per CPython
335+
336+
return mp_const_none;
326337
}
327338

328339
STATIC mp_obj_t list_clear(mp_obj_t self_in) {
@@ -412,7 +423,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_3(list_insert_obj, list_insert);
412423
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(list_pop_obj, 1, 2, list_pop);
413424
STATIC MP_DEFINE_CONST_FUN_OBJ_2(list_remove_obj, list_remove);
414425
STATIC MP_DEFINE_CONST_FUN_OBJ_1(list_reverse_obj, list_reverse);
415-
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(list_sort_obj, 0, mp_obj_list_sort);
426+
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(list_sort_obj, 1, mp_obj_list_sort);
416427

417428
STATIC const mp_map_elem_t list_locals_dict_table[] = {
418429
{ MP_OBJ_NEW_QSTR(MP_QSTR_append), (mp_obj_t)&list_append_obj },

tests/basics/list_sort.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,24 @@
2626
print(l)
2727
print(l == sorted(l, reverse=False))
2828

29+
# test large lists (should not stack overflow)
30+
l = list(range(2000))
31+
l.sort()
32+
print(l[0], l[-1])
33+
l.sort(reverse=True)
34+
print(l[0], l[-1])
35+
36+
# test user-defined ordering
37+
class A:
38+
def __init__(self, x):
39+
self.x = x
40+
def __lt__(self, other):
41+
return self.x > other.x
42+
def __repr__(self):
43+
return str(self.x)
44+
l = [A(5), A(2), A(1), A(3), A(4)]
45+
print(l)
46+
l.sort()
47+
print(l)
48+
l.sort(reverse=True)
49+
print(l)

0 commit comments

Comments
 (0)