Skip to content

Commit fe84f25

Browse files
committed
Fix weirdness with import submodules
1 parent b6e061b commit fe84f25

File tree

5 files changed

+76
-44
lines changed

5 files changed

+76
-44
lines changed

bytecode/src/bytecode.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,7 @@ pub enum Instruction {
8484
symbols: Vec<String>,
8585
level: usize,
8686
},
87-
ImportStar {
88-
name: Option<String>,
89-
level: usize,
90-
},
87+
ImportStar,
9188
ImportFrom {
9289
name: String,
9390
},
@@ -429,7 +426,7 @@ impl Instruction {
429426
format!("{:?}", symbols),
430427
level
431428
),
432-
ImportStar { name, level } => w!(ImportStar, format!("{:?}", name), level),
429+
ImportStar => w!(ImportStar),
433430
ImportFrom { name } => w!(ImportFrom, name),
434431
LoadName { name, scope } => w!(LoadName, name, format!("{:?}", scope)),
435432
StoreName { name, scope } => w!(StoreName, name, format!("{:?}", scope)),

compiler/src/compile.rs

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -300,16 +300,11 @@ impl<O: OutputStream> Compiler<O> {
300300
Import { names } => {
301301
// import a, b, c as d
302302
for name in names {
303-
self.emit(Instruction::Import {
304-
name: Some(name.symbol.clone()),
305-
symbols: vec![],
306-
level: 0,
307-
});
308-
303+
self.compile_import(Some(&name.symbol), vec![], 0, name.alias.is_some());
309304
if let Some(alias) = &name.alias {
310305
self.store_name(alias);
311306
} else {
312-
self.store_name(&name.symbol);
307+
self.store_name(name.symbol.split('.').next().unwrap());
313308
}
314309
}
315310
}
@@ -322,21 +317,25 @@ impl<O: OutputStream> Compiler<O> {
322317

323318
if import_star {
324319
// from .... import *
325-
self.emit(Instruction::ImportStar {
326-
name: module.clone(),
327-
level: *level,
328-
});
320+
self.compile_import(
321+
module.as_ref().map(String::as_str),
322+
vec!["*".to_owned()],
323+
*level,
324+
false,
325+
);
326+
self.emit(Instruction::ImportStar);
329327
} else {
330328
// from mod import a, b as c
331329
// First, determine the fromlist (for import lib):
332330
let from_list = names.iter().map(|n| n.symbol.clone()).collect();
333331

334332
// Load module once:
335-
self.emit(Instruction::Import {
336-
name: module.clone(),
337-
symbols: from_list,
338-
level: *level,
339-
});
333+
self.compile_import(
334+
module.as_ref().map(String::as_str),
335+
from_list,
336+
*level,
337+
false,
338+
);
340339

341340
for name in names {
342341
// import symbol from module:
@@ -574,6 +573,30 @@ impl<O: OutputStream> Compiler<O> {
574573
Ok(())
575574
}
576575

576+
fn compile_import(
577+
&mut self,
578+
name: Option<&str>,
579+
symbols: Vec<String>,
580+
level: usize,
581+
get_final_module: bool,
582+
) {
583+
self.emit(Instruction::Import {
584+
name: name.map(ToOwned::to_owned),
585+
symbols,
586+
level,
587+
});
588+
589+
if get_final_module {
590+
if let Some(name) = name {
591+
for part in name.split('.').skip(1) {
592+
self.emit(Instruction::LoadAttr {
593+
name: part.to_owned(),
594+
});
595+
}
596+
}
597+
}
598+
}
599+
577600
fn compile_delete(&mut self, expression: &ast::Expression) -> Result<(), CompileError> {
578601
match &expression.node {
579602
ast::ExpressionType::Identifier { name } => {

compiler/src/symboltable.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,10 @@ impl SymbolTableBuilder {
388388
self.register_name(alias, SymbolUsage::Assigned)?;
389389
} else {
390390
// `import module`
391-
self.register_name(&name.symbol, SymbolUsage::Assigned)?;
391+
self.register_name(
392+
name.symbol.split('.').next().unwrap(),
393+
SymbolUsage::Assigned,
394+
)?;
392395
}
393396
}
394397
}

vm/src/frame.rs

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,7 @@ impl Frame {
198198
ref symbols,
199199
ref level,
200200
} => self.import(vm, name, symbols, *level),
201-
bytecode::Instruction::ImportStar {
202-
ref name,
203-
ref level,
204-
} => self.import_star(vm, name, *level),
201+
bytecode::Instruction::ImportStar => self.import_star(vm),
205202
bytecode::Instruction::ImportFrom { ref name } => self.import_from(vm, name),
206203
bytecode::Instruction::LoadName {
207204
ref name,
@@ -740,25 +737,23 @@ impl Frame {
740737
// Load attribute, and transform any error into import error.
741738
let obj = vm
742739
.get_attribute(module, name)
743-
.map_err(|_| vm.new_import_error(format!("cannot import name '{}'", name)));
744-
self.push_value(obj?);
740+
.map_err(|_| vm.new_import_error(format!("cannot import name '{}'", name)))?;
741+
self.push_value(obj);
745742
Ok(None)
746743
}
747744

748745
#[cfg_attr(feature = "flame-it", flame("Frame"))]
749-
fn import_star(
750-
&self,
751-
vm: &VirtualMachine,
752-
module: &Option<String>,
753-
level: usize,
754-
) -> FrameResult {
755-
let module = module.clone().unwrap_or_default();
756-
let module = vm.import(&module, &vm.ctx.new_tuple(vec![]), level)?;
746+
fn import_star(&self, vm: &VirtualMachine) -> FrameResult {
747+
let module = self.pop_value();
757748

758749
// Grab all the names from the module and put them in the context
759750
if let Some(dict) = &module.dict {
760751
for (k, v) in dict {
761-
self.scope.store_name(&vm, &objstr::get_value(&k), v);
752+
let k = vm.to_str(&k)?;
753+
let k = k.as_str();
754+
if !k.starts_with('_') {
755+
self.scope.store_name(&vm, k, v);
756+
}
762757
}
763758
}
764759
Ok(None)

vm/src/vm.rs

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,23 @@ impl VirtualMachine {
407407
}
408408

409409
pub fn import(&self, module: &str, from_list: &PyObjectRef, level: usize) -> PyResult {
410-
let sys_modules = self.get_attribute(self.sys_module.clone(), "modules")?;
411-
sys_modules
412-
.get_item(module.to_string(), self)
413-
.or_else(|_| {
410+
// if the import inputs seem weird, e.g a package import or something, rather than just
411+
// a straight `import ident`
412+
let weird = module.contains('.')
413+
|| level != 0
414+
|| objbool::boolval(self, from_list.clone()).unwrap_or(true);
415+
416+
let module = self.new_str(module.to_owned());
417+
418+
let sys_module = if weird {
419+
None
420+
} else {
421+
let sys_modules = self.get_attribute(self.sys_module.clone(), "modules")?;
422+
sys_modules.get_item(module.clone(), self).ok()
423+
};
424+
match sys_module {
425+
Some(module) => Ok(module),
426+
None => {
414427
let import_func = self
415428
.get_attribute(self.builtins.clone(), "__import__")
416429
.map_err(|_| self.new_import_error("__import__ not found".to_string()))?;
@@ -426,15 +439,16 @@ impl VirtualMachine {
426439
self.invoke(
427440
&import_func,
428441
vec![
429-
self.ctx.new_str(module.to_string()),
442+
module,
430443
globals,
431444
locals,
432445
from_list.clone(),
433446
self.ctx.new_int(level),
434447
],
435448
)
436-
})
437-
.map_err(|exc| import::remove_importlib_frames(self, &exc))
449+
.map_err(|exc| import::remove_importlib_frames(self, &exc))
450+
}
451+
}
438452
}
439453

440454
/// Determines if `obj` is an instance of `cls`, either directly, indirectly or virtually via

0 commit comments

Comments
 (0)