Skip to content

Commit 9c09aef

Browse files
committed
fix(vm): provide detailed error for circular from imports
1 parent 6342ad4 commit 9c09aef

File tree

4 files changed

+81
-17
lines changed

4 files changed

+81
-17
lines changed

Lib/test/test_import/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,8 +1380,6 @@ def test_crossreference2(self):
13801380
self.assertIn('partially initialized module', errmsg)
13811381
self.assertIn('circular import', errmsg)
13821382

1383-
# TODO: RUSTPYTHON
1384-
@unittest.expectedFailure
13851383
def test_circular_from_import(self):
13861384
with self.assertRaises(ImportError) as cm:
13871385
import test.test_import.data.circular_imports.from_cycle1

vm/src/builtins/module.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::atomic::AtomicBool;
2+
13
use super::{PyDict, PyDictRef, PyStr, PyStrRef, PyType, PyTypeRef};
24
use crate::{
35
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
@@ -51,6 +53,8 @@ pub struct PyModule {
5153
// weaklist
5254
// for logging purposes after md_dict is cleared
5355
pub name: Option<&'static PyStrInterned>,
56+
57+
pub(crate) initializing: AtomicBool,
5458
}
5559

5660
impl PyPayload for PyModule {
@@ -73,13 +77,15 @@ impl PyModule {
7377
Self {
7478
def: None,
7579
name: None,
80+
initializing: AtomicBool::new(false),
7681
}
7782
}
7883

7984
pub const fn from_def(def: &'static PyModuleDef) -> Self {
8085
Self {
8186
def: Some(def),
8287
name: Some(def.name),
88+
initializing: AtomicBool::new(false),
8389
}
8490
}
8591

vm/src/frame.rs

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,20 +1362,50 @@ impl ExecutingFrame<'_> {
13621362
#[cfg_attr(feature = "flame-it", flame("Frame"))]
13631363
fn import_from(&mut self, vm: &VirtualMachine, idx: bytecode::NameIdx) -> PyResult {
13641364
let module = self.top_value();
1365-
let name = self.code.names[idx as usize];
1366-
let err = || vm.new_import_error(format!("cannot import name '{name}'"), name.to_owned());
1365+
let name_to_import = self.code.names[idx as usize];
1366+
13671367
// Load attribute, and transform any error into import error.
1368-
if let Some(obj) = vm.get_attribute_opt(module.to_owned(), name)? {
1368+
if let Some(obj) = vm.get_attribute_opt(module.to_owned(), name_to_import)? {
13691369
return Ok(obj);
13701370
}
1371-
// fallback to importing '{module.__name__}.{name}' from sys.modules
1372-
let mod_name = module
1373-
.get_attr(identifier!(vm, __name__), vm)
1374-
.map_err(|_| err())?;
1375-
let mod_name = mod_name.downcast::<PyStr>().map_err(|_| err())?;
1376-
let full_mod_name = format!("{mod_name}.{name}");
1377-
let sys_modules = vm.sys_module.get_attr("modules", vm).map_err(|_| err())?;
1378-
sys_modules.get_item(&full_mod_name, vm).map_err(|_| err())
1371+
1372+
let fallback_result: Option<PyResult> = module
1373+
.get_attr(&vm.ctx.new_str("__name__"), vm)
1374+
.ok()
1375+
.and_then(|mod_name| mod_name.downcast_ref::<PyStr>().map(|s| s.to_owned()))
1376+
.and_then(|mod_name_str| {
1377+
let full_mod_name =
1378+
format!("{}.{}", mod_name_str.as_str(), name_to_import.as_str());
1379+
vm.sys_module
1380+
.get_attr("modules", vm)
1381+
.ok()
1382+
.and_then(|sys_modules| sys_modules.get_item(&full_mod_name, vm).ok())
1383+
})
1384+
.map(Ok);
1385+
1386+
if let Some(Ok(sub_module)) = fallback_result {
1387+
return Ok(sub_module);
1388+
}
1389+
1390+
if is_module_initializing(module, vm) {
1391+
let module_name = module
1392+
.get_attr(&vm.ctx.new_str("__name__"), vm)
1393+
.ok()
1394+
.and_then(|n| n.downcast_ref::<PyStr>().map(|s| s.as_str().to_owned()))
1395+
.unwrap_or_else(|| "<unknown>".to_owned());
1396+
1397+
let msg = format!(
1398+
"cannot import name '{}' from partially initialized module '{}' (most likely due to a circular import)",
1399+
name_to_import.as_str(),
1400+
module_name
1401+
);
1402+
Err(vm.new_import_error(msg, name_to_import.to_owned()))
1403+
} else {
1404+
Err(vm.new_import_error(
1405+
format!("cannot import name '{}'", name_to_import.as_str()),
1406+
name_to_import.to_owned(),
1407+
))
1408+
}
13791409
}
13801410

13811411
#[cfg_attr(feature = "flame-it", flame("Frame"))]
@@ -2372,3 +2402,15 @@ impl fmt::Debug for Frame {
23722402
)
23732403
}
23742404
}
2405+
2406+
fn is_module_initializing(module: &PyObject, vm: &VirtualMachine) -> bool {
2407+
let spec = match module.get_attr(&vm.ctx.new_str("__spec__"), vm) {
2408+
Ok(spec) if !vm.is_none(&spec) => spec,
2409+
_ => return false,
2410+
};
2411+
let initializing_attr = match spec.get_attr(&vm.ctx.new_str("_initializing"), vm) {
2412+
Ok(attr) => attr,
2413+
Err(_) => return false,
2414+
};
2415+
initializing_attr.try_to_bool(vm).unwrap_or(false)
2416+
}

vm/src/import.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
//! Import mechanics
22
3+
use std::sync::atomic::Ordering;
4+
35
use crate::{
46
AsObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject,
5-
builtins::{PyBaseExceptionRef, PyCode, list, traceback::PyTraceback},
7+
builtins::{PyBaseExceptionRef, PyCode, PyModule, list, traceback::PyTraceback},
68
scope::Scope,
79
version::get_git_revision,
810
vm::{VirtualMachine, thread},
@@ -156,9 +158,25 @@ pub fn import_code_obj(
156158
let sys_modules = vm.sys_module.get_attr("modules", vm)?;
157159
sys_modules.set_item(module_name, module.clone().into(), vm)?;
158160

159-
// Execute main code in module:
160-
let scope = Scope::with_builtins(None, attrs, vm);
161-
vm.run_code_obj(code_obj, scope)?;
161+
{
162+
struct InitializingGuard<'a> {
163+
module: &'a PyModule,
164+
}
165+
166+
impl<'a> Drop for InitializingGuard<'a> {
167+
fn drop(&mut self) {
168+
self.module.initializing.store(false, Ordering::Relaxed);
169+
}
170+
}
171+
172+
module.initializing.store(true, Ordering::Relaxed);
173+
let _guard = InitializingGuard { module: &module };
174+
175+
// Execute main code in module:
176+
let scope = Scope::with_builtins(None, attrs, vm);
177+
vm.run_code_obj(code_obj, scope)?;
178+
}
179+
162180
Ok(module.into())
163181
}
164182

0 commit comments

Comments
 (0)