Skip to content

Commit 07241cd

Browse files
committed
py/objstringio: If created from immutable object, follow copy on write policy.
Don't create copy of immutable object's contents until .write() is called on BytesIO.
1 parent b24ccfc commit 07241cd

File tree

5 files changed

+70
-3
lines changed

5 files changed

+70
-3
lines changed

py/objstringio.c

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,23 @@ STATIC mp_uint_t stringio_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *er
6868
return size;
6969
}
7070

71+
STATIC void stringio_copy_on_write(mp_obj_stringio_t *o) {
72+
const void *buf = o->vstr->buf;
73+
o->vstr->buf = m_new(char, o->vstr->len);
74+
memcpy(o->vstr->buf, buf, o->vstr->len);
75+
o->vstr->fixed_buf = false;
76+
o->ref_obj = MP_OBJ_NULL;
77+
}
78+
7179
STATIC mp_uint_t stringio_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) {
7280
(void)errcode;
7381
mp_obj_stringio_t *o = MP_OBJ_TO_PTR(o_in);
7482
check_stringio_is_open(o);
83+
84+
if (o->vstr->fixed_buf) {
85+
stringio_copy_on_write(o);
86+
}
87+
7588
mp_uint_t new_pos = o->pos + size;
7689
if (new_pos < size) {
7790
// Writing <size> bytes will overflow o->pos beyond limit of mp_uint_t.
@@ -155,11 +168,11 @@ STATIC mp_obj_t stringio___exit__(size_t n_args, const mp_obj_t *args) {
155168
}
156169
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(stringio___exit___obj, 4, 4, stringio___exit__);
157170

158-
STATIC mp_obj_stringio_t *stringio_new(const mp_obj_type_t *type, mp_uint_t alloc) {
171+
STATIC mp_obj_stringio_t *stringio_new(const mp_obj_type_t *type) {
159172
mp_obj_stringio_t *o = m_new_obj(mp_obj_stringio_t);
160173
o->base.type = type;
161-
o->vstr = vstr_new(alloc);
162174
o->pos = 0;
175+
o->ref_obj = MP_OBJ_NULL;
163176
return o;
164177
}
165178

@@ -170,17 +183,28 @@ STATIC mp_obj_t stringio_make_new(const mp_obj_type_t *type_in, size_t n_args, s
170183
bool initdata = false;
171184
mp_buffer_info_t bufinfo;
172185

186+
mp_obj_stringio_t *o = stringio_new(type_in);
187+
173188
if (n_args > 0) {
174189
if (MP_OBJ_IS_INT(args[0])) {
175190
sz = mp_obj_get_int(args[0]);
176191
} else {
177192
mp_get_buffer_raise(args[0], &bufinfo, MP_BUFFER_READ);
193+
194+
if (MP_OBJ_IS_STR_OR_BYTES(args[0])) {
195+
o->vstr = m_new_obj(vstr_t);
196+
vstr_init_fixed_buf(o->vstr, bufinfo.len, bufinfo.buf);
197+
o->vstr->len = bufinfo.len;
198+
o->ref_obj = args[0];
199+
return MP_OBJ_FROM_PTR(o);
200+
}
201+
178202
sz = bufinfo.len;
179203
initdata = true;
180204
}
181205
}
182206

183-
mp_obj_stringio_t *o = stringio_new(type_in, sz);
207+
o->vstr = vstr_new(sz);
184208

185209
if (initdata) {
186210
stringio_write(MP_OBJ_FROM_PTR(o), bufinfo.buf, bufinfo.len, NULL);

py/objstringio.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ typedef struct _mp_obj_stringio_t {
3333
vstr_t *vstr;
3434
// StringIO has single pointer used for both reading and writing
3535
mp_uint_t pos;
36+
// Underlying object buffered by this StringIO
37+
mp_obj_t ref_obj;
3638
} mp_obj_stringio_t;
3739

3840
#endif // MICROPY_INCLUDED_PY_OBJSTRINGIO_H

tests/io/bytesio_cow.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Make sure that write operations on io.BytesIO don't
2+
# change original object it was constructed from.
3+
try:
4+
import uio as io
5+
except ImportError:
6+
import io
7+
8+
b = b"foobar"
9+
10+
a = io.BytesIO(b)
11+
a.write(b"1")
12+
print(b)
13+
print(a.getvalue())
14+
15+
b = bytearray(b"foobar")
16+
17+
a = io.BytesIO(b)
18+
a.write(b"1")
19+
print(b)
20+
print(a.getvalue())
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Creating BytesIO from immutable object should not immediately
2+
# copy its content.
3+
try:
4+
import uio
5+
import micropython
6+
micropython.mem_total
7+
except (ImportError, AttributeError):
8+
print("SKIP")
9+
raise SystemExit
10+
11+
12+
data = b"1234" * 256
13+
14+
before = micropython.mem_total()
15+
16+
buf = uio.BytesIO(data)
17+
18+
after = micropython.mem_total()
19+
20+
print(after - before < len(data))
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
True

0 commit comments

Comments
 (0)