Skip to content
Next Next commit
Acquire strong references after PyDict_Next
  • Loading branch information
sweeneyde committed May 18, 2022
commit 9b707556e04ba81b6f52607cf27565d4d395def2
39 changes: 39 additions & 0 deletions Lib/test/pickletester.py
Original file line number Diff line number Diff line change
Expand Up @@ -3035,6 +3035,45 @@ def check_array(arr):
# 2-D, non-contiguous
check_array(arr[::2])

def test_evil_class_mutating_dict(self):
from random import getrandbits

global Bad
class Bad:
def __eq__(self, other):
if not ENABLED:
return False
return getrandbits(4) == 0
def __hash__(self):
return getrandbits(1)
def __reduce__(self):
break_things()
return (Bad, (), ())
def __setstate__(self, *args):
break_things()
def __del__(self):
break_things()
def __getattr__(self):
break_things()

def break_things():
if ENABLED and getrandbits(6) == 0:
collection.clear()

for proto in protocols:
for _ in range(20):
ENABLED = False
collection = {Bad(): Bad() for _ in range(50)}
for bad in collection:
bad.bad = bad
bad.collection = collection
ENABLED = True
try:
self.loads(self.dumps(collection, proto))
except RuntimeError as e:
expected = "changed size during iteration"
self.assertIn(expected, str(e))


class BigmemPickleTests:

Expand Down
32 changes: 24 additions & 8 deletions Modules/_pickle.c
Original file line number Diff line number Diff line change
Expand Up @@ -3259,10 +3259,16 @@ batch_dict_exact(PicklerObject *self, PyObject *obj)
/* Special-case len(d) == 1 to save space. */
if (dict_size == 1) {
PyDict_Next(obj, &ppos, &key, &value);
if (save(self, key, 0) < 0)
return -1;
if (save(self, value, 0) < 0)
return -1;
Py_INCREF(key);
Py_INCREF(value);
if (save(self, key, 0) < 0) {
goto error;
}
if (save(self, value, 0) < 0) {
goto error;
}
Py_CLEAR(key);
Py_CLEAR(value);
if (_Pickler_Write(self, &setitem_op, 1) < 0)
return -1;
return 0;
Expand All @@ -3274,10 +3280,16 @@ batch_dict_exact(PicklerObject *self, PyObject *obj)
if (_Pickler_Write(self, &mark_op, 1) < 0)
return -1;
while (PyDict_Next(obj, &ppos, &key, &value)) {
if (save(self, key, 0) < 0)
return -1;
if (save(self, value, 0) < 0)
return -1;
Py_INCREF(key);
Py_INCREF(value);
if (save(self, key, 0) < 0) {
goto error;
}
if (save(self, value, 0) < 0) {
goto error;
}
Py_CLEAR(key);
Py_CLEAR(value);
if (++i == BATCHSIZE)
break;
}
Expand All @@ -3292,6 +3304,10 @@ batch_dict_exact(PicklerObject *self, PyObject *obj)

} while (i == BATCHSIZE);
return 0;
error:
Py_XDECREF(key);
Py_XDECREF(value);
return -1;
}

static int
Expand Down