Skip to content
Next Next commit
Use PyMutex for writes to asyncio state
  • Loading branch information
Fidget-Spinner committed Jul 11, 2024
commit 482219028d32794fe2d727a7a6059068c63377b7
24 changes: 24 additions & 0 deletions Modules/_asynciomodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ typedef struct futureiterobject futureiterobject;

/* State of the _asyncio module */
typedef struct {
PyMutex mutex;
PyTypeObject *FutureIterType;
PyTypeObject *TaskStepMethWrapper_Type;
PyTypeObject *FutureType;
Expand Down Expand Up @@ -341,8 +342,10 @@ get_running_loop(asyncio_state *state, PyObject **loop)
}
}

PyMutex_Lock(&state->mutex);
state->cached_running_loop = rl;
state->cached_running_loop_tsid = ts_id;
PyMutex_Unlock(&state->mutex);
}


Expand Down Expand Up @@ -384,8 +387,10 @@ set_running_loop(asyncio_state *state, PyObject *loop)
return -1;
}

PyMutex_Lock(&state->mutex);
state->cached_running_loop = loop; // borrowed, kept alive by ts_dict
state->cached_running_loop_tsid = PyThreadState_GetID(tstate);
PyMutex_Unlock(&state->mutex);

return 0;
}
Expand Down Expand Up @@ -1668,9 +1673,11 @@ FutureIter_dealloc(futureiterobject *it)
}

if (state && state->fi_freelist_len < FI_FREELIST_MAXLEN) {
PyMutex_Lock(&state->mutex);
state->fi_freelist_len++;
it->future = (FutureObj*) state->fi_freelist;
state->fi_freelist = it;
PyMutex_Unlock(&state->mutex);
}
else {
PyObject_GC_Del(it);
Expand Down Expand Up @@ -1877,9 +1884,11 @@ future_new_iter(PyObject *fut)
ENSURE_FUTURE_ALIVE(state, fut)

if (state->fi_freelist_len) {
PyMutex_Lock(&state->mutex);
state->fi_freelist_len--;
it = state->fi_freelist;
state->fi_freelist = (futureiterobject*) it->future;
PyMutex_Unlock(&state->mutex);
it->future = NULL;
_Py_NewReference((PyObject*) it);
}
Expand Down Expand Up @@ -2028,8 +2037,10 @@ register_task(asyncio_state *state, TaskObj *task)
assert(state->asyncio_tasks.head != NULL);

task->next = state->asyncio_tasks.head;
PyMutex_Lock(&state->mutex);
state->asyncio_tasks.head->prev = task;
state->asyncio_tasks.head = task;
PyMutex_Unlock(&state->mutex);
}

static int
Expand All @@ -2052,7 +2063,9 @@ unregister_task(asyncio_state *state, TaskObj *task)
task->next->prev = task->prev;
if (task->prev == NULL) {
assert(state->asyncio_tasks.head == task);
PyMutex_Lock(&state->mutex);
state->asyncio_tasks.head = task->next;
PyMutex_Unlock(&state->mutex);
} else {
task->prev->next = task->next;
}
Expand Down Expand Up @@ -2213,7 +2226,9 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop,
// optimization: defer task name formatting
// store the task counter as PyLong in the name
// for deferred formatting in get_name
PyMutex_Lock(&state->mutex);
name = PyLong_FromUnsignedLongLong(++state->task_name_counter);
PyMutex_Unlock(&state->mutex);
} else if (!PyUnicode_CheckExact(name)) {
name = PyObject_Str(name);
} else {
Expand Down Expand Up @@ -3750,6 +3765,7 @@ module_free_freelists(asyncio_state *state)
PyObject *current;

next = (PyObject*) state->fi_freelist;
PyMutex_Lock(&state->mutex);
while (next != NULL) {
assert(state->fi_freelist_len > 0);
state->fi_freelist_len--;
Expand All @@ -3760,6 +3776,7 @@ module_free_freelists(asyncio_state *state)
}
assert(state->fi_freelist_len == 0);
state->fi_freelist = NULL;
PyMutex_Unlock(&state->mutex);
}

static int
Expand Down Expand Up @@ -3844,6 +3861,7 @@ module_init(asyncio_state *state)
{
PyObject *module = NULL;

PyMutex_Lock(&state->mutex);
state->asyncio_mod = PyImport_ImportModule("asyncio");
if (state->asyncio_mod == NULL) {
goto fail;
Expand Down Expand Up @@ -3913,10 +3931,13 @@ module_init(asyncio_state *state)
goto fail;
}

PyMutex_Unlock(&state->mutex);

Py_DECREF(module);
return 0;

fail:
PyMutex_Unlock(&state->mutex);
Py_CLEAR(module);
return -1;

Expand Down Expand Up @@ -3947,9 +3968,12 @@ static int
module_exec(PyObject *mod)
{
asyncio_state *state = get_asyncio_state(mod);

PyMutex_Lock(&state->mutex);
Py_SET_TYPE(&state->asyncio_tasks.tail, state->TaskType);
_Py_SetImmortalUntracked((PyObject *)&state->asyncio_tasks.tail);
state->asyncio_tasks.head = &state->asyncio_tasks.tail;
PyMutex_Unlock(&state->mutex);

#define CREATE_TYPE(m, tp, spec, base) \
do { \
Expand Down