Skip to content

Commit 4ac79d0

Browse files
authored
Fix annotation scope, deadlock, MRO, and HEAPTYPE issues (RustPython#7087)
* Fix annotation scope, deadlock, MRO, and HEAPTYPE issues Annotation scope: - Remove module-level annotation re-scan that created phantom sub_tables, breaking annotation closure for comprehensions - Add async comprehension check in symbol table with is_in_async_context(); annotation/type-params scopes are always non-async - Save/restore CompileContext in enter/exit_annotation_scope to reset in_async_scope Deadlock prevention: - Fix TypeVar/ParamSpec/TypeVarTuple __default__ and evaluate_default by cloning lock contents before acquiring a second lock or calling Python Other fixes: - Add HEAPTYPE flag to Generic for correct pickle behavior - Guard heaptype_ext access in name_inner/set___name__/ set___qualname__ with safe checks instead of unwrap - Fix MRO error message to include base class names - Add "format" to varnames in TypeAlias annotation scopes - Fix single-element tuple repr to include trailing comma * Unmark failing markers
2 parents 5d08063 + cb19372 commit 4ac79d0

11 files changed

Lines changed: 157 additions & 85 deletions

File tree

Lib/ctypes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def __repr__(self):
163163
return super().__repr__()
164164
except ValueError:
165165
return "%s(<NULL>)" % type(self).__name__
166+
__class_getitem__ = classmethod(_types.GenericAlias)
166167
_check_size(py_object, "P")
167168

168169
class c_short(_SimpleCData):

Lib/test/test_dataclasses/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def test_field_recursive_repr(self):
8686

8787
self.assertIn(",type=...,", repr_output)
8888

89-
@unittest.expectedFailure # TODO: RUSTPYTHON; recursive annotation type not shown as ...
9089
def test_recursive_annotation(self):
9190
class C:
9291
pass

Lib/test/test_genericalias.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ class BaseTest(unittest.TestCase):
152152
if Event is not None:
153153
generic_types.append(Event)
154154

155-
@unittest.expectedFailure # TODO: RUSTPYTHON; memoryview, Template, Interpolation, py_object not subscriptable
156155
def test_subscriptable(self):
157156
for t in self.generic_types:
158157
if t is None:
@@ -507,7 +506,6 @@ def test_dir(self):
507506
with self.subTest(entry=entry):
508507
getattr(ga, entry) # must not raise `AttributeError`
509508

510-
@unittest.expectedFailure # TODO: RUSTPYTHON; memoryview, Template, Interpolation, py_object not subscriptable
511509
def test_weakref(self):
512510
for t in self.generic_types:
513511
if t is None:

Lib/test/test_type_annotations.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,6 @@ class Nested: ...
413413
self.assertEqual(Outer.meth.__annotations__, {"x": Outer.Nested})
414414
self.assertEqual(Outer.__annotations__, {"x": Outer.Nested})
415415

416-
@unittest.expectedFailure # TODO: RUSTPYTHON
417416
def test_no_exotic_expressions(self):
418417
preludes = [
419418
"",
@@ -431,7 +430,6 @@ def test_no_exotic_expressions(self):
431430
check_syntax_error(self, prelude + "def func(x: {y async for y in x}): ...", "asynchronous comprehension outside of an asynchronous function")
432431
check_syntax_error(self, prelude + "def func(x: {y: y async for y in x}): ...", "asynchronous comprehension outside of an asynchronous function")
433432

434-
@unittest.expectedFailure # TODO: RUSTPYTHON
435433
def test_no_exotic_expressions_in_unevaluated_annotations(self):
436434
preludes = [
437435
"",

Lib/test/test_type_params.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def test_disallowed_expressions(self):
148148
check_syntax_error(self, "def f[T: [(x := 3) for _ in range(2)]](): pass")
149149
check_syntax_error(self, "type T = [(x := 3) for _ in range(2)]")
150150

151-
@unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: "\(MRO\) for bases object, Generic" does not match "Unable to find mro order which keeps local precedence ordering"
152151
def test_incorrect_mro_explicit_object(self):
153152
with self.assertRaisesRegex(TypeError, r"\(MRO\) for bases object, Generic"):
154153
class My[X](object): ...
@@ -1215,7 +1214,6 @@ def test_pickling_functions(self):
12151214
pickled = pickle.dumps(thing, protocol=proto)
12161215
self.assertEqual(pickle.loads(pickled), thing)
12171216

1218-
@unittest.expectedFailure # TODO: RUSTPYTHON
12191217
def test_pickling_classes(self):
12201218
things_to_test = [
12211219
Class1,

Lib/test/test_typing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4209,7 +4209,6 @@ class P(Protocol):
42094209
Alias2 = typing.Union[P, typing.Iterable]
42104210
self.assertEqual(Alias, Alias2)
42114211

4212-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Generic() takes no arguments
42134212
def test_protocols_pickleable(self):
42144213
global P, CP # pickle wants to reference the class by name
42154214
T = TypeVar('T')
@@ -5287,7 +5286,6 @@ def test_all_repr_eq_any(self):
52875286
self.assertNotEqual(repr(base), '')
52885287
self.assertEqual(base, base)
52895288

5290-
@unittest.expectedFailure # TODO: RUSTPYTHON; TypeError: Generic() takes no arguments
52915289
def test_pickle(self):
52925290
global C # pickle wants to reference the class by name
52935291
T = TypeVar('T')

crates/codegen/src/compile.rs

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,21 +1226,35 @@ impl Compiler {
12261226
}
12271227

12281228
/// Exit annotation scope - similar to exit_scope but restores annotation_block to parent
1229-
fn exit_annotation_scope(&mut self) -> CodeObject {
1229+
fn exit_annotation_scope(&mut self, saved_ctx: CompileContext) -> CodeObject {
12301230
self.pop_annotation_symbol_table();
1231+
self.ctx = saved_ctx;
12311232

12321233
let pop = self.code_stack.pop();
12331234
let stack_top = compiler_unwrap_option(self, pop);
12341235
unwrap_internal(self, stack_top.finalize_code(&self.opts))
12351236
}
12361237

1237-
/// Enter annotation scope using the symbol table's annotation_block
1238-
/// Returns false if no annotation_block exists
1239-
fn enter_annotation_scope(&mut self, _func_name: &str) -> CompileResult<bool> {
1238+
/// Enter annotation scope using the symbol table's annotation_block.
1239+
/// Returns None if no annotation_block exists.
1240+
/// On success, returns the saved CompileContext to pass to exit_annotation_scope.
1241+
fn enter_annotation_scope(
1242+
&mut self,
1243+
_func_name: &str,
1244+
) -> CompileResult<Option<CompileContext>> {
12401245
if !self.push_annotation_symbol_table() {
1241-
return Ok(false);
1246+
return Ok(None);
12421247
}
12431248

1249+
// Annotation scopes are never async (even inside async functions)
1250+
let saved_ctx = self.ctx;
1251+
self.ctx = CompileContext {
1252+
loop_data: None,
1253+
in_class: saved_ctx.in_class,
1254+
func: FunctionContext::Function,
1255+
in_async_scope: false,
1256+
};
1257+
12441258
let key = self.symbol_table_stack.len() - 1;
12451259
let lineno = self.get_source_line_number().get();
12461260
self.enter_scope(
@@ -1261,7 +1275,7 @@ impl Compiler {
12611275
// VALUE_WITH_FAKE_GLOBALS = 2 (from annotationlib.Format)
12621276
self.emit_format_validation()?;
12631277

1264-
Ok(true)
1278+
Ok(Some(saved_ctx))
12651279
}
12661280

12671281
/// Emit format parameter validation for annotation scope
@@ -2477,6 +2491,10 @@ impl Compiler {
24772491
// Evaluator takes a positional-only format parameter
24782492
self.current_code_info().metadata.argcount = 1;
24792493
self.current_code_info().metadata.posonlyargcount = 1;
2494+
self.current_code_info()
2495+
.metadata
2496+
.varnames
2497+
.insert("format".to_owned());
24802498
self.emit_format_validation()?;
24812499
self.compile_expression(value)?;
24822500
emit!(self, Instruction::ReturnValue);
@@ -2514,6 +2532,10 @@ impl Compiler {
25142532
// Evaluator takes a positional-only format parameter
25152533
self.current_code_info().metadata.argcount = 1;
25162534
self.current_code_info().metadata.posonlyargcount = 1;
2535+
self.current_code_info()
2536+
.metadata
2537+
.varnames
2538+
.insert("format".to_owned());
25172539
self.emit_format_validation()?;
25182540

25192541
let prev_ctx = self.ctx;
@@ -2659,6 +2681,10 @@ impl Compiler {
26592681
// Evaluator takes a positional-only format parameter
26602682
self.current_code_info().metadata.argcount = 1;
26612683
self.current_code_info().metadata.posonlyargcount = 1;
2684+
self.current_code_info()
2685+
.metadata
2686+
.varnames
2687+
.insert("format".to_owned());
26622688

26632689
self.emit_format_validation()?;
26642690

@@ -3787,10 +3813,10 @@ impl Compiler {
37873813
parameters: &ast::Parameters,
37883814
returns: Option<&ast::Expr>,
37893815
) -> CompileResult<bool> {
3790-
// Try to enter annotation scope - returns false if no annotation_block exists
3791-
if !self.enter_annotation_scope(func_name)? {
3816+
// Try to enter annotation scope - returns None if no annotation_block exists
3817+
let Some(saved_ctx) = self.enter_annotation_scope(func_name)? else {
37923818
return Ok(false);
3793-
}
3819+
};
37943820

37953821
// Count annotations
37963822
let parameters_iter = core::iter::empty()
@@ -3842,7 +3868,7 @@ impl Compiler {
38423868
emit!(self, Instruction::ReturnValue);
38433869

38443870
// Exit the annotation scope and get the code object
3845-
let annotate_code = self.exit_annotation_scope();
3871+
let annotate_code = self.exit_annotation_scope(saved_ctx);
38463872

38473873
// Make a closure from the code object
38483874
self.make_closure(annotate_code, bytecode::MakeFunctionFlags::empty())?;
@@ -3935,6 +3961,15 @@ impl Compiler {
39353961
return Ok(false);
39363962
}
39373963

3964+
// Annotation scopes are never async (even inside async functions)
3965+
let saved_ctx = self.ctx;
3966+
self.ctx = CompileContext {
3967+
loop_data: None,
3968+
in_class: saved_ctx.in_class,
3969+
func: FunctionContext::Function,
3970+
in_async_scope: false,
3971+
};
3972+
39383973
// Enter annotation scope for code generation
39393974
let key = self.symbol_table_stack.len() - 1;
39403975
let lineno = self.get_source_line_number().get();
@@ -4031,6 +4066,8 @@ impl Compiler {
40314066
.last_mut()
40324067
.expect("no module symbol table")
40334068
.annotation_block = Some(Box::new(annotation_table));
4069+
// Restore context
4070+
self.ctx = saved_ctx;
40344071
// Exit code scope
40354072
let pop = self.code_stack.pop();
40364073
let annotate_code = unwrap_internal(

crates/codegen/src/symboltable.rs

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,26 @@ impl SymbolTableBuilder {
10261026
.insert(SymbolFlags::REFERENCED | SymbolFlags::FREE_CLASS);
10271027
}
10281028

1029+
/// Walk up the scope chain to determine if we're inside an async function.
1030+
/// Annotation and TypeParams scopes act as async barriers (always non-async).
1031+
/// Comprehension scopes are transparent (inherit parent's async context).
1032+
fn is_in_async_context(&self) -> bool {
1033+
for table in self.tables.iter().rev() {
1034+
match table.typ {
1035+
CompilerScope::AsyncFunction => return true,
1036+
CompilerScope::Function
1037+
| CompilerScope::Lambda
1038+
| CompilerScope::Class
1039+
| CompilerScope::Module
1040+
| CompilerScope::Annotation
1041+
| CompilerScope::TypeParams => return false,
1042+
// Comprehension inherits parent's async context
1043+
CompilerScope::Comprehension => continue,
1044+
}
1045+
}
1046+
false
1047+
}
1048+
10291049
fn line_index_start(&self, range: TextRange) -> u32 {
10301050
self.source_file
10311051
.to_source_code()
@@ -1128,15 +1148,6 @@ impl SymbolTableBuilder {
11281148

11291149
self.leave_annotation_scope();
11301150

1131-
// Module scope: re-scan to register symbols (builtins like str, int)
1132-
// Class scope: do NOT re-scan to preserve class-local symbol resolution
1133-
if matches!(current_scope, Some(CompilerScope::Module)) {
1134-
let was_in_annotation = self.in_annotation;
1135-
self.in_annotation = true;
1136-
let _ = self.scan_expression(annotation, ExpressionContext::Load);
1137-
self.in_annotation = was_in_annotation;
1138-
}
1139-
11401151
result
11411152
}
11421153

@@ -1939,6 +1950,20 @@ impl SymbolTableBuilder {
19391950
range: TextRange,
19401951
is_generator: bool,
19411952
) -> SymbolTableResult {
1953+
// Check for async comprehension outside async function
1954+
// (list/set/dict comprehensions only, not generator expressions)
1955+
let has_async_gen = generators.iter().any(|g| g.is_async);
1956+
if has_async_gen && !is_generator && !self.is_in_async_context() {
1957+
return Err(SymbolTableError {
1958+
error: "asynchronous comprehension outside of an asynchronous function".to_owned(),
1959+
location: Some(
1960+
self.source_file
1961+
.to_source_code()
1962+
.source_location(range.start(), PositionEncoding::Utf8),
1963+
),
1964+
});
1965+
}
1966+
19421967
// Comprehensions are compiled as functions, so create a scope for them:
19431968
self.enter_scope(
19441969
scope_name,

crates/vm/src/builtins/type.rs

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -605,10 +605,10 @@ impl PyType {
605605
static_f: impl FnOnce(&'static str) -> R,
606606
heap_f: impl FnOnce(&'a HeapTypeExt) -> R,
607607
) -> R {
608-
if !self.slots.flags.has_feature(PyTypeFlags::HEAPTYPE) {
609-
static_f(self.slots.name)
608+
if let Some(ref ext) = self.heaptype_ext {
609+
heap_f(ext)
610610
} else {
611-
heap_f(self.heaptype_ext.as_ref().unwrap())
611+
static_f(self.slots.name)
612612
}
613613
}
614614

@@ -849,13 +849,7 @@ impl PyType {
849849

850850
#[pygetset(setter)]
851851
fn set___qualname__(&self, value: PySetterValue, vm: &VirtualMachine) -> PyResult<()> {
852-
// TODO: we should replace heaptype flag check to immutable flag check
853-
if !self.slots.flags.has_feature(PyTypeFlags::HEAPTYPE) {
854-
return Err(vm.new_type_error(format!(
855-
"cannot set '__qualname__' attribute of immutable type '{}'",
856-
self.name()
857-
)));
858-
};
852+
self.check_set_special_type_attr(identifier!(vm, __qualname__), vm)?;
859853
let value = value.ok_or_else(|| {
860854
vm.new_type_error(format!(
861855
"cannot delete '__qualname__' attribute of immutable type '{}'",
@@ -865,10 +859,12 @@ impl PyType {
865859

866860
let str_value = downcast_qualname(value, vm)?;
867861

868-
let heap_type = self
869-
.heaptype_ext
870-
.as_ref()
871-
.expect("HEAPTYPE should have heaptype_ext");
862+
let heap_type = self.heaptype_ext.as_ref().ok_or_else(|| {
863+
vm.new_type_error(format!(
864+
"cannot set '__qualname__' attribute of immutable type '{}'",
865+
self.name()
866+
))
867+
})?;
872868

873869
// Use std::mem::replace to swap the new value in and get the old value out,
874870
// then drop the old value after releasing the lock
@@ -1160,10 +1156,17 @@ impl PyType {
11601156
}
11611157
name.ensure_valid_utf8(vm)?;
11621158

1159+
let heap_type = self.heaptype_ext.as_ref().ok_or_else(|| {
1160+
vm.new_type_error(format!(
1161+
"cannot set '__name__' attribute of immutable type '{}'",
1162+
self.slot_name()
1163+
))
1164+
})?;
1165+
11631166
// Use std::mem::replace to swap the new value in and get the old value out,
1164-
// then drop the old value after releasing the lock (similar to CPython's Py_SETREF)
1167+
// then drop the old value after releasing the lock
11651168
let _old_name = {
1166-
let mut name_guard = self.heaptype_ext.as_ref().unwrap().name.write();
1169+
let mut name_guard = heap_type.name.write();
11671170
core::mem::replace(&mut *name_guard, name)
11681171
};
11691172
// old_name is dropped here, outside the lock scope
@@ -2129,9 +2132,10 @@ fn linearise_mro(mut bases: Vec<Vec<PyTypeRef>>) -> Result<Vec<PyTypeRef>, Strin
21292132
// We start at index 1 to skip direct bases.
21302133
// This will not catch duplicate bases, but such a thing is already tested for.
21312134
if later_mro[1..].iter().any(|cls| cls.is(base)) {
2132-
return Err(
2133-
"Unable to find mro order which keeps local precedence ordering".to_owned(),
2134-
);
2135+
return Err(format!(
2136+
"Cannot create a consistent method resolution order (MRO) for bases {}",
2137+
bases.iter().map(|x| x.first().unwrap()).format(", ")
2138+
));
21352139
}
21362140
}
21372141
}

0 commit comments

Comments
 (0)