diff --git a/Lib/test/test_free_threading/test_collections.py b/Lib/test/test_free_threading/test_collections.py index 849b0480e232fc2..f9788654940aac3 100644 --- a/Lib/test/test_free_threading/test_collections.py +++ b/Lib/test/test_free_threading/test_collections.py @@ -1,5 +1,6 @@ +import threading import unittest -from collections import deque +from collections import Counter, deque from copy import copy from test.support import threading_helper @@ -49,5 +50,21 @@ def mutate(): ) +class TestCounter(unittest.TestCase): + def test_update_concurrent(self): + # gh-151633: concurrent Counter.update calls must not cause use-after-free + # under free-threading. + NTHREADS = 4 + PER_THREAD = 5000 + c = Counter() + data = ['x'] * PER_THREAD + threads = [threading.Thread(target=c.update, args=(data,)) + for _ in range(NTHREADS)] + for t in threads: + t.start() + for t in threads: + t.join() + + if __name__ == "__main__": unittest.main() diff --git a/Misc/NEWS.d/next/Library/2026-06-18-11-26-00.gh-issue-151633.uJzHdc.rst b/Misc/NEWS.d/next/Library/2026-06-18-11-26-00.gh-issue-151633.uJzHdc.rst new file mode 100644 index 000000000000000..89fdf33965179ef --- /dev/null +++ b/Misc/NEWS.d/next/Library/2026-06-18-11-26-00.gh-issue-151633.uJzHdc.rst @@ -0,0 +1 @@ +Fix a race condition under free-threading when multiple threads update the same :class:`~collections.Counter` concurrently. diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c index 5ca6362406a78b9..8e4525570ea1365 100644 --- a/Modules/_collectionsmodule.c +++ b/Modules/_collectionsmodule.c @@ -1,6 +1,6 @@ #include "Python.h" #include "pycore_call.h" // _PyObject_CallNoArgs() -#include "pycore_dict.h" // _PyDict_GetItem_KnownHash() +#include "pycore_dict.h" // _PyDict_GetItemRef_KnownHash_LockHeld() #include "pycore_long.h" // _PyLong_GetZero() #include "pycore_moduleobject.h" // _PyModule_GetState() #include "pycore_pyatomic_ft_wrappers.h" @@ -2595,24 +2595,35 @@ _collections__count_elements_impl(PyObject *module, PyObject *mapping, goto done; } - oldval = _PyDict_GetItem_KnownHash(mapping, key, hash); - if (oldval == NULL) { - if (PyErr_Occurred()) - goto done; - if (_PyDict_SetItem_KnownHash(mapping, key, one, hash) < 0) - goto done; - } else { - /* oldval is a borrowed reference. Keep it alive across - PyNumber_Add(), which can execute arbitrary user code and - mutate (or even clear) the underlying dict. */ - Py_INCREF(oldval); + int found; + int cs_err = 0; + Py_BEGIN_CRITICAL_SECTION(mapping); + found = _PyDict_GetItemRef_KnownHash_LockHeld( + (PyDictObject *)mapping, key, hash, &oldval); + if (found < 0) { + cs_err = -1; + } + else if (found == 0) { + if (_PyDict_SetItem_KnownHash_LockHeld( + (PyDictObject *)mapping, key, one, hash) < 0) { + cs_err = -1; + } + } + else { newval = PyNumber_Add(oldval, one); Py_DECREF(oldval); - if (newval == NULL) - goto done; - if (_PyDict_SetItem_KnownHash(mapping, key, newval, hash) < 0) - goto done; - Py_CLEAR(newval); + if (newval == NULL) { + cs_err = -1; + } + else if (_PyDict_SetItem_KnownHash_LockHeld( + (PyDictObject *)mapping, key, newval, hash) < 0) { + cs_err = -1; + } + } + Py_END_CRITICAL_SECTION(); + Py_CLEAR(newval); + if (cs_err < 0) { + goto done; } Py_DECREF(key); }