Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add a regression test.
  • Loading branch information
ericsnowcurrently committed May 31, 2023
commit b6801173ad01cf025ed49774d849ef7be0d29ff5
31 changes: 30 additions & 1 deletion Lib/test/test_capi/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Run the _testcapi module tests (tests for the Python/C API): by defn,
# these are all functions _testcapi exports whose name begins with 'test_'.

from collections import OrderedDict
import _thread
from collections import OrderedDict
import contextlib
import importlib.machinery
import importlib.util
import os
Expand Down Expand Up @@ -1626,6 +1627,34 @@ def test_tp_mro_is_set(self):
self.assertIsNot(mro, None)


class TestStaticTypes(unittest.TestCase):

def test_pytype_ready_always_sets_tp_type(self):
# The point of this test is to prevent something like
# https://github.com/python/cpython/issues/104614
# from happening again.

@contextlib.contextmanager
def basic_static_type(*args):
cls = _testcapi.get_basic_static_type(*args)
try:
yield cls
finally:
_testcapi.clear_basic_static_type(cls)

# First check when tp_base/tp_bases is *not* set before PyType_Ready().
with basic_static_type() as cls:
self.assertIs(cls.__base__, object);
self.assertEqual(cls.__bases__, (object,));
self.assertIs(type(cls), type(object));

# Then check when we *do* set tp_base/tp_bases first.
with basic_static_type(object) as cls:
self.assertIs(cls.__base__, object);
self.assertEqual(cls.__bases__, (object,));
self.assertIs(type(cls), type(object));


class TestThreadState(unittest.TestCase):

@threading_helper.reap_threads
Expand Down
56 changes: 56 additions & 0 deletions Modules/_testcapimodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -2627,6 +2627,60 @@ type_get_tp_mro(PyObject *self, PyObject *type)
}


static PyTypeObject BasicStaticType = {
PyVarObject_HEAD_INIT(NULL, 0)
.tp_name = "BasicStaticType",
.tp_basicsize = sizeof(PyObject),
};

static PyObject * clear_basic_static_type(PyObject *, PyObject *);

static PyObject *
get_basic_static_type(PyObject *self, PyObject *args)
{
PyObject *base = NULL;
if (!PyArg_ParseTuple(args, "|O", &base)) {
return NULL;
}
assert(base == NULL || PyType_Check(base));

PyTypeObject *cls = &BasicStaticType;
assert(!(cls->tp_flags & Py_TPFLAGS_READY));

if (base != NULL) {
cls->tp_base = (PyTypeObject *)Py_NewRef(base);
cls->tp_bases = Py_BuildValue("(O)", base);
if (cls->tp_bases == NULL) {
clear_basic_static_type(self, (PyObject *)cls);
return NULL;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should decref tp_base and tp_bases before returning.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened gh-105225.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for following up after-the-fact!

}
}
if (PyType_Ready(cls) < 0) {
clear_basic_static_type(self, (PyObject *)cls);
return NULL;
}
Py_INCREF(cls);
return (PyObject *)cls;
}

static PyObject *
clear_basic_static_type(PyObject *self, PyObject *clsobj)
{
// Reset it back to the statically initialized state.
PyTypeObject *cls = (PyTypeObject *)clsobj;
Py_CLEAR(cls->ob_base.ob_base.ob_type);
Py_CLEAR(cls->tp_base);
Py_CLEAR(cls->tp_bases);
Py_CLEAR(cls->tp_mro);
Py_CLEAR(cls->tp_subclasses);
Py_CLEAR(cls->tp_dict);
cls->tp_flags &= ~Py_TPFLAGS_READY;
cls->tp_flags &= ~Py_TPFLAGS_VALID_VERSION_TAG;
cls->tp_version_tag = 0;
Py_RETURN_NONE;
}


// Test PyThreadState C API
static PyObject *
test_tstate_capi(PyObject *self, PyObject *Py_UNUSED(args))
Expand Down Expand Up @@ -3384,6 +3438,8 @@ static PyMethodDef TestMethods[] = {
{"type_assign_version", type_assign_version, METH_O, PyDoc_STR("PyUnstable_Type_AssignVersionTag")},
{"type_get_tp_bases", type_get_tp_bases, METH_O},
{"type_get_tp_mro", type_get_tp_mro, METH_O},
{"get_basic_static_type", get_basic_static_type, METH_VARARGS, NULL},
{"clear_basic_static_type", clear_basic_static_type, METH_O, NULL},
{"test_tstate_capi", test_tstate_capi, METH_NOARGS, NULL},
{"frame_getlocals", frame_getlocals, METH_O, NULL},
{"frame_getglobals", frame_getglobals, METH_O, NULL},
Expand Down