Skip to content
Merged
Prev Previous commit
Next Next commit
fix UBSan failures for localobject
  • Loading branch information
picnixz committed Jan 25, 2025
commit a96a296863e74c4500d19817034d8369cc6a133f
49 changes: 25 additions & 24 deletions Modules/_threadmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,8 @@ typedef struct {
PyObject *thread_watchdogs;
} localobject;

#define _localobject_CAST(op) ((localobject *)(op))

/* Forward declaration */
static int create_localsdict(localobject *self, thread_module_state *state,
PyObject **localsdict, PyObject **sentinel_wr);
Expand All @@ -1369,7 +1371,7 @@ static PyObject *
create_sentinel_wr(localobject *self)
{
static PyMethodDef wr_callback_def = {
"clear_locals", (PyCFunction) clear_locals, METH_O
"clear_locals", clear_locals, METH_O
};

PyThreadState *tstate = PyThreadState_Get();
Expand Down Expand Up @@ -1461,8 +1463,9 @@ local_new(PyTypeObject *type, PyObject *args, PyObject *kw)
}

static int
local_traverse(localobject *self, visitproc visit, void *arg)
local_traverse(PyObject *op, visitproc visit, void *arg)
{
localobject *self = _localobject_CAST(op);
Py_VISIT(Py_TYPE(self));
Py_VISIT(self->args);
Py_VISIT(self->kw);
Expand All @@ -1472,8 +1475,9 @@ local_traverse(localobject *self, visitproc visit, void *arg)
}

static int
local_clear(localobject *self)
local_clear(PyObject *op)
{
localobject *self = _localobject_CAST(op);
Py_CLEAR(self->args);
Py_CLEAR(self->kw);
Py_CLEAR(self->localdicts);
Expand All @@ -1482,20 +1486,18 @@ local_clear(localobject *self)
}

static void
local_dealloc(localobject *self)
local_dealloc(PyObject *op)
{
localobject *self = _localobject_CAST(op);
/* Weakrefs must be invalidated right now, otherwise they can be used
from code called below, which is very dangerous since Py_REFCNT(self) == 0 */
if (self->weakreflist != NULL) {
PyObject_ClearWeakRefs((PyObject *) self);
PyObject_ClearWeakRefs(op);
}

PyObject_GC_UnTrack(self);

local_clear(self);

(void)local_clear(op);
PyTypeObject *tp = Py_TYPE(self);
tp->tp_free((PyObject*)self);
tp->tp_free(self);
Py_DECREF(tp);
}

Expand Down Expand Up @@ -1634,8 +1636,9 @@ _ldict(localobject *self, thread_module_state *state)
}

static int
local_setattro(localobject *self, PyObject *name, PyObject *v)
local_setattro(PyObject *op, PyObject *name, PyObject *v)
{
localobject *self = _localobject_CAST(op);
PyObject *module = PyType_GetModuleByDef(Py_TYPE(self), &thread_module);
assert(module != NULL);
thread_module_state *state = get_thread_state(module);
Expand All @@ -1656,8 +1659,7 @@ local_setattro(localobject *self, PyObject *name, PyObject *v)
goto err;
}

int st =
_PyObject_GenericSetAttrWithDict((PyObject *)self, name, v, ldict);
int st = _PyObject_GenericSetAttrWithDict(op, name, v, ldict);
Py_DECREF(ldict);
return st;

Expand All @@ -1666,20 +1668,20 @@ local_setattro(localobject *self, PyObject *name, PyObject *v)
return -1;
}

static PyObject *local_getattro(localobject *, PyObject *);
static PyObject *local_getattro(PyObject *, PyObject *);

static PyMemberDef local_type_members[] = {
{"__weaklistoffset__", Py_T_PYSSIZET, offsetof(localobject, weakreflist), Py_READONLY},
{NULL},
};

static PyType_Slot local_type_slots[] = {
{Py_tp_dealloc, (destructor)local_dealloc},
{Py_tp_getattro, (getattrofunc)local_getattro},
{Py_tp_setattro, (setattrofunc)local_setattro},
{Py_tp_dealloc, local_dealloc},
{Py_tp_getattro, local_getattro},
{Py_tp_setattro, local_setattro},
{Py_tp_doc, "_local()\n--\n\nThread-local data"},
{Py_tp_traverse, (traverseproc)local_traverse},
{Py_tp_clear, (inquiry)local_clear},
{Py_tp_traverse, local_traverse},
{Py_tp_clear, local_clear},
{Py_tp_new, local_new},
{Py_tp_members, local_type_members},
{0, 0}
Expand All @@ -1694,8 +1696,9 @@ static PyType_Spec local_type_spec = {
};

static PyObject *
local_getattro(localobject *self, PyObject *name)
local_getattro(PyObject *op, PyObject *name)
{
localobject *self = _localobject_CAST(op);
PyObject *module = PyType_GetModuleByDef(Py_TYPE(self), &thread_module);
assert(module != NULL);
thread_module_state *state = get_thread_state(module);
Expand All @@ -1715,8 +1718,7 @@ local_getattro(localobject *self, PyObject *name)

if (!Py_IS_TYPE(self, state->local_type)) {
/* use generic lookup for subtypes */
PyObject *res =
_PyObject_GenericGetAttrWithDict((PyObject *)self, name, ldict, 0);
PyObject *res = _PyObject_GenericGetAttrWithDict(op, name, ldict, 0);
Py_DECREF(ldict);
return res;
}
Expand All @@ -1730,8 +1732,7 @@ local_getattro(localobject *self, PyObject *name)
}

/* Fall back on generic to get __class__ and __dict__ */
PyObject *res =
_PyObject_GenericGetAttrWithDict((PyObject *)self, name, ldict, 0);
PyObject *res = _PyObject_GenericGetAttrWithDict(op, name, ldict, 0);
Py_DECREF(ldict);
return res;
}
Expand Down