Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Implement waiting for _channels.send().
  • Loading branch information
ericsnowcurrently committed Oct 9, 2023
commit c61736198c911f50c4771a7ad03f4e7d0cdda017
12 changes: 4 additions & 8 deletions Lib/test/support/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,7 @@ def send(self, obj):

This blocks until the object is received.
"""
_channels.send(self._id, obj)
# XXX We are missing a low-level channel_send_wait().
# See bpo-32604 and gh-19829.
# Until that shows up we fake it:
time.sleep(2)
_channels.send(self._id, obj, wait=True)

def send_nowait(self, obj):
"""Send the object to the channel's receiving end.
Expand All @@ -223,22 +219,22 @@ def send_nowait(self, obj):
# XXX Note that at the moment channel_send() only ever returns
# None. This should be fixed when channel_send_wait() is added.
# See bpo-32604 and gh-19829.
return _channels.send(self._id, obj)
return _channels.send(self._id, obj, wait=False)

def send_buffer(self, obj):
"""Send the object's buffer to the channel's receiving end.

This blocks until the object is received.
"""
_channels.send_buffer(self._id, obj)
_channels.send_buffer(self._id, obj, wait=True)

def send_buffer_nowait(self, obj):
"""Send the object's buffer to the channel's receiving end.

If the object is immediately received then return True
(else False). Otherwise this is the same as send().
"""
return _channels.send_buffer(self._id, obj)
return _channels.send_buffer(self._id, obj, wait=False)

def close(self):
_channels.close(self._id, send=True)
Expand Down
2 changes: 1 addition & 1 deletion Lib/test/test_interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,8 +964,8 @@ def f():

orig = b'spam'
s.send(orig)
t.join()
obj = r.recv()
t.join()

self.assertEqual(obj, orig)
self.assertIsNot(obj, orig)
Expand Down
114 changes: 92 additions & 22 deletions Modules/_xxinterpchannelsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,14 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
return cls;
}

static void
wait_for_lock(PyThread_type_lock mutex)
{
// XXX Handle eintr, etc.
PyThread_acquire_lock(mutex, WAIT_LOCK);
PyThread_release_lock(mutex);
}


/* Cross-interpreter Buffer Views *******************************************/

Expand Down Expand Up @@ -567,6 +575,7 @@ struct _channelitem;

typedef struct _channelitem {
_PyCrossInterpreterData *data;
PyThread_type_lock recv_mutex;
struct _channelitem *next;
} _channelitem;

Expand Down Expand Up @@ -612,10 +621,11 @@ _channelitem_free_all(_channelitem *item)
}

static _PyCrossInterpreterData *
_channelitem_popped(_channelitem *item)
_channelitem_popped(_channelitem *item, PyThread_type_lock *recv_mutex)
{
_PyCrossInterpreterData *data = item->data;
item->data = NULL;
*recv_mutex = item->recv_mutex;
_channelitem_free(item);
return data;
}
Expand Down Expand Up @@ -657,13 +667,15 @@ _channelqueue_free(_channelqueue *queue)
}

static int
_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data)
_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data,
PyThread_type_lock recv_mutex)
{
_channelitem *item = _channelitem_new();
if (item == NULL) {
return -1;
}
item->data = data;
item->recv_mutex = recv_mutex;

queue->count += 1;
if (queue->first == NULL) {
Expand All @@ -677,7 +689,7 @@ _channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data)
}

static _PyCrossInterpreterData *
_channelqueue_get(_channelqueue *queue)
_channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex)
{
_channelitem *item = queue->first;
if (item == NULL) {
Expand All @@ -689,7 +701,7 @@ _channelqueue_get(_channelqueue *queue)
}
queue->count -= 1;

return _channelitem_popped(item);
return _channelitem_popped(item, recv_mutex);
}

static void
Expand Down Expand Up @@ -1006,7 +1018,7 @@ _channel_free(_PyChannelState *chan)

static int
_channel_add(_PyChannelState *chan, int64_t interp,
_PyCrossInterpreterData *data)
_PyCrossInterpreterData *data, PyThread_type_lock recv_mutex)
{
int res = -1;
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
Expand All @@ -1020,7 +1032,7 @@ _channel_add(_PyChannelState *chan, int64_t interp,
goto done;
}

if (_channelqueue_put(chan->queue, data) != 0) {
if (_channelqueue_put(chan->queue, data, recv_mutex) != 0) {
goto done;
}

Expand All @@ -1046,12 +1058,17 @@ _channel_next(_PyChannelState *chan, int64_t interp,
goto done;
}

_PyCrossInterpreterData *data = _channelqueue_get(chan->queue);
PyThread_type_lock recv_mutex = NULL;
_PyCrossInterpreterData *data = _channelqueue_get(chan->queue, &recv_mutex);
if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
chan->open = 0;
}
*res = data;

if (recv_mutex != NULL) {
PyThread_release_lock(recv_mutex);
}

done:
PyThread_release_lock(chan->mutex);
if (chan->queue->count == 0) {
Expand Down Expand Up @@ -1571,7 +1588,8 @@ _channel_destroy(_channels *channels, int64_t id)
}

static int
_channel_send(_channels *channels, int64_t id, PyObject *obj)
_channel_send(_channels *channels, int64_t id, PyObject *obj,
PyThread_type_lock recv_mutex)
{
PyInterpreterState *interp = _get_current_interp();
if (interp == NULL) {
Expand Down Expand Up @@ -1606,7 +1624,8 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj)
}

// Add the data to the channel.
int res = _channel_add(chan, PyInterpreterState_GetID(interp), data);
int res = _channel_add(chan, PyInterpreterState_GetID(interp), data,
recv_mutex);
PyThread_release_lock(mutex);
if (res != 0) {
// We may chain an exception here:
Expand Down Expand Up @@ -2489,22 +2508,47 @@ receive end.");
static PyObject *
channel_send(PyObject *self, PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {"cid", "obj", NULL};
// XXX Add a timeout arg.
static char *kwlist[] = {"cid", "obj", "wait", NULL};
int64_t cid;
struct channel_id_converter_data cid_data = {
.module = self,
};
PyObject *obj;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O:channel_send", kwlist,
channel_id_converter, &cid_data, &obj)) {
int wait = 1;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$p:channel_send", kwlist,
channel_id_converter, &cid_data, &obj,
&wait)) {
return NULL;
}
cid = cid_data.cid;

int err = _channel_send(&_globals.channels, cid, obj);
if (handle_channel_error(err, self, cid)) {
return NULL;
if (wait) {
PyThread_type_lock mutex = PyThread_allocate_lock();
if (mutex == NULL) {
PyErr_NoMemory();
return NULL;
}
PyThread_acquire_lock(mutex, WAIT_LOCK);

/* Queue up the object. */
int err = _channel_send(&_globals.channels, cid, obj, mutex);
if (handle_channel_error(err, self, cid)) {
PyThread_release_lock(mutex);
return NULL;
}

/* Wait until the object is received. */
wait_for_lock(mutex);
}
else {
/* Queue up the object. */
int err = _channel_send(&_globals.channels, cid, obj, NULL);
if (handle_channel_error(err, self, cid)) {
return NULL;
}
}

Py_RETURN_NONE;
}

Expand All @@ -2516,15 +2560,17 @@ Add the object's data to the channel's queue.");
static PyObject *
channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {"cid", "obj", NULL};
static char *kwlist[] = {"cid", "obj", "wait", NULL};
int64_t cid;
struct channel_id_converter_data cid_data = {
.module = self,
};
PyObject *obj;
int wait = 1;
if (!PyArg_ParseTupleAndKeywords(args, kwds,
"O&O:channel_send_buffer", kwlist,
channel_id_converter, &cid_data, &obj)) {
"O&O|$p:channel_send_buffer", kwlist,
channel_id_converter, &cid_data, &obj,
&wait)) {
return NULL;
}
cid = cid_data.cid;
Expand All @@ -2534,11 +2580,35 @@ channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
return NULL;
}

int err = _channel_send(&_globals.channels, cid, tempobj);
Py_DECREF(tempobj);
if (handle_channel_error(err, self, cid)) {
return NULL;
if (wait) {
PyThread_type_lock mutex = PyThread_allocate_lock();
if (mutex == NULL) {
Py_DECREF(tempobj);
PyErr_NoMemory();
return NULL;
}
PyThread_acquire_lock(mutex, WAIT_LOCK);

/* Queue up the buffer. */
int err = _channel_send(&_globals.channels, cid, tempobj, mutex);
Py_DECREF(tempobj);
if (handle_channel_error(err, self, cid)) {
PyThread_acquire_lock(mutex, WAIT_LOCK);
return NULL;
}

/* Wait until the buffer is received. */
wait_for_lock(mutex);
}
else {
/* Queue up the buffer. */
int err = _channel_send(&_globals.channels, cid, tempobj, NULL);
Py_DECREF(tempobj);
if (handle_channel_error(err, self, cid)) {
return NULL;
}
}

Py_RETURN_NONE;
}

Expand Down