Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Only convert channel IDs when requested.
  • Loading branch information
ericsnowcurrently committed May 16, 2018
commit bac0f1271015cf0484a9591bf89d0fa90dcbce20
38 changes: 28 additions & 10 deletions Lib/test/test__xxsubinterpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ def _captured_script(script):
indented = script.replace('\n', '\n ')
wrapped = dedent(f"""
import contextlib
with open({w}, 'w') as chan:
with contextlib.redirect_stdout(chan):
with open({w}, 'w') as spipe:
with contextlib.redirect_stdout(spipe):
{indented}
""")
return wrapped, open(r)


def _run_output(interp, request, shared=None):
script, chan = _captured_script(request)
with chan:
script, rpipe = _captured_script(request)
with rpipe:
interpreters.run_string(interp, script, shared)
return chan.read()
return rpipe.read()


@contextlib.contextmanager
Expand All @@ -37,17 +37,17 @@ def _running(interp):
def run():
interpreters.run_string(interp, dedent(f"""
# wait for "signal"
with open({r}) as chan:
chan.read()
with open({r}) as rpipe:
rpipe.read()
"""))

t = threading.Thread(target=run)
t.start()

yield

with open(w, 'w') as chan:
chan.write('done')
with open(w, 'w') as spipe:
spipe.write('done')
t.join()


Expand Down Expand Up @@ -1209,7 +1209,7 @@ def test_recv_empty(self):
with self.assertRaises(interpreters.ChannelEmptyError):
interpreters.channel_recv(cid)

def test_run_string_arg(self):
def test_run_string_arg_unresolved(self):
cid = interpreters.channel_create()
interp = interpreters.create()

Expand All @@ -1224,6 +1224,24 @@ def test_run_string_arg(self):
self.assertEqual(obj, b'spam')
self.assertEqual(out.strip(), 'send')

def test_run_string_arg_resolved(self):
cid = interpreters.channel_create()
cid = interpreters._channel_id(cid, _resolve=True)
interp = interpreters.create()

out = _run_output(interp, dedent("""
import _xxsubinterpreters as _interpreters
print(chan.end)
_interpreters.channel_send(chan, b'spam')
#print(chan.id.end)
#_interpreters.channel_send(chan.id, b'spam')
"""),
dict(chan=cid.send))
obj = interpreters.channel_recv(cid)

self.assertEqual(obj, b'spam')
self.assertEqual(out.strip(), 'send')


if __name__ == '__main__':
unittest.main()
28 changes: 19 additions & 9 deletions Modules/_xxsubinterpretersmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -1304,19 +1304,21 @@ typedef struct channelid {
PyObject_HEAD
int64_t id;
int end;
int resolve;
_channels *channels;
} channelid;

static channelid *
newchannelid(PyTypeObject *cls, int64_t cid, int end, _channels *channels,
int force)
int force, int resolve)
{
channelid *self = PyObject_New(channelid, cls);
if (self == NULL) {
return NULL;
}
self->id = cid;
self->end = end;
self->resolve = resolve;
self->channels = channels;

if (_channels_add_id_object(channels, cid) != 0) {
Expand All @@ -1337,14 +1339,15 @@ static _channels * _global_channels(void);
static PyObject *
channelid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {"id", "send", "recv", "force", NULL};
static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL};
PyObject *id;
int send = -1;
int recv = -1;
int force = 0;
int resolve = 0;
if (!PyArg_ParseTupleAndKeywords(args, kwds,
"O|$ppp:ChannelID.__init__", kwlist,
&id, &send, &recv, &force))
"O|$pppp:ChannelID.__new__", kwlist,
&id, &send, &recv, &force, &resolve))
return NULL;

// Coerce and check the ID.
Expand Down Expand Up @@ -1376,7 +1379,8 @@ channelid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds)
end = CHANNEL_RECV;
}

return (PyObject *)newchannelid(cls, cid, end, _global_channels(), force);
return (PyObject *)newchannelid(cls, cid, end, _global_channels(),
force, resolve);
}

static void
Expand Down Expand Up @@ -1519,17 +1523,22 @@ channelid_richcompare(PyObject *self, PyObject *other, int op)
struct _channelid_xid {
int64_t id;
int end;
int resolve;
};

static PyObject *
_channelid_from_xid(_PyCrossInterpreterData *data)
{
struct _channelid_xid *xid = (struct _channelid_xid *)data->data;
// Note that we do not preserve the "resolve" flag.
PyObject *cid = (PyObject *)newchannelid(&ChannelIDtype, xid->id, xid->end,
_global_channels(), 0);
_global_channels(), 0, 0);
if (xid->end == 0) {
return cid;
}
if (!xid->resolve) {
return cid;
}

/* Try returning a high-level channel end but fall back to the ID. */
PyObject *highlevel = PyImport_ImportModule("interpreters");
Expand Down Expand Up @@ -1568,6 +1577,7 @@ _channelid_shared(PyObject *obj, _PyCrossInterpreterData *data)
}
xid->id = ((channelid *)obj)->id;
xid->end = ((channelid *)obj)->end;
xid->resolve = ((channelid *)obj)->resolve;

data->data = xid;
data->obj = obj;
Expand All @@ -1583,7 +1593,7 @@ channelid_end(PyObject *self, void *end)
channelid *cid = (channelid *)self;
if (end != NULL) {
return (PyObject *)newchannelid(Py_TYPE(self), cid->id, *(int *)end,
cid->channels, force);
cid->channels, force, cid->resolve);
}

if (cid->end == CHANNEL_SEND) {
Expand Down Expand Up @@ -2378,7 +2388,7 @@ channel_create(PyObject *self, PyObject *Py_UNUSED(ignored))
return NULL;
}
PyObject *id = (PyObject *)newchannelid(&ChannelIDtype, cid, 0,
&_globals.channels, 0);
&_globals.channels, 0, 0);
if (id == NULL) {
if (_channel_destroy(&_globals.channels, cid) != 0) {
// XXX issue a warning?
Expand Down Expand Up @@ -2436,7 +2446,7 @@ channel_list_all(PyObject *self, PyObject *Py_UNUSED(ignored))
int64_t *cur = cids;
for (int64_t i=0; i < count; cur++, i++) {
PyObject *id = (PyObject *)newchannelid(&ChannelIDtype, *cur, 0,
&_globals.channels, 0);
&_globals.channels, 0, 0);
if (id == NULL) {
Py_DECREF(ids);
ids = NULL;
Expand Down