Skip to content
Prev Previous commit
Next Next commit
Stop waiting when the channel is closed.
  • Loading branch information
ericsnowcurrently committed Oct 16, 2023
commit f904a52a1a60f3c1f41fb1fa64cd3819d3b357c0
26 changes: 26 additions & 0 deletions Lib/test/test__xxinterpchannels.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,32 @@ def test_channel_list_interpreters_closed_send_end(self):

####################

def test_send_closed_while_waiting(self):
obj = b'spam'
cid = channels.create()
def f():
# sleep() isn't a great for this, but definitely simple.
time.sleep(1)
channels.close(cid, force=True)
t = threading.Thread(target=f)
t.start()
with self.assertRaises(channels.ChannelClosedError):
channels.send(cid, obj, blocking=True)
t.join()

def test_send_buffer_closed_while_waiting(self):
obj = bytearray(b'spam')
cid = channels.create()
def f():
# sleep() isn't a great for this, but definitely simple.
time.sleep(1)
channels.close(cid, force=True)
t = threading.Thread(target=f)
t.start()
with self.assertRaises(channels.ChannelClosedError):
channels.send_buffer(cid, obj, blocking=True)
t.join()

def test_send_recv_main(self):
cid = channels.create()
orig = b'spam'
Expand Down
58 changes: 49 additions & 9 deletions Modules/_xxinterpchannelsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ _get_current_xibufferview_type(void)
#define ERR_CHANNEL_MUTEX_INIT -7
#define ERR_CHANNELS_MUTEX_INIT -8
#define ERR_NO_NEXT_CHANNEL_ID -9
#define ERR_CHANNEL_CLOSED_WAITING -10

static int
exceptions_init(PyObject *mod)
Expand Down Expand Up @@ -550,6 +551,10 @@ handle_channel_error(int err, PyObject *mod, int64_t cid)
PyErr_Format(state->ChannelClosedError,
"channel %" PRId64 " is closed", cid);
}
else if (err == ERR_CHANNEL_CLOSED_WAITING) {
PyErr_Format(state->ChannelClosedError,
"channel %" PRId64 " has closed", cid);
}
else if (err == ERR_CHANNEL_INTERP_CLOSED) {
PyErr_Format(state->ChannelClosedError,
"channel %" PRId64 " is already closed", cid);
Expand Down Expand Up @@ -589,6 +594,7 @@ struct _channelitem;
typedef struct _channelitem {
_PyCrossInterpreterData *data;
PyThread_type_lock recv_mutex;
int *received;
struct _channelitem *next;
} _channelitem;

Expand All @@ -601,6 +607,8 @@ _channelitem_new(void)
return NULL;
}
item->data = NULL;
item->recv_mutex = NULL;
item->received = NULL;
item->next = NULL;
return item;
}
Expand All @@ -620,6 +628,10 @@ static void
_channelitem_free(_channelitem *item)
{
_channelitem_clear(item);
if (item->recv_mutex != NULL) {
// The code that sent the object must free the item.
return;
}
GLOBAL_FREE(item);
}

Expand All @@ -629,6 +641,9 @@ _channelitem_free_all(_channelitem *item)
while (item != NULL) {
_channelitem *last = item;
item = item->next;
if (last->recv_mutex != NULL) {
PyThread_release_lock(last->recv_mutex);
}
_channelitem_free(last);
}
}
Expand All @@ -639,6 +654,10 @@ _channelitem_popped(_channelitem *item, PyThread_type_lock *recv_mutex)
_PyCrossInterpreterData *data = item->data;
item->data = NULL;
*recv_mutex = item->recv_mutex;
if (item->received != NULL) {
assert(*item->received == 0);
*item->received = 1;
}
_channelitem_free(item);
return data;
}
Expand Down Expand Up @@ -679,16 +698,27 @@ _channelqueue_free(_channelqueue *queue)
GLOBAL_FREE(queue);
}

struct wait_info {
PyThread_type_lock mutex;
int received;
};

static int
_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data,
PyThread_type_lock recv_mutex)
struct wait_info *wait)
{
_channelitem *item = _channelitem_new();
if (item == NULL) {
return -1;
}
item->data = data;
item->recv_mutex = recv_mutex;
*item = (_channelitem){
.data = data,
};
if (wait != NULL) {
item->recv_mutex = wait->mutex;
wait->received = 0;
item->received = &wait->received;
}

queue->count += 1;
if (queue->first == NULL) {
Expand All @@ -698,6 +728,7 @@ _channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data,
queue->last->next = item;
}
queue->last = item;

return 0;
}

Expand Down Expand Up @@ -732,6 +763,9 @@ _channelqueue_drop_interpreter(_channelqueue *queue, int64_t interp)
else {
prev->next = item->next;
}
if (item->recv_mutex != NULL) {
PyThread_release_lock(item->recv_mutex);
}
_channelitem_free(item);
queue->count -= 1;
}
Expand Down Expand Up @@ -1031,7 +1065,7 @@ _channel_free(_PyChannelState *chan)

static int
_channel_add(_PyChannelState *chan, int64_t interp,
_PyCrossInterpreterData *data, PyThread_type_lock recv_mutex)
_PyCrossInterpreterData *data, struct wait_info *wait)
{
int res = -1;
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
Expand All @@ -1045,7 +1079,7 @@ _channel_add(_PyChannelState *chan, int64_t interp,
goto done;
}

if (_channelqueue_put(chan->queue, data, recv_mutex) != 0) {
if (_channelqueue_put(chan->queue, data, wait) != 0) {
goto done;
}

Expand Down Expand Up @@ -1602,7 +1636,7 @@ _channel_destroy(_channels *channels, int64_t id)

static int
_channel_send(_channels *channels, int64_t id, PyObject *obj,
PyThread_type_lock recv_mutex)
struct wait_info *wait)
{
PyInterpreterState *interp = _get_current_interp();
if (interp == NULL) {
Expand Down Expand Up @@ -1637,8 +1671,7 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj,
}

// Add the data to the channel.
int res = _channel_add(chan, PyInterpreterState_GetID(interp), data,
recv_mutex);
int res = _channel_add(chan, PyInterpreterState_GetID(interp), data, wait);
PyThread_release_lock(mutex);
if (res != 0) {
// We may chain an exception here:
Expand All @@ -1661,7 +1694,10 @@ _channel_send_wait(_channels *channels, int64_t cid, PyObject *obj)
PyThread_acquire_lock(mutex, NOWAIT_LOCK);

/* Queue up the object. */
int res = _channel_send(channels, cid, obj, mutex);
struct wait_info wait = (struct wait_info){
.mutex = mutex,
};
int res = _channel_send(channels, cid, obj, &wait);
if (res < 0) {
PyThread_release_lock(mutex);
goto finally;
Expand All @@ -1674,6 +1710,10 @@ _channel_send_wait(_channels *channels, int64_t cid, PyObject *obj)
res = -1;
goto finally;
}
if (!wait.received) {
res = ERR_CHANNEL_CLOSED_WAITING;
goto finally;
}

/* success! */
res = 0;
Expand Down