Skip to content
Prev Previous commit
Next Next commit
fix UBSan failures for PicklerObject
  • Loading branch information
picnixz committed Jan 25, 2025
commit 61da4f215446cc767b4ebb7df57fa5072ae75213
31 changes: 19 additions & 12 deletions Modules/_pickle.c
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,8 @@ typedef struct {
UnpicklerObject *unpickler;
} UnpicklerMemoProxyObject;

#define _PicklerObject_CAST(op) ((PicklerObject *)(op))

/* Forward declarations */
static int save(PickleState *state, PicklerObject *, PyObject *, int);
static int save_reduce(PickleState *, PicklerObject *, PyObject *, PyObject *);
Expand Down Expand Up @@ -4723,8 +4725,9 @@ static struct PyMethodDef Pickler_methods[] = {
};

static int
Pickler_clear(PicklerObject *self)
Pickler_clear(PyObject *op)
{
PicklerObject *self = _PicklerObject_CAST(op);
Py_CLEAR(self->output_buffer);
Py_CLEAR(self->write);
Py_CLEAR(self->persistent_id);
Expand All @@ -4742,18 +4745,19 @@ Pickler_clear(PicklerObject *self)
}

static void
Pickler_dealloc(PicklerObject *self)
Pickler_dealloc(PyObject *self)
{
PyTypeObject *tp = Py_TYPE(self);
PyObject_GC_UnTrack(self);
(void)Pickler_clear(self);
tp->tp_free((PyObject *)self);
tp->tp_free(self);
Py_DECREF(tp);
}

static int
Pickler_traverse(PicklerObject *self, visitproc visit, void *arg)
Pickler_traverse(PyObject *op, visitproc visit, void *arg)
{
PicklerObject *self = _PicklerObject_CAST(op);
Py_VISIT(Py_TYPE(self));
Py_VISIT(self->write);
Py_VISIT(self->persistent_id);
Expand Down Expand Up @@ -4823,7 +4827,7 @@ _pickle_Pickler___init___impl(PicklerObject *self, PyObject *file,
{
/* In case of multiple __init__() calls, clear previous content. */
if (self->write != NULL)
(void)Pickler_clear(self);
(void)Pickler_clear((PyObject *)self);

if (_Pickler_SetProtocol(self, protocol, fix_imports) < 0)
return -1;
Expand Down Expand Up @@ -5033,15 +5037,17 @@ PicklerMemoProxy_New(PicklerObject *pickler)
/*****************************************************************************/

static PyObject *
Pickler_get_memo(PicklerObject *self, void *Py_UNUSED(ignored))
Pickler_get_memo(PyObject *op, void *Py_UNUSED(ignored))
{
PicklerObject *self = _PicklerObject_CAST(op);
return PicklerMemoProxy_New(self);
}

static int
Pickler_set_memo(PicklerObject *self, PyObject *obj, void *Py_UNUSED(ignored))
Pickler_set_memo(PyObject *op, PyObject *obj, void *Py_UNUSED(ignored))
{
PyMemoTable *new_memo = NULL;
PicklerObject *self = _PicklerObject_CAST(op);

if (obj == NULL) {
PyErr_SetString(PyExc_TypeError,
Expand Down Expand Up @@ -5104,11 +5110,12 @@ Pickler_set_memo(PicklerObject *self, PyObject *obj, void *Py_UNUSED(ignored))
static PyObject *
Pickler_getattr(PyObject *self, PyObject *name)
{
PicklerObject *po = _PicklerObject_CAST(self);
if (PyUnicode_Check(name)
&& PyUnicode_EqualToUTF8(name, "persistent_id")
&& ((PicklerObject *)self)->persistent_id_attr)
&& po->persistent_id_attr)
{
return Py_NewRef(((PicklerObject *)self)->persistent_id_attr);
return Py_NewRef(po->persistent_id_attr);
}

return PyObject_GenericGetAttr(self, name);
Expand All @@ -5120,8 +5127,9 @@ Pickler_setattr(PyObject *self, PyObject *name, PyObject *value)
if (PyUnicode_Check(name)
&& PyUnicode_EqualToUTF8(name, "persistent_id"))
{
PicklerObject *po = _PicklerObject_CAST(self);
Py_XINCREF(value);
Py_XSETREF(((PicklerObject *)self)->persistent_id_attr, value);
Py_XSETREF(po->persistent_id_attr, value);
return 0;
}

Expand All @@ -5136,8 +5144,7 @@ static PyMemberDef Pickler_members[] = {
};

static PyGetSetDef Pickler_getsets[] = {
{"memo", (getter)Pickler_get_memo,
(setter)Pickler_set_memo},
{"memo", Pickler_get_memo, Pickler_set_memo},
{NULL}
};

Expand Down