Skip to content

Commit 779794a

Browse files
committed
py: Add dispatch for user defined ==, >, <=, >=.
Addresses issue adafruit#827.
1 parent fa1a9bc commit 779794a

3 files changed

Lines changed: 41 additions & 5 deletions

File tree

py/objtype.c

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -381,11 +381,12 @@ STATIC const qstr binary_op_method_name[] = {
381381
MP_BINARY_OP_INPLACE_MODULO,
382382
MP_BINARY_OP_INPLACE_POWER,*/
383383
[MP_BINARY_OP_LESS] = MP_QSTR___lt__,
384-
/*MP_BINARY_OP_MORE,
385-
MP_BINARY_OP_EQUAL,
386-
MP_BINARY_OP_LESS_EQUAL,
387-
MP_BINARY_OP_MORE_EQUAL,
388-
MP_BINARY_OP_NOT_EQUAL,
384+
[MP_BINARY_OP_MORE] = MP_QSTR___gt__,
385+
[MP_BINARY_OP_EQUAL] = MP_QSTR___eq__,
386+
[MP_BINARY_OP_LESS_EQUAL] = MP_QSTR___le__,
387+
[MP_BINARY_OP_MORE_EQUAL] = MP_QSTR___ge__,
388+
/*
389+
MP_BINARY_OP_NOT_EQUAL, // a != b calls a == b and inverts result
389390
*/
390391
[MP_BINARY_OP_IN] = MP_QSTR___contains__,
391392
/*

py/qstrdefs.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ Q(__getattr__)
6464
Q(__del__)
6565
Q(__call__)
6666
Q(__lt__)
67+
Q(__gt__)
68+
Q(__eq__)
69+
Q(__le__)
70+
Q(__ge__)
6771

6872
Q(micropython)
6973
Q(bytecode)

tests/basics/class_binop.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
class foo(object):
2+
def __init__(self, value):
3+
self.x = value
4+
5+
def __eq__(self, other):
6+
print('eq')
7+
return self.x == other.x
8+
9+
def __lt__(self, other):
10+
print('lt')
11+
return self.x < other.x
12+
13+
def __gt__(self, other):
14+
print('gt')
15+
return self.x > other.x
16+
17+
def __le__(self, other):
18+
print('le')
19+
return self.x <= other.x
20+
21+
def __ge__(self, other):
22+
print('ge')
23+
return self.x >= other.x
24+
25+
for i in range(3):
26+
for j in range(3):
27+
print(foo(i) == foo(j))
28+
print(foo(i) < foo(j))
29+
print(foo(i) > foo(j))
30+
print(foo(i) <= foo(j))
31+
print(foo(i) >= foo(j))

0 commit comments

Comments
 (0)