Skip to content
Prev Previous commit
Next Next commit
Make RecvChannel and SendChannel shareable.
  • Loading branch information
ericsnowcurrently committed Sep 29, 2023
commit 996596ba9fe0902703a47c0f91151608aa69f616
122 changes: 122 additions & 0 deletions Modules/_xxinterpchannelsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ _release_xid_data(_PyCrossInterpreterData *data, int flags)
/* module state *************************************************************/

typedef struct {
PyTypeObject *send_channel_type;
PyTypeObject *recv_channel_type;

/* heap types */
PyTypeObject *ChannelIDType;

Expand Down Expand Up @@ -252,6 +255,9 @@ traverse_module_state(module_state *state, visitproc visit, void *arg)
static int
clear_module_state(module_state *state)
{
Py_CLEAR(state->send_channel_type);
Py_CLEAR(state->recv_channel_type);

/* heap types */
if (state->ChannelIDType != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
Expand Down Expand Up @@ -1967,6 +1973,91 @@ static PyType_Spec ChannelIDType_spec = {
};


/* SendChannel and RecvChannel classes */

// XXX Use a new __xid__ protocol instead?

static PyTypeObject *
_get_current_channel_end_type(int end)
{
module_state *state = _get_current_module_state();
if (state == NULL) {
return NULL;
}
PyTypeObject *cls;
if (end == CHANNEL_SEND) {
cls = state->send_channel_type;
}
else {
assert(end == CHANNEL_RECV);
cls = state->recv_channel_type;
}
if (cls == NULL) {
// XXX Use some other exception type?
PyErr_SetString(PyExc_RuntimeError, "interpreters module not imported yet");
return NULL;
}
return cls;
}

static PyObject *
_channel_end_from_xid(_PyCrossInterpreterData *data)
{
channelid *cid = (channelid *)_channelid_from_xid(data);
if (cid == NULL) {
return NULL;
}
PyTypeObject *cls = _get_current_channel_end_type(cid->end);
if (cls == NULL) {
return NULL;
}
PyObject *obj = PyObject_CallOneArg((PyObject *)cls, (PyObject *)cid);
Py_DECREF(cid);
return obj;
}

static int
_channel_end_shared(PyThreadState *tstate, PyObject *obj,
_PyCrossInterpreterData *data)
{
PyObject *cidobj = PyObject_GetAttrString(obj, "_id");
if (cidobj == NULL) {
return -1;
}
if (_channelid_shared(tstate, cidobj, data) < 0) {
return -1;
}
data->new_object = _channel_end_from_xid;
return 0;
}

static int
set_channel_end_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
{
module_state *state = get_module_state(mod);
if (state == NULL) {
return -1;
}

if (state->send_channel_type != NULL
|| state->recv_channel_type != NULL)
{
PyErr_SetString(PyExc_TypeError, "already registered");
return -1;
}
state->send_channel_type = (PyTypeObject *)Py_NewRef(send);
state->recv_channel_type = (PyTypeObject *)Py_NewRef(recv);

if (_PyCrossInterpreterData_RegisterClass(send, _channel_end_shared)) {
return -1;
}
if (_PyCrossInterpreterData_RegisterClass(recv, _channel_end_shared)) {
return -1;
}

return 0;
}

/* module level code ********************************************************/

/* globals is the process-global state for the module. It holds all
Expand Down Expand Up @@ -2375,6 +2466,35 @@ channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
return _channelid_new(self, cls, args, kwds);
}

static PyObject *
channel__register_end_types(PyObject *self, PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {"send", "recv", NULL};
PyObject *send;
PyObject *recv;
if (!PyArg_ParseTupleAndKeywords(args, kwds,
"OO:_register_end_types", kwlist,
&send, &recv)) {
return NULL;
}
if (!PyType_Check(send)) {
PyErr_SetString(PyExc_TypeError, "expected a type for 'send'");
return NULL;
}
if (!PyType_Check(recv)) {
PyErr_SetString(PyExc_TypeError, "expected a type for 'recv'");
return NULL;
}
PyTypeObject *cls_send = (PyTypeObject *)send;
PyTypeObject *cls_recv = (PyTypeObject *)recv;

if (set_channel_end_types(self, cls_send, cls_recv) < 0) {
return NULL;
}

Py_RETURN_NONE;
}

static PyMethodDef module_functions[] = {
{"create", channel_create,
METH_NOARGS, channel_create_doc},
Expand All @@ -2394,6 +2514,8 @@ static PyMethodDef module_functions[] = {
METH_VARARGS | METH_KEYWORDS, channel_release_doc},
{"_channel_id", _PyCFunction_CAST(channel__channel_id),
METH_VARARGS | METH_KEYWORDS, NULL},
{"_register_end_types", _PyCFunction_CAST(channel__register_end_types),
METH_VARARGS | METH_KEYWORDS, NULL},

{NULL, NULL} /* sentinel */
};
Expand Down