Skip to content

Commit 7b0f9a7

Browse files
committed
bytes: Implement comparison and other binary operations.
Should support everything supported by strings.
1 parent 070c78a commit 7b0f9a7

3 files changed

Lines changed: 65 additions & 9 deletions

File tree

py/objstr.c

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,13 @@ STATIC const byte *find_subbytes(const byte *haystack, machine_uint_t hlen, cons
251251

252252
STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
253253
GET_STR_DATA_LEN(lhs_in, lhs_data, lhs_len);
254+
mp_obj_type_t *lhs_type = mp_obj_get_type(lhs_in);
255+
mp_obj_type_t *rhs_type = mp_obj_get_type(rhs_in);
254256
switch (op) {
255257
case MP_BINARY_OP_ADD:
256258
case MP_BINARY_OP_INPLACE_ADD:
257-
if (MP_OBJ_IS_STR(rhs_in)) {
258-
// add 2 strings
259+
if (lhs_type == rhs_type) {
260+
// add 2 strings or bytes
259261

260262
GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
261263
int alloc_len = lhs_len + rhs_len;
@@ -270,7 +272,7 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
270272

271273
// code for non-qstr
272274
byte *data;
273-
mp_obj_t s = mp_obj_str_builder_start(mp_obj_get_type(lhs_in), alloc_len, &data);
275+
mp_obj_t s = mp_obj_str_builder_start(lhs_type, alloc_len, &data);
274276
memcpy(data, lhs_data, lhs_len);
275277
memcpy(data + lhs_len, rhs_data, rhs_len);
276278
return mp_obj_str_builder_end(s);
@@ -279,7 +281,7 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
279281

280282
case MP_BINARY_OP_IN:
281283
/* NOTE `a in b` is `b.__contains__(a)` */
282-
if (MP_OBJ_IS_STR(rhs_in)) {
284+
if (lhs_type == rhs_type) {
283285
GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
284286
return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len, 1) != NULL);
285287
}
@@ -292,7 +294,7 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
292294
}
293295
int n = MP_OBJ_SMALL_INT_VALUE(rhs_in);
294296
byte *data;
295-
mp_obj_t s = mp_obj_str_builder_start(mp_obj_get_type(lhs_in), lhs_len * n, &data);
297+
mp_obj_t s = mp_obj_str_builder_start(lhs_type, lhs_len * n, &data);
296298
mp_seq_multiply(lhs_data, sizeof(*lhs_data), lhs_len, n, data);
297299
return mp_obj_str_builder_end(s);
298300
}
@@ -310,14 +312,13 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
310312
return str_modulo_format(lhs_in, n_args, args);
311313
}
312314

313-
// These 2 are never passed here, dealt with as a special case in mp_binary_op().
314-
//case MP_BINARY_OP_EQUAL:
315-
//case MP_BINARY_OP_NOT_EQUAL:
315+
//case MP_BINARY_OP_NOT_EQUAL: // This is never passed here
316+
case MP_BINARY_OP_EQUAL: // This will be passed only for bytes, str is dealt with in mp_obj_equal()
316317
case MP_BINARY_OP_LESS:
317318
case MP_BINARY_OP_LESS_EQUAL:
318319
case MP_BINARY_OP_MORE:
319320
case MP_BINARY_OP_MORE_EQUAL:
320-
if (MP_OBJ_IS_STR(rhs_in)) {
321+
if (lhs_type == rhs_type) {
321322
GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
322323
return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, rhs_data, rhs_len));
323324
}

py/sequence.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ bool m_seq_get_fast_slice_indexes(machine_uint_t len, mp_obj_t slice, machine_ui
8383
// Special-case comparison function for sequences of bytes
8484
// Don't pass MP_BINARY_OP_NOT_EQUAL here
8585
bool mp_seq_cmp_bytes(int op, const byte *data1, uint len1, const byte *data2, uint len2) {
86+
if (op == MP_BINARY_OP_EQUAL && len1 != len2) {
87+
return false;
88+
}
89+
8690
// Let's deal only with > & >=
8791
if (op == MP_BINARY_OP_LESS || op == MP_BINARY_OP_LESS_EQUAL) {
8892
SWAP(const byte*, data1, data2);

tests/basics/bytes_compare.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
print(b"" == b"")
2+
print(b"" > b"")
3+
print(b"" < b"")
4+
print(b"" == b"1")
5+
print(b"1" == b"")
6+
print("==")
7+
print(b"" > b"1")
8+
print(b"1" > b"")
9+
print(b"" < b"1")
10+
print(b"1" < b"")
11+
print(b"" >= b"1")
12+
print(b"1" >= b"")
13+
print(b"" <= b"1")
14+
print(b"1" <= b"")
15+
16+
print(b"1" == b"1")
17+
print(b"1" != b"1")
18+
print(b"1" == b"2")
19+
print(b"1" == b"10")
20+
21+
print(b"1" > b"1")
22+
print(b"1" > b"2")
23+
print(b"2" > b"1")
24+
print(b"10" > b"1")
25+
print(b"1/" > b"1")
26+
print(b"1" > b"10")
27+
print(b"1" > b"1/")
28+
29+
print(b"1" < b"1")
30+
print(b"2" < b"1")
31+
print(b"1" < b"2")
32+
print(b"1" < b"10")
33+
print(b"1" < b"1/")
34+
print(b"10" < b"1")
35+
print(b"1/" < b"1")
36+
37+
print(b"1" >= b"1")
38+
print(b"1" >= b"2")
39+
print(b"2" >= b"1")
40+
print(b"10" >= b"1")
41+
print(b"1/" >= b"1")
42+
print(b"1" >= b"10")
43+
print(b"1" >= b"1/")
44+
45+
print(b"1" <= b"1")
46+
print(b"2" <= b"1")
47+
print(b"1" <= b"2")
48+
print(b"1" <= b"10")
49+
print(b"1" <= b"1/")
50+
print(b"10" <= b"1")
51+
print(b"1/" <= b"1")

0 commit comments

Comments
 (0)