Skip to content

Commit d2166dd

Browse files
authored
Fix sorted() to use __lt__ instead of __gt__ (#6887)
* test * Fix sorted() to use __lt__ instead of __gt__ CPython's sort uses __lt__ for comparisons, but RustPython was using __gt__. This caused issues when only __lt__ was overridden on a subclass (e.g., NamedTuple with custom __lt__), as it would fall back to the parent class's comparison instead of using the overridden method.
1 parent 44c3179 commit d2166dd

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

crates/vm/src/builtins/list.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -522,12 +522,17 @@ fn do_sort(
522522
key_func: Option<PyObjectRef>,
523523
reverse: bool,
524524
) -> PyResult<()> {
525-
let op = if reverse {
526-
PyComparisonOp::Lt
527-
} else {
528-
PyComparisonOp::Gt
525+
// CPython uses __lt__ for all comparisons in sort.
526+
// try_sort_by_gt expects is_gt(a, b) = true when a should come AFTER b.
527+
let cmp = |a: &PyObjectRef, b: &PyObjectRef| {
528+
if reverse {
529+
// Descending: a comes after b when a < b
530+
a.rich_compare_bool(b, PyComparisonOp::Lt, vm)
531+
} else {
532+
// Ascending: a comes after b when b < a
533+
b.rich_compare_bool(a, PyComparisonOp::Lt, vm)
534+
}
529535
};
530-
let cmp = |a: &PyObjectRef, b: &PyObjectRef| a.rich_compare_bool(b, op, vm);
531536

532537
if let Some(ref key_func) = key_func {
533538
let mut items = values

extra_tests/snippets/builtin_list.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,47 @@ def __gt__(self, other):
270270
lst.sort(key=C)
271271
assert lst == [1, 2, 3, 4, 5]
272272

273+
# Test that sorted() uses __lt__ (not __gt__) for comparisons.
274+
# Track which comparison method is actually called during sort.
275+
class TrackComparison:
276+
lt_calls = 0
277+
gt_calls = 0
278+
279+
def __init__(self, value):
280+
self.value = value
281+
282+
def __lt__(self, other):
283+
TrackComparison.lt_calls += 1
284+
return self.value < other.value
285+
286+
def __gt__(self, other):
287+
TrackComparison.gt_calls += 1
288+
return self.value > other.value
289+
290+
# Reset and test sorted()
291+
TrackComparison.lt_calls = 0
292+
TrackComparison.gt_calls = 0
293+
items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)]
294+
sorted(items)
295+
assert TrackComparison.lt_calls > 0, "sorted() should call __lt__"
296+
assert TrackComparison.gt_calls == 0, f"sorted() should not call __gt__, but it was called {TrackComparison.gt_calls} times"
297+
298+
# Reset and test list.sort()
299+
TrackComparison.lt_calls = 0
300+
TrackComparison.gt_calls = 0
301+
items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)]
302+
items.sort()
303+
assert TrackComparison.lt_calls > 0, "list.sort() should call __lt__"
304+
assert TrackComparison.gt_calls == 0, f"list.sort() should not call __gt__, but it was called {TrackComparison.gt_calls} times"
305+
306+
# Reset and test sorted(reverse=True) - should still use __lt__, not __gt__
307+
TrackComparison.lt_calls = 0
308+
TrackComparison.gt_calls = 0
309+
items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)]
310+
sorted(items, reverse=True)
311+
assert TrackComparison.lt_calls > 0, "sorted(reverse=True) should call __lt__"
312+
assert TrackComparison.gt_calls == 0, f"sorted(reverse=True) should not call __gt__, but it was called {TrackComparison.gt_calls} times"
313+
273314
lst = [5, 1, 2, 3, 4]
274315

275316

0 commit comments

Comments
 (0)