Skip to content

Commit 45da6b8

Browse files
committed
Add list.sort, sorted()
1 parent c33abe9 commit 45da6b8

File tree

4 files changed

+151
-7
lines changed

4 files changed

+151
-7
lines changed

tests/snippets/list.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,42 @@ def __eq__(self, x):
117117
assert a == b
118118

119119
assert [foo] == [foo]
120+
121+
for size in [1, 2, 3, 4, 5, 8, 10, 100, 1000]:
122+
lst = list(range(size))
123+
orig = lst[:]
124+
lst.sort()
125+
assert lst == orig
126+
assert sorted(lst) == orig
127+
assert_raises(ZeroDivisionError, lambda: sorted(lst, key=lambda x: 1/x))
128+
lst.reverse()
129+
assert sorted(lst) == orig
130+
assert sorted(lst, reverse=True) == lst
131+
assert sorted(lst, key=lambda x: -x) == lst
132+
assert sorted(lst, key=lambda x: -x, reverse=True) == orig
133+
134+
assert sorted([(1, 2, 3), (0, 3, 6)]) == [(0, 3, 6), (1, 2, 3)]
135+
assert sorted([(1, 2, 3), (0, 3, 6)], key=lambda x: x[0]) == [(0, 3, 6), (1, 2, 3)]
136+
assert sorted([(1, 2, 3), (0, 3, 6)], key=lambda x: x[1]) == [(1, 2, 3), (0, 3, 6)]
137+
assert sorted([(1, 2), (), (5,)], key=len) == [(), (5,), (1, 2)]
138+
139+
lst = [3, 1, 5, 2, 4]
140+
class C:
141+
def __init__(self, x): self.x = x
142+
def __lt__(self, other): return self.x < other.x
143+
lst.sort(key=C)
144+
assert lst == [1, 2, 3, 4, 5]
145+
146+
lst = [3, 1, 5, 2, 4]
147+
class C:
148+
def __init__(self, x): self.x = x
149+
def __gt__(self, other): return self.x > other.x
150+
lst.sort(key=C)
151+
assert lst == [1, 2, 3, 4, 5]
152+
153+
lst = [5, 1, 2, 3, 4]
154+
def f(x):
155+
lst.append(1)
156+
return x
157+
assert_raises(ValueError, lambda: lst.sort(key=f)) # "list modified during sort"
158+
assert lst == [1, 2, 3, 4, 5]

vm/src/builtins.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,16 @@ fn builtin_setattr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
689689
}
690690

691691
// builtin_slice
692-
// builtin_sorted
692+
693+
fn builtin_sorted(vm: &mut VirtualMachine, mut args: PyFuncArgs) -> PyResult {
694+
arg_check!(vm, args, required = [(iterable, None)]);
695+
let items = vm.extract_elements(iterable)?;
696+
let lst = vm.ctx.new_list(items);
697+
698+
args.shift();
699+
vm.call_method_pyargs(&lst, "sort", args)?;
700+
Ok(lst)
701+
}
693702

694703
fn builtin_sum(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
695704
arg_check!(vm, args, required = [(iterable, None)]);
@@ -763,6 +772,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
763772
"round" => ctx.new_rustfunc(builtin_round),
764773
"set" => ctx.set_type(),
765774
"setattr" => ctx.new_rustfunc(builtin_setattr),
775+
"sorted" => ctx.new_rustfunc(builtin_sorted),
766776
"slice" => ctx.slice_type(),
767777
"staticmethod" => ctx.staticmethod_type(),
768778
"str" => ctx.str_type(),

vm/src/obj/objlist.rs

Lines changed: 93 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use std::cell::{Cell, RefCell};
33
use super::objbool;
44
use super::objint;
55
use super::objsequence::{
6-
get_elements, get_item, get_mut_elements, seq_equal, seq_ge, seq_gt, seq_le, seq_lt, seq_mul,
7-
PySliceableSequence,
6+
get_elements, get_elements_cell, get_item, get_mut_elements, seq_equal, seq_ge, seq_gt, seq_le,
7+
seq_lt, seq_mul, PySliceableSequence,
88
};
99
use super::objstr;
1010
use super::objtype;
@@ -334,12 +334,99 @@ fn list_reverse(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
334334
Ok(vm.get_none())
335335
}
336336

337+
fn quicksort(
338+
vm: &mut VirtualMachine,
339+
keys: &mut [PyObjectRef],
340+
values: &mut [PyObjectRef],
341+
) -> PyResult<()> {
342+
let len = values.len();
343+
if len >= 2 {
344+
let pivot = partition(vm, keys, values)?;
345+
quicksort(vm, &mut keys[0..pivot], &mut values[0..pivot])?;
346+
quicksort(vm, &mut keys[pivot + 1..len], &mut values[pivot + 1..len])?;
347+
}
348+
Ok(())
349+
}
350+
351+
fn partition(
352+
vm: &mut VirtualMachine,
353+
keys: &mut [PyObjectRef],
354+
values: &mut [PyObjectRef],
355+
) -> PyResult<usize> {
356+
let len = values.len();
357+
let pivot = len / 2;
358+
359+
values.swap(pivot, len - 1);
360+
keys.swap(pivot, len - 1);
361+
362+
let mut store_idx = 0;
363+
for i in 0..len - 1 {
364+
let result = vm._lt(keys[i].clone(), keys[len - 1].clone())?;
365+
let boolval = objbool::boolval(vm, result)?;
366+
if boolval {
367+
values.swap(i, store_idx);
368+
keys.swap(i, store_idx);
369+
store_idx += 1;
370+
}
371+
}
372+
373+
values.swap(store_idx, len - 1);
374+
keys.swap(store_idx, len - 1);
375+
Ok(store_idx)
376+
}
377+
378+
fn do_sort(
379+
vm: &mut VirtualMachine,
380+
values: &mut Vec<PyObjectRef>,
381+
key_func: Option<PyObjectRef>,
382+
reverse: bool,
383+
) -> PyResult<()> {
384+
// build a list of keys. If no keyfunc is provided, it's a copy of the list.
385+
let mut keys: Vec<PyObjectRef> = vec![];
386+
for x in values.iter() {
387+
keys.push(match &key_func {
388+
None => x.clone(),
389+
Some(ref func) => vm.invoke(
390+
(*func).clone(),
391+
PyFuncArgs {
392+
args: vec![x.clone()],
393+
kwargs: vec![],
394+
},
395+
)?,
396+
});
397+
}
398+
399+
quicksort(vm, &mut keys, values)?;
400+
401+
if reverse {
402+
values.reverse();
403+
}
404+
405+
Ok(())
406+
}
407+
337408
fn list_sort(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
338409
arg_check!(vm, args, required = [(list, Some(vm.ctx.list_type()))]);
339-
let mut _elements = get_mut_elements(list);
340-
unimplemented!("TODO: figure out how to invoke `sort_by` on a Vec");
341-
// elements.sort_by();
342-
// Ok(vm.get_none())
410+
let key_func = args.get_optional_kwarg("key");
411+
let reverse = args.get_optional_kwarg("reverse");
412+
let reverse = match reverse {
413+
None => false,
414+
Some(val) => objbool::boolval(vm, val)?,
415+
};
416+
417+
let elements_cell = get_elements_cell(list);
418+
// replace list contents with [] for duration of sort.
419+
// this prevents keyfunc from messing with the list and makes it easy to
420+
// check if it tries to append elements to it.
421+
let mut elements = elements_cell.replace(vec![]);
422+
do_sort(vm, &mut elements, key_func, reverse)?;
423+
let temp_elements = elements_cell.replace(elements);
424+
425+
if !temp_elements.is_empty() {
426+
return Err(vm.new_value_error("list modified during sort".to_string()));
427+
}
428+
429+
Ok(vm.get_none())
343430
}
344431

345432
fn list_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {

vm/src/obj/objsequence.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,14 @@ pub fn seq_mul(elements: &[PyObjectRef], product: &PyObjectRef) -> Vec<PyObjectR
302302
new_elements
303303
}
304304

305+
pub fn get_elements_cell<'a>(obj: &'a PyObjectRef) -> &'a RefCell<Vec<PyObjectRef>> {
306+
if let PyObjectPayload::Sequence { ref elements } = obj.payload {
307+
elements
308+
} else {
309+
panic!("Cannot extract elements from non-sequence");
310+
}
311+
}
312+
305313
pub fn get_elements<'a>(obj: &'a PyObjectRef) -> impl Deref<Target = Vec<PyObjectRef>> + 'a {
306314
if let PyObjectPayload::Sequence { ref elements } = obj.payload {
307315
elements.borrow()

0 commit comments

Comments
 (0)