Skip to content
Prev Previous commit
Next Next commit
fix UBSan failures for UnpicklerObject
  • Loading branch information
picnixz committed Jan 25, 2025
commit d1e782c76f31d3e1798d326307b4c2fb70b99155
31 changes: 19 additions & 12 deletions Modules/_pickle.c
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ typedef struct {
} UnpicklerMemoProxyObject;

#define _PicklerObject_CAST(op) ((PicklerObject *)(op))
#define _UnpicklerObject_CAST(op) ((UnpicklerObject *)(op))

/* Forward declarations */
static int save(PickleState *state, PicklerObject *, PyObject *, int);
Expand Down Expand Up @@ -7229,8 +7230,9 @@ static struct PyMethodDef Unpickler_methods[] = {
};

static int
Unpickler_clear(UnpicklerObject *self)
Unpickler_clear(PyObject *op)
{
UnpicklerObject *self = _UnpicklerObject_CAST(op);
Py_CLEAR(self->readline);
Py_CLEAR(self->readinto);
Py_CLEAR(self->read);
Expand All @@ -7257,18 +7259,19 @@ Unpickler_clear(UnpicklerObject *self)
}

static void
Unpickler_dealloc(UnpicklerObject *self)
Unpickler_dealloc(PyObject *self)
{
PyTypeObject *tp = Py_TYPE(self);
PyObject_GC_UnTrack((PyObject *)self);
PyObject_GC_UnTrack(self);
(void)Unpickler_clear(self);
tp->tp_free((PyObject *)self);
tp->tp_free(self);
Py_DECREF(tp);
}

static int
Unpickler_traverse(UnpicklerObject *self, visitproc visit, void *arg)
Unpickler_traverse(PyObject *op, visitproc visit, void *arg)
{
UnpicklerObject *self = _UnpicklerObject_CAST(op);
Py_VISIT(Py_TYPE(self));
Py_VISIT(self->readline);
Py_VISIT(self->readinto);
Expand Down Expand Up @@ -7328,7 +7331,7 @@ _pickle_Unpickler___init___impl(UnpicklerObject *self, PyObject *file,
{
/* In case of multiple __init__() calls, clear previous content. */
if (self->read != NULL)
(void)Unpickler_clear(self);
(void)Unpickler_clear((PyObject *)self);

if (_Unpickler_SetInputStream(self, file) < 0)
return -1;
Expand Down Expand Up @@ -7527,15 +7530,17 @@ UnpicklerMemoProxy_New(UnpicklerObject *unpickler)


static PyObject *
Unpickler_get_memo(UnpicklerObject *self, void *Py_UNUSED(ignored))
Unpickler_get_memo(PyObject *op, void *Py_UNUSED(ignored))
{
UnpicklerObject *self = _UnpicklerObject_CAST(op);
return UnpicklerMemoProxy_New(self);
}

static int
Unpickler_set_memo(UnpicklerObject *self, PyObject *obj, void *Py_UNUSED(ignored))
Unpickler_set_memo(PyObject *op, PyObject *obj, void *Py_UNUSED(ignored))
{
PyObject **new_memo;
UnpicklerObject *self = _UnpicklerObject_CAST(op);
size_t new_memo_size = 0;

if (obj == NULL) {
Expand Down Expand Up @@ -7612,11 +7617,12 @@ Unpickler_set_memo(UnpicklerObject *self, PyObject *obj, void *Py_UNUSED(ignored
static PyObject *
Unpickler_getattr(PyObject *self, PyObject *name)
{
UnpicklerObject *obj = _UnpicklerObject_CAST(self);
if (PyUnicode_Check(name)
&& PyUnicode_EqualToUTF8(name, "persistent_load")
&& ((UnpicklerObject *)self)->persistent_load_attr)
&& obj->persistent_load_attr)
{
return Py_NewRef(((UnpicklerObject *)self)->persistent_load_attr);
return Py_NewRef(obj->persistent_load_attr);
}

return PyObject_GenericGetAttr(self, name);
Expand All @@ -7628,16 +7634,17 @@ Unpickler_setattr(PyObject *self, PyObject *name, PyObject *value)
if (PyUnicode_Check(name)
&& PyUnicode_EqualToUTF8(name, "persistent_load"))
{
UnpicklerObject *obj = _UnpicklerObject_CAST(self);
Py_XINCREF(value);
Py_XSETREF(((UnpicklerObject *)self)->persistent_load_attr, value);
Py_XSETREF(obj->persistent_load_attr, value);
return 0;
}

return PyObject_GenericSetAttr(self, name, value);
}

static PyGetSetDef Unpickler_getsets[] = {
{"memo", (getter)Unpickler_get_memo, (setter)Unpickler_set_memo},
{"memo", Unpickler_get_memo, Unpickler_set_memo},
{NULL}
};

Expand Down