Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,43 @@ def keyfunc(element):
items = list(grouper_iter)
self.assertEqual(len(items), 1)

@threading_helper.requires_working_threading()
def test_groupby_concurrent_next_does_not_crash(self):
# regression test for gh-150791
# Concurrent next calls on a shared groupby object should
# not race / corrupt state.
class K:
__slots__ = ("v",)
def __init__(self, v):
self.v = v
def __eq__(self, other):
if not isinstance(other, K):
return NotImplemented
return self.v == other.v
def __hash__(self):
return hash(self.v)

keys = [K(i) for i in range(5_000)]
g = itertools.groupby(keys)
errors = []

def consume():
try:
while True:
_, _ = next(g)
except StopIteration:
pass
except Exception as e:
errors.append(e)

threads = [threading.Thread(target=consume) for _ in range(8)]
for t in threads:
t.start()
for t in threads:
t.join()

self.assertEqual(errors, []) # must pass with ThreadSanitizer

def test_filter(self):
self.assertEqual(list(filter(isEven, range(6))), [0,2,4])
self.assertEqual(list(filter(None, [0,1,0,2,0])), [1,2])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a data race in :func:`itertools.groupby` on free-threaded builds where concurrent calls to :func:`next` could corrupt the iterator's internal state.
24 changes: 22 additions & 2 deletions Modules/itertoolsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ groupby_step(groupbyobject *gbo)
}

static PyObject *
groupby_next(PyObject *op)
groupby_next_lock_held(PyObject *op)
{
PyObject *grouper;
groupbyobject *gbo = groupbyobject_CAST(op);
Expand Down Expand Up @@ -574,6 +574,16 @@ groupby_next(PyObject *op)
return _PyTuple_FromPairSteal(Py_NewRef(gbo->currkey), grouper);
}

static PyObject *
groupby_next(PyObject *op)
{
PyObject *result;
Py_BEGIN_CRITICAL_SECTION(op);
result = groupby_next_lock_held(op);
Py_END_CRITICAL_SECTION()
return result;
}

static PyType_Slot groupby_slots[] = {
{Py_tp_dealloc, groupby_dealloc},
{Py_tp_getattro, PyObject_GenericGetAttr},
Expand Down Expand Up @@ -659,7 +669,7 @@ _grouper_traverse(PyObject *op, visitproc visit, void *arg)
}

static PyObject *
_grouper_next(PyObject *op)
_grouper_next_lock_held(PyObject *op)
{
_grouperobject *igo = _grouperobject_CAST(op);
groupbyobject *gbo = groupbyobject_CAST(igo->parent);
Expand Down Expand Up @@ -695,6 +705,16 @@ _grouper_next(PyObject *op)
return r;
}

static PyObject *
_grouper_next(PyObject *op)
{
PyObject *result;
Py_BEGIN_CRITICAL_SECTION(_grouperobject_CAST(op)->parent);
result = _grouper_next_lock_held(op);
Py_END_CRITICAL_SECTION()
return result;
}

static PyType_Slot _grouper_slots[] = {
{Py_tp_dealloc, _grouper_dealloc},
{Py_tp_getattro, PyObject_GenericGetAttr},
Expand Down
Loading