Skip to content
Merged
Prev Previous commit
Next Next commit
Clean up arg parsing.
  • Loading branch information
ericsnowcurrently committed Oct 4, 2023
commit b8e32fefb862c243d8051b5088ed425686940a4b
151 changes: 119 additions & 32 deletions Modules/_xxsubinterpretersmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -335,39 +335,48 @@ _sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass)

/* Python code **************************************************************/

static int
validate_code_object(PyCodeObject *code)
static const char *
check_code_str(PyUnicodeObject *text)
{
assert(text != NULL);
if (PyUnicode_GET_LENGTH(text) == 0) {
return "too short";
}

// XXX Verify that it parses?

return NULL;
}

static const char *
check_code_object(PyCodeObject *code)
{
assert(code != NULL);
if (code->co_argcount > 0
|| code->co_posonlyargcount > 0
|| code->co_kwonlyargcount > 0
|| code->co_flags & (CO_VARARGS | CO_VARKEYWORDS))
{
PyErr_SetString(PyExc_ValueError, "arguments not supported");
return -1;
return "arguments not supported";
}
if (code->co_ncellvars > 0) {
PyErr_SetString(PyExc_ValueError, "closures not supported");
return -1;
return "closures not supported";
}
// We trust that no code objects under co_consts have unbound cell vars.

if (code->co_executors != NULL
|| code->_co_instrumentation_version > 0)
{
PyErr_SetString(PyExc_ValueError, "only basic functions are supported");
return -1;
return "only basic functions are supported";
}
if (code->_co_monitoring != NULL) {
PyErr_SetString(PyExc_ValueError, "only basic functions are supported");
return -1;
return "only basic functions are supported";
}
if (code->co_extra != NULL) {
PyErr_SetString(PyExc_ValueError, "only basic functions are supported");
return -1;
return "only basic functions are supported";
}

return 0;
return NULL;
}

#define RUN_TEXT 1
Expand All @@ -382,7 +391,8 @@ get_code_str(PyObject *arg, Py_ssize_t *len_p, PyObject **bytes_p, int *flags_p)
int flags = 0;

if (PyUnicode_Check(arg)) {
// XXX Validate that it parses?
assert(PyUnicode_CheckExact(arg)
&& (check_code_str((PyUnicodeObject *)arg) == NULL));
codestr = PyUnicode_AsUTF8AndSize(arg, &len);
if (codestr == NULL) {
return NULL;
Expand All @@ -395,26 +405,12 @@ get_code_str(PyObject *arg, Py_ssize_t *len_p, PyObject **bytes_p, int *flags_p)
flags = RUN_TEXT;
}
else {
PyObject *code = arg;
if (PyFunction_Check(arg)) {
if (PyFunction_GetClosure(arg)) {
PyErr_SetString(PyExc_ValueError, "closures not supported");
return NULL;
}
code = PyFunction_GetCode(arg);
}
else if (!PyCode_Check(arg)) {
PyErr_SetString(PyExc_TypeError, "unsupported type");
return NULL;
}
assert(PyCode_Check(arg)
&& (check_code_object((PyCodeObject *)arg) == NULL));
flags = RUN_CODE;

if (validate_code_object((PyCodeObject *)code) < 0) {
return NULL;
}

// Serialize the code object.
bytes_obj = PyMarshal_WriteObjectToString(code, Py_MARSHAL_VERSION);
bytes_obj = PyMarshal_WriteObjectToString(arg, Py_MARSHAL_VERSION);
if (bytes_obj == NULL) {
return NULL;
}
Expand Down Expand Up @@ -777,6 +773,82 @@ PyDoc_STRVAR(get_main_doc,
Return the ID of main interpreter.");


static PyUnicodeObject *
convert_script_arg(PyObject *arg, const char *fname, const char *displayname,
const char *expected)
{
PyUnicodeObject *str = NULL;
if (PyUnicode_CheckExact(arg)) {
str = (PyUnicodeObject *)Py_NewRef(arg);
}
else if (PyUnicode_Check(arg)) {
// XXX str = PyUnicode_FromObject(arg);
str = (PyUnicodeObject *)Py_NewRef(arg);
}
else {
_PyArg_BadArgument(fname, displayname, expected, arg);
return NULL;
}

const char *err = check_code_str(str);
if (err != NULL) {
Py_DECREF(str);
PyErr_Format(PyExc_ValueError,
"%.200s(): bad script text (%s)", fname, err);
return NULL;
}

return str;
}

static PyCodeObject *
convert_code_arg(PyObject *arg, const char *fname, const char *displayname,
const char *expected)
{
const char *kind = NULL;
PyCodeObject *code = NULL;
if (PyFunction_Check(arg)) {
if (PyFunction_GetClosure(arg) != NULL) {
PyErr_Format(PyExc_ValueError,
"%.200s(): closures not supported", fname);
return NULL;
}
code = (PyCodeObject *)PyFunction_GetCode(arg);
if (code == NULL) {
if (PyErr_Occurred()) {
// This chains.
PyErr_Format(PyExc_ValueError,
"%.200s(): bad func", fname);
}
else {
PyErr_Format(PyExc_ValueError,
"%.200s(): func.__code__ missing", fname);
}
return NULL;
}
Py_INCREF(code);
kind = "func";
}
else if (PyCode_Check(arg)) {
code = (PyCodeObject *)Py_NewRef(arg);
kind = "code object";
}
else {
_PyArg_BadArgument(fname, displayname, expected, arg);
return NULL;
}

const char *err = check_code_object(code);
if (err != NULL) {
Py_DECREF(code);
PyErr_Format(PyExc_ValueError,
"%.200s(): bad %s (%s)", fname, kind, err);
return NULL;
}

return code;
}

static int
_interp_exec(PyObject *self,
PyObject *id_arg, PyObject *code_arg, PyObject *shared_arg)
Expand Down Expand Up @@ -820,7 +892,22 @@ interp_exec(PyObject *self, PyObject *args, PyObject *kwds)
return NULL;
}

if (_interp_exec(self, id, code, shared) < 0) {
const char *expected = "a string, a function, or a code object";
if (PyUnicode_Check(code)) {
code = (PyObject *)convert_script_arg(code, MODULE_NAME ".exec",
"argument 2", expected);
}
else {
code = (PyObject *)convert_code_arg(code, MODULE_NAME ".exec",
"argument 2", expected);
}
if (code == NULL) {
return NULL;
}

int res = _interp_exec(self, id, code, shared);
Py_DECREF(code);
if (res < 0) {
return NULL;
}
Py_RETURN_NONE;
Expand Down