Skip to content

Commit a65c03c

Browse files
committed
py: Allow +, in, and compare ops between bytes and bytearray/array.
Eg b"123" + bytearray(2) now works. This patch actually decreases code size while adding functionality: 32-bit unix down by 128 bytes, stmhal down by 84 bytes.
1 parent 346aacf commit a65c03c

File tree

3 files changed

+103
-81
lines changed

3 files changed

+103
-81
lines changed

py/objstr.c

Lines changed: 86 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -285,99 +285,107 @@ STATIC const byte *find_subbytes(const byte *haystack, mp_uint_t hlen, const byt
285285
// works because both those types use it as their binary_op method. Revisit
286286
// MP_OBJ_IS_STR_OR_BYTES if this fact changes.
287287
mp_obj_t mp_obj_str_binary_op(mp_uint_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
288-
GET_STR_DATA_LEN(lhs_in, lhs_data, lhs_len);
288+
// check for modulo
289+
if (op == MP_BINARY_OP_MODULO) {
290+
mp_obj_t *args;
291+
mp_uint_t n_args;
292+
mp_obj_t dict = MP_OBJ_NULL;
293+
if (MP_OBJ_IS_TYPE(rhs_in, &mp_type_tuple)) {
294+
// TODO: Support tuple subclasses?
295+
mp_obj_tuple_get(rhs_in, &n_args, &args);
296+
} else if (MP_OBJ_IS_TYPE(rhs_in, &mp_type_dict)) {
297+
args = NULL;
298+
n_args = 0;
299+
dict = rhs_in;
300+
} else {
301+
args = &rhs_in;
302+
n_args = 1;
303+
}
304+
return str_modulo_format(lhs_in, n_args, args, dict);
305+
}
306+
307+
// from now on we need lhs type and data, so extract them
289308
mp_obj_type_t *lhs_type = mp_obj_get_type(lhs_in);
290-
mp_obj_type_t *rhs_type = mp_obj_get_type(rhs_in);
291-
switch (op) {
292-
case MP_BINARY_OP_ADD:
293-
case MP_BINARY_OP_INPLACE_ADD:
294-
if (lhs_type == rhs_type) {
295-
// add 2 strings or bytes
296-
297-
GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
298-
mp_uint_t alloc_len = lhs_len + rhs_len;
299-
300-
/* code for making qstr
301-
byte *q_ptr;
302-
byte *val = qstr_build_start(alloc_len, &q_ptr);
303-
memcpy(val, lhs_data, lhs_len);
304-
memcpy(val + lhs_len, rhs_data, rhs_len);
305-
return MP_OBJ_NEW_QSTR(qstr_build_end(q_ptr));
306-
*/
307-
308-
// code for non-qstr
309-
byte *data;
310-
mp_obj_t s = mp_obj_str_builder_start(lhs_type, alloc_len, &data);
311-
memcpy(data, lhs_data, lhs_len);
312-
memcpy(data + lhs_len, rhs_data, rhs_len);
313-
return mp_obj_str_builder_end(s);
314-
}
315-
break;
309+
GET_STR_DATA_LEN(lhs_in, lhs_data, lhs_len);
316310

317-
case MP_BINARY_OP_IN:
318-
/* NOTE `a in b` is `b.__contains__(a)` */
319-
if (lhs_type == rhs_type) {
320-
GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
321-
return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len, 1) != NULL);
311+
// check for multiply
312+
if (op == MP_BINARY_OP_MULTIPLY) {
313+
mp_int_t n;
314+
if (!mp_obj_get_int_maybe(rhs_in, &n)) {
315+
return MP_OBJ_NULL; // op not supported
316+
}
317+
if (n <= 0) {
318+
if (lhs_type == &mp_type_str) {
319+
return MP_OBJ_NEW_QSTR(MP_QSTR_); // empty str
320+
} else {
321+
return mp_const_empty_bytes;
322322
}
323-
break;
323+
}
324+
byte *data;
325+
mp_obj_t s = mp_obj_str_builder_start(lhs_type, lhs_len * n, &data);
326+
mp_seq_multiply(lhs_data, sizeof(*lhs_data), lhs_len, n, data);
327+
return mp_obj_str_builder_end(s);
328+
}
329+
330+
// From now on all operations allow:
331+
// - str with str
332+
// - bytes with bytes
333+
// - bytes with bytearray
334+
// - bytes with array.array
335+
// To do this efficiently we use the buffer protocol to extract the raw
336+
// data for the rhs, but only if the lhs is a bytes object.
337+
//
338+
// NOTE: CPython does not allow comparison between bytes ard array.array
339+
// (even if the array is of type 'b'), even though it allows addition of
340+
// such types. We are not compatible with this (we do allow comparison
341+
// of bytes with anything that has the buffer protocol). It would be
342+
// easy to "fix" this with a bit of extra logic below, but it costs code
343+
// size and execution time so we don't.
344+
345+
const byte *rhs_data;
346+
mp_uint_t rhs_len;
347+
if (lhs_type == mp_obj_get_type(rhs_in)) {
348+
GET_STR_DATA_LEN(rhs_in, rhs_data_, rhs_len_);
349+
rhs_data = rhs_data_;
350+
rhs_len = rhs_len_;
351+
} else if (lhs_type == &mp_type_bytes) {
352+
mp_buffer_info_t bufinfo;
353+
if (!mp_get_buffer(rhs_in, &bufinfo, MP_BUFFER_READ)) {
354+
goto incompatible;
355+
}
356+
rhs_data = bufinfo.buf;
357+
rhs_len = bufinfo.len;
358+
} else {
359+
// incompatible types
360+
incompatible:
361+
if (op == MP_BINARY_OP_EQUAL) {
362+
return mp_const_false; // can check for equality against every type
363+
}
364+
return MP_OBJ_NULL; // op not supported
365+
}
324366

325-
case MP_BINARY_OP_MULTIPLY: {
326-
mp_int_t n;
327-
if (!mp_obj_get_int_maybe(rhs_in, &n)) {
328-
return MP_OBJ_NULL; // op not supported
329-
}
330-
if (n <= 0) {
331-
if (lhs_type == &mp_type_str) {
332-
return MP_OBJ_NEW_QSTR(MP_QSTR_); // empty str
333-
}
334-
n = 0;
335-
}
367+
switch (op) {
368+
case MP_BINARY_OP_ADD:
369+
case MP_BINARY_OP_INPLACE_ADD: {
370+
mp_uint_t alloc_len = lhs_len + rhs_len;
336371
byte *data;
337-
mp_obj_t s = mp_obj_str_builder_start(lhs_type, lhs_len * n, &data);
338-
mp_seq_multiply(lhs_data, sizeof(*lhs_data), lhs_len, n, data);
372+
mp_obj_t s = mp_obj_str_builder_start(lhs_type, alloc_len, &data);
373+
memcpy(data, lhs_data, lhs_len);
374+
memcpy(data + lhs_len, rhs_data, rhs_len);
339375
return mp_obj_str_builder_end(s);
340376
}
341377

342-
case MP_BINARY_OP_MODULO: {
343-
mp_obj_t *args;
344-
mp_uint_t n_args;
345-
mp_obj_t dict = MP_OBJ_NULL;
346-
if (MP_OBJ_IS_TYPE(rhs_in, &mp_type_tuple)) {
347-
// TODO: Support tuple subclasses?
348-
mp_obj_tuple_get(rhs_in, &n_args, &args);
349-
} else if (MP_OBJ_IS_TYPE(rhs_in, &mp_type_dict)) {
350-
args = NULL;
351-
n_args = 0;
352-
dict = rhs_in;
353-
} else {
354-
args = &rhs_in;
355-
n_args = 1;
356-
}
357-
return str_modulo_format(lhs_in, n_args, args, dict);
358-
}
378+
case MP_BINARY_OP_IN:
379+
/* NOTE `a in b` is `b.__contains__(a)` */
380+
return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len, 1) != NULL);
359381

360382
//case MP_BINARY_OP_NOT_EQUAL: // This is never passed here
361383
case MP_BINARY_OP_EQUAL: // This will be passed only for bytes, str is dealt with in mp_obj_equal()
362384
case MP_BINARY_OP_LESS:
363385
case MP_BINARY_OP_LESS_EQUAL:
364386
case MP_BINARY_OP_MORE:
365387
case MP_BINARY_OP_MORE_EQUAL:
366-
if (lhs_type == rhs_type) {
367-
GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
368-
return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, rhs_data, rhs_len));
369-
}
370-
if (lhs_type == &mp_type_bytes) {
371-
mp_buffer_info_t bufinfo;
372-
if (!mp_get_buffer(rhs_in, &bufinfo, MP_BUFFER_READ)) {
373-
goto uncomparable;
374-
}
375-
return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, bufinfo.buf, bufinfo.len));
376-
}
377-
uncomparable:
378-
if (op == MP_BINARY_OP_EQUAL) {
379-
return mp_const_false;
380-
}
388+
return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, rhs_data, rhs_len));
381389
}
382390

383391
return MP_OBJ_NULL; // op not supported

tests/basics/bytes_add.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# test bytes + other
2+
3+
print(b"123" + b"456")
4+
print(b"123" + bytearray(2))
5+
6+
import array
7+
8+
print(b"123" + array.array('i', [1]))
9+
print(b"\x01\x02" + array.array('b', [1, 2]))

tests/basics/bytes_compare2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
import array
2-
31
print(b"1" == 1)
42
print(b"123" == bytearray(b"123"))
53
print(b"123" == "123")
6-
# CPyhon gives False here
4+
print(b'123' < bytearray(b"124"))
5+
print(b'123' > bytearray(b"122"))
6+
print(bytearray(b"23") in b"1234")
7+
8+
import array
9+
10+
print(array.array('b', [1, 2]) in b'\x01\x02\x03')
11+
# CPython gives False here
712
#print(b"\x01\x02\x03" == array.array("B", [1, 2, 3]))

0 commit comments

Comments
 (0)