Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add Interpreter.bind().
  • Loading branch information
ericsnowcurrently committed Nov 6, 2023
commit c8c2edd57d9878a27c7b411d65bed9cfc804e783
5 changes: 5 additions & 0 deletions Lib/test/support/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ def close(self):
"""
return _interpreters.destroy(self._id)

def bind(self, ns=None, /, **kwargs):
"""Bind the given values into the interpreter's __main__."""
ns = dict(ns, **kwargs) if ns is not None else kwargs
_interpreters.bind(self._id, ns)

# XXX Rename "run" to "exec"?
# XXX Do not allow init to overwrite (by default)?
def run(self, src_str, /, *, init=None):
Expand Down
4 changes: 2 additions & 2 deletions Lib/test/test__xxinterpchannels.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,12 +587,12 @@ def test_run_string_arg_unresolved(self):
cid = channels.create()
interp = interpreters.create()

interpreters.bind(interp, dict(cid=cid.send))
out = _run_output(interp, dedent("""
import _xxinterpchannels as _channels
print(cid.end)
_channels.send(cid, b'spam', blocking=False)
"""),
dict(cid=cid.send))
"""))
obj = channels.recv(cid)

self.assertEqual(obj, b'spam')
Expand Down
24 changes: 15 additions & 9 deletions Lib/test/test__xxsubinterpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def _captured_script(script):
return wrapped, open(r, encoding="utf-8")


def _run_output(interp, request, shared=None):
def _run_output(interp, request):
script, rpipe = _captured_script(request)
with rpipe:
interpreters.run_string(interp, script, shared)
interpreters.run_string(interp, script)
return rpipe.read()


Expand Down Expand Up @@ -659,10 +659,10 @@ def test_shareable_types(self):
]
for obj in objects:
with self.subTest(obj):
interpreters.bind(interp, dict(obj=obj))
interpreters.run_string(
interp,
f'assert(obj == {obj!r})',
shared=dict(obj=obj),
)

def test_os_exec(self):
Expand Down Expand Up @@ -790,7 +790,8 @@ def test_with_shared(self):
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
""")
interpreters.run_string(self.id, script, shared)
interpreters.bind(self.id, shared)
interpreters.run_string(self.id, script)
with open(r, 'rb') as chan:
ns = pickle.load(chan)

Expand All @@ -811,7 +812,8 @@ def test_shared_overwrites(self):
ns2 = dict(vars())
del ns2['__builtins__']
""")
interpreters.run_string(self.id, script, shared)
interpreters.bind(self.id, shared)
interpreters.run_string(self.id, script)

r, w = os.pipe()
script = dedent(f"""
Expand Down Expand Up @@ -842,7 +844,8 @@ def test_shared_overwrites_default_vars(self):
with open({w}, 'wb') as chan:
pickle.dump(ns, chan)
""")
interpreters.run_string(self.id, script, shared)
interpreters.bind(self.id, shared)
interpreters.run_string(self.id, script)
with open(r, 'rb') as chan:
ns = pickle.load(chan)

Expand Down Expand Up @@ -948,7 +951,8 @@ def script():
with open(w, 'w', encoding="utf-8") as spipe:
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
interpreters.run_func(self.id, script, shared=dict(w=w))
interpreters.bind(self.id, dict(w=w))
interpreters.run_func(self.id, script)

with open(r, encoding="utf-8") as outfile:
out = outfile.read()
Expand All @@ -964,7 +968,8 @@ def script():
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
def f():
interpreters.run_func(self.id, script, shared=dict(w=w))
interpreters.bind(self.id, dict(w=w))
interpreters.run_func(self.id, script)
t = threading.Thread(target=f)
t.start()
t.join()
Expand All @@ -984,7 +989,8 @@ def script():
with contextlib.redirect_stdout(spipe):
print('it worked!', end='')
code = script.__code__
interpreters.run_func(self.id, code, shared=dict(w=w))
interpreters.bind(self.id, dict(w=w))
interpreters.run_func(self.id, code)

with open(r, encoding="utf-8") as outfile:
out = outfile.read()
Expand Down
61 changes: 60 additions & 1 deletion Lib/test/test_interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def clean_up_interpreters():
def _run_output(interp, request, init=None):
script, rpipe = _captured_script(request)
with rpipe:
interp.run(script, init=init)
if init:
interp.bind(init)
interp.run(script)
return rpipe.read()


Expand Down Expand Up @@ -467,6 +469,63 @@ def task():
self.assertEqual(os.read(r_interp, 1), FINISHED)


class TestInterpreterBind(TestBase):

def test_empty(self):
interp = interpreters.create()
with self.assertRaises(ValueError):
interp.bind()

def test_dict(self):
values = {'spam': 42, 'eggs': 'ham'}
interp = interpreters.create()
interp.bind(values)
out = _run_output(interp, dedent("""
print(spam, eggs)
"""))
self.assertEqual(out.strip(), '42 ham')

def test_tuple(self):
values = {'spam': 42, 'eggs': 'ham'}
values = tuple(values.items())
interp = interpreters.create()
interp.bind(values)
out = _run_output(interp, dedent("""
print(spam, eggs)
"""))
self.assertEqual(out.strip(), '42 ham')

def test_kwargs(self):
values = {'spam': 42, 'eggs': 'ham'}
interp = interpreters.create()
interp.bind(**values)
out = _run_output(interp, dedent("""
print(spam, eggs)
"""))
self.assertEqual(out.strip(), '42 ham')

def test_dict_and_kwargs(self):
values = {'spam': 42, 'eggs': 'ham'}
interp = interpreters.create()
interp.bind(values, foo='bar')
out = _run_output(interp, dedent("""
print(spam, eggs, foo)
"""))
self.assertEqual(out.strip(), '42 ham bar')

def test_not_shareable(self):
interp = interpreters.create()
# XXX TypeError?
with self.assertRaises(ValueError):
interp.bind(spam={'spam': 'eggs', 'foo': 'bar'})

# Make sure neither was actually bound.
with self.assertRaises(RuntimeError):
interp.run('print(foo)')
with self.assertRaises(RuntimeError):
interp.run('print(spam)')


class TestInterpreterRun(TestBase):

def test_success(self):
Expand Down
54 changes: 54 additions & 0 deletions Modules/_xxsubinterpretersmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,58 @@ PyDoc_STRVAR(get_main_doc,
\n\
Return the ID of main interpreter.");

static PyObject *
interp_bind(PyObject *self, PyObject *args)
{
PyObject *id, *updates;
if (!PyArg_ParseTuple(args, "OO:" MODULE_NAME ".bind", &id, &updates)) {
return NULL;
}

// Look up the interpreter.
PyInterpreterState *interp = PyInterpreterID_LookUp(id);
if (interp == NULL) {
return NULL;
}

// Check the updates.
if (updates != Py_None) {
Py_ssize_t size = PyObject_Size(updates);
if (size < 0) {
return NULL;
}
if (size == 0) {
PyErr_SetString(PyExc_ValueError,
"arg 2 must be a non-empty mapping");
return NULL;
}
}

_PyXI_session session = {0};

// Prep and switch interpreters, including apply the updates.
if (_PyXI_Enter(&session, interp, updates) < 0) {
if (!PyErr_Occurred()) {
_PyXI_ApplyCapturedException(&session, NULL);
assert(PyErr_Occurred());
}
else {
assert(!_PyXI_HasCapturedException(&session));
}
return NULL;
}

// Clean up and switch back.
_PyXI_Exit(&session);

Py_RETURN_NONE;
}

PyDoc_STRVAR(bind_doc,
"bind(id, ns)\n\
\n\
Bind the given attributes in the interpreter's __main__ module.");

static PyUnicodeObject *
convert_script_arg(PyObject *arg, const char *fname, const char *displayname,
const char *expected)
Expand Down Expand Up @@ -698,6 +750,8 @@ static PyMethodDef module_functions[] = {
{"run_func", _PyCFunction_CAST(interp_run_func),
METH_VARARGS | METH_KEYWORDS, run_func_doc},

{"bind", _PyCFunction_CAST(interp_bind),
METH_VARARGS, bind_doc},
{"is_shareable", _PyCFunction_CAST(object_is_shareable),
METH_VARARGS | METH_KEYWORDS, is_shareable_doc},

Expand Down